1mod api;
57mod client;
58mod dataset;
59mod error;
60
61pub use crate::{
62 api::{
63 AnnotationSetID, AppId, Artifact, DatasetID, DatasetParams, Experiment, ExperimentID,
64 ImageId, Organization, OrganizationID, Parameter, Project, ProjectID, SampleID, SequenceId,
65 SnapshotID, Stage, Task, TaskID, TaskInfo, TrainingSession, TrainingSessionID,
66 ValidationSession, ValidationSessionID,
67 },
68 client::{Client, Progress},
69 dataset::{
70 Annotation, AnnotationSet, AnnotationType, Box2d, Box3d, Dataset, FileType, Label, Mask,
71 Sample,
72 },
73 error::Error,
74};
75
76#[cfg(feature = "polars")]
77pub use crate::dataset::annotations_dataframe;
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use polars::frame::UniqueKeepStrategy;
83 use std::{
84 collections::HashMap,
85 env,
86 fs::{File, read_to_string},
87 io::Write,
88 path::PathBuf,
89 };
90 use tokio::time::{Duration, sleep};
91
92 fn get_test_data_dir() -> PathBuf {
95 let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
96 .parent()
97 .unwrap()
98 .parent()
99 .unwrap()
100 .join("target")
101 .join("testdata");
102
103 std::fs::create_dir_all(&test_dir).expect("Failed to create test data directory");
104 test_dir
105 }
106
107 #[ctor::ctor]
108 fn init() {
109 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
110 }
111
112 #[tokio::test]
113 async fn test_version() -> Result<(), Error> {
114 let client = match env::var("STUDIO_SERVER") {
115 Ok(server) => Client::new()?.with_server(&server)?,
116 Err(_) => Client::new()?,
117 };
118 let result = client.version().await?;
119 println!("EdgeFirst Studio Version: {}", result);
120 Ok(())
121 }
122
123 async fn get_client() -> Result<Client, Error> {
124 let client = Client::new()?.with_token_path(None)?;
125
126 let client = match env::var("STUDIO_TOKEN") {
127 Ok(token) => client.with_token(&token)?,
128 Err(_) => client,
129 };
130
131 let client = match env::var("STUDIO_SERVER") {
132 Ok(server) => client.with_server(&server)?,
133 Err(_) => client,
134 };
135
136 let client = match (env::var("STUDIO_USERNAME"), env::var("STUDIO_PASSWORD")) {
137 (Ok(username), Ok(password)) => client.with_login(&username, &password).await?,
138 _ => client,
139 };
140
141 client.verify_token().await?;
142
143 Ok(client)
144 }
145
146 #[tokio::test]
147 async fn test_token() -> Result<(), Error> {
148 let client = get_client().await?;
149 let token = client.token().await;
150 assert!(!token.is_empty());
151 println!("Token: {}", token);
152
153 let exp = client.token_expiration().await?;
154 println!("Token Expiration: {}", exp);
155
156 let username = client.username().await?;
157 assert!(!username.is_empty());
158 println!("Username: {}", username);
159
160 sleep(Duration::from_secs(2)).await;
162
163 client.renew_token().await?;
164 let new_token = client.token().await;
165 assert!(!new_token.is_empty());
166 assert_ne!(token, new_token);
167 println!("New Token Expiration: {}", client.token_expiration().await?);
168
169 Ok(())
170 }
171
172 #[tokio::test]
173 async fn test_organization() -> Result<(), Error> {
174 let client = get_client().await?;
175 let org = client.organization().await?;
176 println!(
177 "Organization: {}\nID: {}\nCredits: {}",
178 org.name(),
179 org.id(),
180 org.credits()
181 );
182 Ok(())
183 }
184
185 #[tokio::test]
186 async fn test_projects() -> Result<(), Error> {
187 let client = get_client().await?;
188 let project = client.projects(Some("Unit Testing")).await?;
189 assert!(!project.is_empty());
190 Ok(())
191 }
192
193 #[tokio::test]
194 async fn test_datasets() -> Result<(), Error> {
195 let client = get_client().await?;
196 let project = client.projects(Some("Unit Testing")).await?;
197 assert!(!project.is_empty());
198 let project = project.first().unwrap();
199 let datasets = client.datasets(project.id(), None).await?;
200
201 for dataset in datasets {
202 let dataset_id = dataset.id();
203 let result = client.dataset(dataset_id).await?;
204 assert_eq!(result.id(), dataset_id);
205 }
206
207 Ok(())
208 }
209
210 #[tokio::test]
211 async fn test_labels() -> Result<(), Error> {
212 let client = get_client().await?;
213 let project = client.projects(Some("Unit Testing")).await?;
214 assert!(!project.is_empty());
215 let project = project.first().unwrap();
216 let datasets = client.datasets(project.id(), Some("Test Labels")).await?;
217 let dataset = datasets.first().unwrap_or_else(|| {
218 panic!(
219 "Dataset 'Test Labels' not found in project '{}'",
220 project.name()
221 )
222 });
223
224 let labels = dataset.labels(&client).await?;
225 for label in labels {
226 label.remove(&client).await?;
227 }
228
229 let labels = dataset.labels(&client).await?;
230 assert_eq!(labels.len(), 0);
231
232 dataset.add_label(&client, "test").await?;
233 let labels = dataset.labels(&client).await?;
234 assert_eq!(labels.len(), 1);
235 assert_eq!(labels[0].name(), "test");
236
237 dataset.remove_label(&client, "test").await?;
238 let labels = dataset.labels(&client).await?;
239 assert_eq!(labels.len(), 0);
240
241 Ok(())
242 }
243
244 #[tokio::test]
245 async fn test_coco() -> Result<(), Error> {
246 let coco_labels = HashMap::from([
247 (0, "person"),
248 (1, "bicycle"),
249 (2, "car"),
250 (3, "motorcycle"),
251 (4, "airplane"),
252 (5, "bus"),
253 (6, "train"),
254 (7, "truck"),
255 (8, "boat"),
256 (9, "traffic light"),
257 (10, "fire hydrant"),
258 (11, "stop sign"),
259 (12, "parking meter"),
260 (13, "bench"),
261 (14, "bird"),
262 (15, "cat"),
263 (16, "dog"),
264 (17, "horse"),
265 (18, "sheep"),
266 (19, "cow"),
267 (20, "elephant"),
268 (21, "bear"),
269 (22, "zebra"),
270 (23, "giraffe"),
271 (24, "backpack"),
272 (25, "umbrella"),
273 (26, "handbag"),
274 (27, "tie"),
275 (28, "suitcase"),
276 (29, "frisbee"),
277 (30, "skis"),
278 (31, "snowboard"),
279 (32, "sports ball"),
280 (33, "kite"),
281 (34, "baseball bat"),
282 (35, "baseball glove"),
283 (36, "skateboard"),
284 (37, "surfboard"),
285 (38, "tennis racket"),
286 (39, "bottle"),
287 (40, "wine glass"),
288 (41, "cup"),
289 (42, "fork"),
290 (43, "knife"),
291 (44, "spoon"),
292 (45, "bowl"),
293 (46, "banana"),
294 (47, "apple"),
295 (48, "sandwich"),
296 (49, "orange"),
297 (50, "broccoli"),
298 (51, "carrot"),
299 (52, "hot dog"),
300 (53, "pizza"),
301 (54, "donut"),
302 (55, "cake"),
303 (56, "chair"),
304 (57, "couch"),
305 (58, "potted plant"),
306 (59, "bed"),
307 (60, "dining table"),
308 (61, "toilet"),
309 (62, "tv"),
310 (63, "laptop"),
311 (64, "mouse"),
312 (65, "remote"),
313 (66, "keyboard"),
314 (67, "cell phone"),
315 (68, "microwave"),
316 (69, "oven"),
317 (70, "toaster"),
318 (71, "sink"),
319 (72, "refrigerator"),
320 (73, "book"),
321 (74, "clock"),
322 (75, "vase"),
323 (76, "scissors"),
324 (77, "teddy bear"),
325 (78, "hair drier"),
326 (79, "toothbrush"),
327 ]);
328
329 let client = get_client().await?;
330 let project = client.projects(Some("Sample Project")).await?;
331 assert!(!project.is_empty());
332 let project = project.first().unwrap();
333 let datasets = client.datasets(project.id(), Some("COCO")).await?;
334 assert!(!datasets.is_empty());
335 let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
337
338 let labels = dataset.labels(&client).await?;
339 assert_eq!(labels.len(), 80);
340
341 for label in &labels {
342 assert_eq!(label.name(), coco_labels[&label.index()]);
343 }
344
345 let n_samples = client
346 .samples_count(dataset.id(), None, &[], &["val".to_string()], &[])
347 .await?;
348 assert_eq!(n_samples.total, 5000);
349
350 let samples = client
351 .samples(dataset.id(), None, &[], &["val".to_string()], &[], None)
352 .await?;
353 assert_eq!(samples.len(), 5000);
354
355 Ok(())
356 }
357
358 #[cfg(feature = "polars")]
359 #[tokio::test]
360 async fn test_coco_dataframe() -> Result<(), Error> {
361 let client = get_client().await?;
362 let project = client.projects(Some("Sample Project")).await?;
363 assert!(!project.is_empty());
364 let project = project.first().unwrap();
365 let datasets = client.datasets(project.id(), Some("COCO")).await?;
366 assert!(!datasets.is_empty());
367 let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
369
370 let annotation_set_id = dataset
371 .annotation_sets(&client)
372 .await?
373 .first()
374 .unwrap()
375 .id();
376
377 let annotations = client
378 .annotations(annotation_set_id, &["val".to_string()], &[], None)
379 .await?;
380 let df = annotations_dataframe(&annotations);
381 let df = df
382 .unique_stable(Some(&["name".to_string()]), UniqueKeepStrategy::First, None)
383 .unwrap();
384 assert_eq!(df.height(), 5000);
385
386 Ok(())
387 }
388
389 #[tokio::test]
390 async fn test_snapshots() -> Result<(), Error> {
391 let client = get_client().await?;
392 let snapshots = client.snapshots(None).await?;
393
394 for snapshot in snapshots {
395 let snapshot_id = snapshot.id();
396 let result = client.snapshot(snapshot_id).await?;
397 assert_eq!(result.id(), snapshot_id);
398 }
399
400 Ok(())
401 }
402
403 #[tokio::test]
404 async fn test_experiments() -> Result<(), Error> {
405 let client = get_client().await?;
406 let project = client.projects(Some("Unit Testing")).await?;
407 assert!(!project.is_empty());
408 let project = project.first().unwrap();
409 let experiments = client.experiments(project.id(), None).await?;
410
411 for experiment in experiments {
412 let experiment_id = experiment.id();
413 let result = client.experiment(experiment_id).await?;
414 assert_eq!(result.id(), experiment_id);
415 }
416
417 Ok(())
418 }
419
420 #[tokio::test]
421 async fn test_training_session() -> Result<(), Error> {
422 let client = get_client().await?;
423 let project = client.projects(Some("Unit Testing")).await?;
424 assert!(!project.is_empty());
425 let project = project.first().unwrap();
426 let experiment = client
427 .experiments(project.id(), Some("Unit Testing"))
428 .await?;
429 let experiment = experiment.first().unwrap();
430
431 let sessions = client
432 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
433 .await?;
434 assert_ne!(sessions.len(), 0);
435 let session = sessions.first().unwrap();
436
437 let metrics = HashMap::from([
438 ("epochs".to_string(), Parameter::Integer(10)),
439 ("loss".to_string(), Parameter::Real(0.05)),
440 (
441 "model".to_string(),
442 Parameter::String("modelpack".to_string()),
443 ),
444 ]);
445
446 session.set_metrics(&client, metrics).await?;
447 let updated_metrics = session.metrics(&client).await?;
448 assert_eq!(updated_metrics.len(), 3);
449 assert_eq!(updated_metrics.get("epochs"), Some(&Parameter::Integer(10)));
450 assert_eq!(updated_metrics.get("loss"), Some(&Parameter::Real(0.05)));
451 assert_eq!(
452 updated_metrics.get("model"),
453 Some(&Parameter::String("modelpack".to_string()))
454 );
455
456 println!("Updated Metrics: {:?}", updated_metrics);
457
458 let mut labels = tempfile::NamedTempFile::new()?;
459 write!(labels, "background")?;
460 labels.flush()?;
461
462 session
463 .upload(
464 &client,
465 &[(
466 "artifacts/labels.txt".to_string(),
467 labels.path().to_path_buf(),
468 )],
469 )
470 .await?;
471
472 let labels = session.download(&client, "artifacts/labels.txt").await?;
473 assert_eq!(labels, "background");
474
475 Ok(())
476 }
477
478 #[tokio::test]
479 async fn test_validate() -> Result<(), Error> {
480 let client = get_client().await?;
481 let project = client.projects(Some("Unit Testing")).await?;
482 assert!(!project.is_empty());
483 let project = project.first().unwrap();
484
485 let sessions = client.validation_sessions(project.id()).await?;
486 for session in &sessions {
487 let s = client.validation_session(session.id()).await?;
488 assert_eq!(s.id(), session.id());
489 assert_eq!(s.description(), session.description());
490 }
491
492 let session = sessions
493 .into_iter()
494 .find(|s| s.name() == "modelpack-usermanaged")
495 .unwrap_or_else(|| {
496 panic!(
497 "Validation session 'modelpack-usermanaged' not found in project '{}'",
498 project.name()
499 )
500 });
501
502 let metrics = HashMap::from([("accuracy".to_string(), Parameter::Real(0.95))]);
503 session.set_metrics(&client, metrics).await?;
504
505 let metrics = session.metrics(&client).await?;
506 assert_eq!(metrics.get("accuracy"), Some(&Parameter::Real(0.95)));
507
508 Ok(())
509 }
510
511 #[tokio::test]
512 async fn test_artifacts() -> Result<(), Error> {
513 let client = get_client().await?;
514 let project = client.projects(Some("Unit Testing")).await?;
515 assert!(!project.is_empty());
516 let project = project.first().unwrap();
517 let experiment = client
518 .experiments(project.id(), Some("Unit Testing"))
519 .await?;
520 let experiment = experiment.first().unwrap();
521 let trainer = client
522 .training_sessions(experiment.id(), Some("modelpack-960x540"))
523 .await?;
524 let trainer = trainer.first().unwrap();
525 let artifacts = client.artifacts(trainer.id()).await?;
526 assert!(!artifacts.is_empty());
527
528 let test_dir = get_test_data_dir();
529
530 for artifact in artifacts {
531 let output_path = test_dir.join(artifact.name());
532 client
533 .download_artifact(
534 trainer.id(),
535 artifact.name(),
536 Some(output_path.clone()),
537 None,
538 )
539 .await?;
540
541 if output_path.exists() {
543 std::fs::remove_file(&output_path)?;
544 }
545 }
546
547 let fake_path = test_dir.join("fakefile.txt");
548 let res = client
549 .download_artifact(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
550 .await;
551 assert!(res.is_err());
552 assert!(!fake_path.exists());
553
554 Ok(())
555 }
556
557 #[tokio::test]
558 async fn test_checkpoints() -> Result<(), Error> {
559 let client = get_client().await?;
560 let project = client.projects(Some("Unit Testing")).await?;
561 assert!(!project.is_empty());
562 let project = project.first().unwrap();
563 let experiment = client
564 .experiments(project.id(), Some("Unit Testing"))
565 .await?;
566 let experiment = experiment.first().unwrap_or_else(|| {
567 panic!(
568 "Experiment 'Unit Testing' not found in project '{}'",
569 project.name()
570 )
571 });
572 let trainer = client
573 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
574 .await?;
575 let trainer = trainer.first().unwrap();
576
577 let test_dir = get_test_data_dir();
578 let checkpoint_path = test_dir.join("checkpoint.txt");
579 let checkpoint2_path = test_dir.join("checkpoint2.txt");
580
581 {
582 let mut chkpt = File::create(&checkpoint_path)?;
583 chkpt.write_all(b"Test Checkpoint")?;
584 }
585
586 trainer
587 .upload(
588 &client,
589 &[(
590 "checkpoints/checkpoint.txt".to_string(),
591 checkpoint_path.clone(),
592 )],
593 )
594 .await?;
595
596 client
597 .download_checkpoint(
598 trainer.id(),
599 "checkpoint.txt",
600 Some(checkpoint2_path.clone()),
601 None,
602 )
603 .await?;
604
605 let chkpt = read_to_string(&checkpoint2_path)?;
606 assert_eq!(chkpt, "Test Checkpoint");
607
608 let fake_path = test_dir.join("fakefile.txt");
609 let res = client
610 .download_checkpoint(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
611 .await;
612 assert!(res.is_err());
613 assert!(!fake_path.exists());
614
615 if checkpoint_path.exists() {
617 std::fs::remove_file(&checkpoint_path)?;
618 }
619 if checkpoint2_path.exists() {
620 std::fs::remove_file(&checkpoint2_path)?;
621 }
622
623 Ok(())
624 }
625
626 #[tokio::test]
627 async fn test_tasks() -> Result<(), Error> {
628 let client = get_client().await?;
629 let project = client.projects(Some("Unit Testing")).await?;
630 let project = project.first().unwrap();
631 let tasks = client.tasks(None, None, None, None).await?;
632
633 for task in tasks {
634 let task_info = client.task_info(task.id()).await?;
635 println!("{} - {}", task, task_info);
636 }
637
638 let tasks = client
639 .tasks(Some("modelpack-usermanaged"), None, None, None)
640 .await?;
641 let tasks = tasks
642 .into_iter()
643 .map(|t| client.task_info(t.id()))
644 .collect::<Vec<_>>();
645 let tasks = futures::future::try_join_all(tasks).await?;
646 let tasks = tasks
647 .into_iter()
648 .filter(|t| t.project_id() == Some(project.id()))
649 .collect::<Vec<_>>();
650 assert_ne!(tasks.len(), 0);
651 let task = &tasks[0];
652
653 let t = client.task_status(task.id(), "training").await?;
654 assert_eq!(t.id(), task.id());
655 assert_eq!(t.status(), "training");
656
657 let stages = [
658 ("download", "Downloading Dataset"),
659 ("train", "Training Model"),
660 ("export", "Exporting Model"),
661 ];
662 client.set_stages(task.id(), &stages).await?;
663
664 client
665 .update_stage(task.id(), "download", "running", "Downloading dataset", 50)
666 .await?;
667
668 let task = client.task_info(task.id()).await?;
669 println!("task progress: {:?}", task.stages());
670
671 Ok(())
672 }
673}