1mod api;
57mod client;
58mod dataset;
59mod error;
60
61pub use crate::{
62 api::{
63 AnnotationSetID, AppId, Artifact, DatasetID, DatasetParams, Experiment, ExperimentID,
64 ImageId, Organization, OrganizationID, Parameter, Project, ProjectID, SampleID, SequenceId,
65 SnapshotID, Stage, Task, TaskID, TaskInfo, TrainingSession, TrainingSessionID,
66 ValidationSession, ValidationSessionID,
67 },
68 client::{Client, Progress},
69 dataset::{
70 Annotation, AnnotationSet, AnnotationType, Box2d, Box3d, Dataset, FileType, Label, Mask,
71 Sample,
72 },
73 error::Error,
74};
75
76#[cfg(feature = "polars")]
77pub use crate::dataset::annotations_dataframe;
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use polars::frame::UniqueKeepStrategy;
83 use std::{
84 collections::HashMap,
85 env,
86 fs::{File, read_to_string},
87 io::Write,
88 path::Path,
89 };
90 use tokio::time::{Duration, sleep};
91
92 #[ctor::ctor]
93 fn init() {
94 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
95 }
96
97 #[tokio::test]
98 async fn test_version() -> Result<(), Error> {
99 let client = match env::var("STUDIO_SERVER") {
100 Ok(server) => Client::new()?.with_server(&server)?,
101 Err(_) => Client::new()?,
102 };
103 let result = client.version().await?;
104 println!("EdgeFirst Studio Version: {}", result);
105 Ok(())
106 }
107
108 async fn get_client() -> Result<Client, Error> {
109 let client = Client::new()?.with_token_path(None)?;
110
111 let client = match env::var("STUDIO_TOKEN") {
112 Ok(token) => client.with_token(&token)?,
113 Err(_) => client,
114 };
115
116 let client = match env::var("STUDIO_SERVER") {
117 Ok(server) => client.with_server(&server)?,
118 Err(_) => client,
119 };
120
121 let client = match (env::var("STUDIO_USERNAME"), env::var("STUDIO_PASSWORD")) {
122 (Ok(username), Ok(password)) => client.with_login(&username, &password).await?,
123 _ => client,
124 };
125
126 client.verify_token().await?;
127
128 Ok(client)
129 }
130
131 #[tokio::test]
132 async fn test_token() -> Result<(), Error> {
133 let client = get_client().await?;
134 let token = client.token().await;
135 assert!(!token.is_empty());
136 println!("Token: {}", token);
137
138 let exp = client.token_expiration().await?;
139 println!("Token Expiration: {}", exp);
140
141 let username = client.username().await?;
142 assert!(!username.is_empty());
143 println!("Username: {}", username);
144
145 sleep(Duration::from_secs(2)).await;
147
148 client.renew_token().await?;
149 let new_token = client.token().await;
150 assert!(!new_token.is_empty());
151 assert_ne!(token, new_token);
152 println!("New Token Expiration: {}", client.token_expiration().await?);
153
154 Ok(())
155 }
156
157 #[tokio::test]
158 async fn test_organization() -> Result<(), Error> {
159 let client = get_client().await?;
160 let org = client.organization().await?;
161 println!(
162 "Organization: {}\nID: {}\nCredits: {}",
163 org.name(),
164 org.id(),
165 org.credits()
166 );
167 Ok(())
168 }
169
170 #[tokio::test]
171 async fn test_projects() -> Result<(), Error> {
172 let client = get_client().await?;
173 let project = client.projects(Some("Unit Testing")).await?;
174 assert!(!project.is_empty());
175 Ok(())
176 }
177
178 #[tokio::test]
179 async fn test_datasets() -> Result<(), Error> {
180 let client = get_client().await?;
181 let project = client.projects(Some("Unit Testing")).await?;
182 assert!(!project.is_empty());
183 let project = project.first().unwrap();
184 let datasets = client.datasets(project.id(), None).await?;
185
186 for dataset in datasets {
187 let dataset_id = dataset.id();
188 let result = client.dataset(dataset_id).await?;
189 assert_eq!(result.id(), dataset_id);
190 }
191
192 Ok(())
193 }
194
195 #[tokio::test]
196 async fn test_labels() -> Result<(), Error> {
197 let client = get_client().await?;
198 let project = client.projects(Some("Unit Testing")).await?;
199 assert!(!project.is_empty());
200 let project = project.first().unwrap();
201 let datasets = client.datasets(project.id(), Some("Test Labels")).await?;
202 let dataset = datasets.first().unwrap_or_else(|| {
203 panic!(
204 "Dataset 'Test Labels' not found in project '{}'",
205 project.name()
206 )
207 });
208
209 let labels = dataset.labels(&client).await?;
210 for label in labels {
211 label.remove(&client).await?;
212 }
213
214 let labels = dataset.labels(&client).await?;
215 assert_eq!(labels.len(), 0);
216
217 dataset.add_label(&client, "test").await?;
218 let labels = dataset.labels(&client).await?;
219 assert_eq!(labels.len(), 1);
220 assert_eq!(labels[0].name(), "test");
221
222 dataset.remove_label(&client, "test").await?;
223 let labels = dataset.labels(&client).await?;
224 assert_eq!(labels.len(), 0);
225
226 Ok(())
227 }
228
229 #[tokio::test]
230 async fn test_coco() -> Result<(), Error> {
231 let coco_labels = HashMap::from([
232 (0, "person"),
233 (1, "bicycle"),
234 (2, "car"),
235 (3, "motorcycle"),
236 (4, "airplane"),
237 (5, "bus"),
238 (6, "train"),
239 (7, "truck"),
240 (8, "boat"),
241 (9, "traffic light"),
242 (10, "fire hydrant"),
243 (11, "stop sign"),
244 (12, "parking meter"),
245 (13, "bench"),
246 (14, "bird"),
247 (15, "cat"),
248 (16, "dog"),
249 (17, "horse"),
250 (18, "sheep"),
251 (19, "cow"),
252 (20, "elephant"),
253 (21, "bear"),
254 (22, "zebra"),
255 (23, "giraffe"),
256 (24, "backpack"),
257 (25, "umbrella"),
258 (26, "handbag"),
259 (27, "tie"),
260 (28, "suitcase"),
261 (29, "frisbee"),
262 (30, "skis"),
263 (31, "snowboard"),
264 (32, "sports ball"),
265 (33, "kite"),
266 (34, "baseball bat"),
267 (35, "baseball glove"),
268 (36, "skateboard"),
269 (37, "surfboard"),
270 (38, "tennis racket"),
271 (39, "bottle"),
272 (40, "wine glass"),
273 (41, "cup"),
274 (42, "fork"),
275 (43, "knife"),
276 (44, "spoon"),
277 (45, "bowl"),
278 (46, "banana"),
279 (47, "apple"),
280 (48, "sandwich"),
281 (49, "orange"),
282 (50, "broccoli"),
283 (51, "carrot"),
284 (52, "hot dog"),
285 (53, "pizza"),
286 (54, "donut"),
287 (55, "cake"),
288 (56, "chair"),
289 (57, "couch"),
290 (58, "potted plant"),
291 (59, "bed"),
292 (60, "dining table"),
293 (61, "toilet"),
294 (62, "tv"),
295 (63, "laptop"),
296 (64, "mouse"),
297 (65, "remote"),
298 (66, "keyboard"),
299 (67, "cell phone"),
300 (68, "microwave"),
301 (69, "oven"),
302 (70, "toaster"),
303 (71, "sink"),
304 (72, "refrigerator"),
305 (73, "book"),
306 (74, "clock"),
307 (75, "vase"),
308 (76, "scissors"),
309 (77, "teddy bear"),
310 (78, "hair drier"),
311 (79, "toothbrush"),
312 ]);
313
314 let client = get_client().await?;
315 let project = client.projects(Some("Sample Project")).await?;
316 assert!(!project.is_empty());
317 let project = project.first().unwrap();
318 let datasets = client.datasets(project.id(), Some("COCO")).await?;
319 assert!(!datasets.is_empty());
320 let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
322
323 let labels = dataset.labels(&client).await?;
324 assert_eq!(labels.len(), 80);
325
326 for label in &labels {
327 assert_eq!(label.name(), coco_labels[&label.index()]);
328 }
329
330 let n_samples = client
331 .samples_count(dataset.id(), None, &[], &["val".to_string()], &[])
332 .await?;
333 assert_eq!(n_samples.total, 5000);
334
335 let samples = client
336 .samples(dataset.id(), None, &[], &["val".to_string()], &[], None)
337 .await?;
338 assert_eq!(samples.len(), 5000);
339
340 Ok(())
341 }
342
343 #[cfg(feature = "polars")]
344 #[tokio::test]
345 async fn test_coco_dataframe() -> Result<(), Error> {
346 let client = get_client().await?;
347 let project = client.projects(Some("Sample Project")).await?;
348 assert!(!project.is_empty());
349 let project = project.first().unwrap();
350 let datasets = client.datasets(project.id(), Some("COCO")).await?;
351 assert!(!datasets.is_empty());
352 let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
354
355 let annotation_set_id = dataset
356 .annotation_sets(&client)
357 .await?
358 .first()
359 .unwrap()
360 .id();
361
362 let annotations = client
363 .annotations(annotation_set_id, &["val".to_string()], &[], None)
364 .await?;
365 let df = annotations_dataframe(&annotations);
366 let df = df
367 .unique_stable(Some(&["name".to_string()]), UniqueKeepStrategy::First, None)
368 .unwrap();
369 assert_eq!(df.height(), 5000);
370
371 Ok(())
372 }
373
374 #[tokio::test]
375 async fn test_snapshots() -> Result<(), Error> {
376 let client = get_client().await?;
377 let snapshots = client.snapshots(None).await?;
378
379 for snapshot in snapshots {
380 let snapshot_id = snapshot.id();
381 let result = client.snapshot(snapshot_id).await?;
382 assert_eq!(result.id(), snapshot_id);
383 }
384
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn test_experiments() -> Result<(), Error> {
390 let client = get_client().await?;
391 let project = client.projects(Some("Unit Testing")).await?;
392 assert!(!project.is_empty());
393 let project = project.first().unwrap();
394 let experiments = client.experiments(project.id(), None).await?;
395
396 for experiment in experiments {
397 let experiment_id = experiment.id();
398 let result = client.experiment(experiment_id).await?;
399 assert_eq!(result.id(), experiment_id);
400 }
401
402 Ok(())
403 }
404
405 #[tokio::test]
406 async fn test_training_session() -> Result<(), Error> {
407 let client = get_client().await?;
408 let project = client.projects(Some("Unit Testing")).await?;
409 assert!(!project.is_empty());
410 let project = project.first().unwrap();
411 let experiment = client
412 .experiments(project.id(), Some("Unit Testing"))
413 .await?;
414 let experiment = experiment.first().unwrap();
415
416 let sessions = client
417 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
418 .await?;
419 assert_ne!(sessions.len(), 0);
420 let session = sessions.first().unwrap();
421
422 let metrics = HashMap::from([
423 ("epochs".to_string(), Parameter::Integer(10)),
424 ("loss".to_string(), Parameter::Real(0.05)),
425 (
426 "model".to_string(),
427 Parameter::String("modelpack".to_string()),
428 ),
429 ]);
430
431 session.set_metrics(&client, metrics).await?;
432 let updated_metrics = session.metrics(&client).await?;
433 assert_eq!(updated_metrics.len(), 3);
434 assert_eq!(updated_metrics.get("epochs"), Some(&Parameter::Integer(10)));
435 assert_eq!(updated_metrics.get("loss"), Some(&Parameter::Real(0.05)));
436 assert_eq!(
437 updated_metrics.get("model"),
438 Some(&Parameter::String("modelpack".to_string()))
439 );
440
441 println!("Updated Metrics: {:?}", updated_metrics);
442
443 let mut labels = tempfile::NamedTempFile::new()?;
444 write!(labels, "background")?;
445 labels.flush()?;
446
447 session
448 .upload(
449 &client,
450 &[(
451 "artifacts/labels.txt".to_string(),
452 labels.path().to_path_buf(),
453 )],
454 )
455 .await?;
456
457 let labels = session.download(&client, "artifacts/labels.txt").await?;
458 assert_eq!(labels, "background");
459
460 Ok(())
461 }
462
463 #[tokio::test]
464 async fn test_validate() -> Result<(), Error> {
465 let client = get_client().await?;
466 let project = client.projects(Some("Unit Testing")).await?;
467 assert!(!project.is_empty());
468 let project = project.first().unwrap();
469
470 let sessions = client.validation_sessions(project.id()).await?;
471 for session in &sessions {
472 let s = client.validation_session(session.id()).await?;
473 assert_eq!(s.id(), session.id());
474 assert_eq!(s.description(), session.description());
475 }
476
477 let session = sessions
478 .into_iter()
479 .find(|s| s.name() == "modelpack-usermanaged")
480 .unwrap_or_else(|| {
481 panic!(
482 "Validation session 'modelpack-usermanaged' not found in project '{}'",
483 project.name()
484 )
485 });
486
487 let metrics = HashMap::from([("accuracy".to_string(), Parameter::Real(0.95))]);
488 session.set_metrics(&client, metrics).await?;
489
490 let metrics = session.metrics(&client).await?;
491 assert_eq!(metrics.get("accuracy"), Some(&Parameter::Real(0.95)));
492
493 Ok(())
494 }
495
496 #[tokio::test]
497 async fn test_artifacts() -> Result<(), Error> {
498 let client = get_client().await?;
499 let project = client.projects(Some("Unit Testing")).await?;
500 assert!(!project.is_empty());
501 let project = project.first().unwrap();
502 let experiment = client
503 .experiments(project.id(), Some("Unit Testing"))
504 .await?;
505 let experiment = experiment.first().unwrap();
506 let trainer = client
507 .training_sessions(experiment.id(), Some("modelpack-960x540"))
508 .await?;
509 let trainer = trainer.first().unwrap();
510 let artifacts = client.artifacts(trainer.id()).await?;
511 assert!(!artifacts.is_empty());
512
513 for artifact in artifacts {
514 client
515 .download_artifact(
516 trainer.id(),
517 artifact.name(),
518 Some(artifact.name().into()),
519 None,
520 )
521 .await?;
522 }
523
524 let res = client
525 .download_artifact(
526 trainer.id(),
527 "fakefile.txt",
528 Some("fakefile.txt".into()),
529 None,
530 )
531 .await;
532 assert!(res.is_err());
533 assert!(!Path::new("fakefile.txt").exists());
534
535 Ok(())
536 }
537
538 #[tokio::test]
539 async fn test_checkpoints() -> Result<(), Error> {
540 let client = get_client().await?;
541 let project = client.projects(Some("Unit Testing")).await?;
542 assert!(!project.is_empty());
543 let project = project.first().unwrap();
544 let experiment = client
545 .experiments(project.id(), Some("Unit Testing"))
546 .await?;
547 let experiment = experiment.first().unwrap_or_else(|| {
548 panic!(
549 "Experiment 'Unit Testing' not found in project '{}'",
550 project.name()
551 )
552 });
553 let trainer = client
554 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
555 .await?;
556 let trainer = trainer.first().unwrap();
557
558 {
559 let mut chkpt = File::create("checkpoint.txt")?;
560 chkpt.write_all(b"Test Checkpoint")?;
561 }
562
563 trainer
564 .upload(
565 &client,
566 &[(
567 "checkpoints/checkpoint.txt".to_string(),
568 "checkpoint.txt".into(),
569 )],
570 )
571 .await?;
572
573 client
574 .download_checkpoint(
575 trainer.id(),
576 "checkpoint.txt",
577 Some("checkpoint2.txt".into()),
578 None,
579 )
580 .await?;
581
582 let chkpt = read_to_string("checkpoint2.txt")?;
583 assert_eq!(chkpt, "Test Checkpoint");
584
585 let res = client
586 .download_checkpoint(
587 trainer.id(),
588 "fakefile.txt",
589 Some("fakefile.txt".into()),
590 None,
591 )
592 .await;
593 assert!(res.is_err());
594 assert!(!Path::new("fakefile.txt").exists());
595
596 Ok(())
597 }
598
599 #[tokio::test]
600 async fn test_tasks() -> Result<(), Error> {
601 let client = get_client().await?;
602 let project = client.projects(Some("Unit Testing")).await?;
603 let project = project.first().unwrap();
604 let tasks = client.tasks(None, None, None, None).await?;
605
606 for task in tasks {
607 let task_info = client.task_info(task.id()).await?;
608 println!("{} - {}", task, task_info);
609 }
610
611 let tasks = client
612 .tasks(Some("modelpack-usermanaged"), None, None, None)
613 .await?;
614 let tasks = tasks
615 .into_iter()
616 .map(|t| client.task_info(t.id()))
617 .collect::<Vec<_>>();
618 let tasks = futures::future::try_join_all(tasks).await?;
619 let tasks = tasks
620 .into_iter()
621 .filter(|t| t.project_id() == Some(project.id()))
622 .collect::<Vec<_>>();
623 assert_ne!(tasks.len(), 0);
624 let task = &tasks[0];
625
626 let t = client.task_status(task.id(), "training").await?;
627 assert_eq!(t.id(), task.id());
628 assert_eq!(t.status(), "training");
629
630 let stages = [
631 ("download", "Downloading Dataset"),
632 ("train", "Training Model"),
633 ("export", "Exporting Model"),
634 ];
635 client.set_stages(task.id(), &stages).await?;
636
637 client
638 .update_stage(task.id(), "download", "running", "Downloading dataset", 50)
639 .await?;
640
641 let task = client.task_info(task.id()).await?;
642 println!("task progress: {:?}", task.stages());
643
644 Ok(())
645 }
646}