1mod api;
57mod client;
58mod dataset;
59mod error;
60
61pub use crate::{
62 api::{
63 AnnotationSetID, AppId, Artifact, DatasetID, DatasetParams, Experiment, ExperimentID,
64 ImageId, Organization, OrganizationID, Parameter, PresignedUrl, Project, ProjectID,
65 SampleID, SamplesCountResult, SamplesPopulateParams, SamplesPopulateResult, SequenceId,
66 SnapshotID, Stage, Task, TaskID, TaskInfo, TrainingSession, TrainingSessionID,
67 ValidationSession, ValidationSessionID,
68 },
69 client::{Client, Progress},
70 dataset::{
71 Annotation, AnnotationSet, AnnotationType, Box2d, Box3d, Dataset, FileType, GpsData,
72 ImuData, Label, Location, Mask, Sample, SampleFile,
73 },
74 error::Error,
75};
76
77#[cfg(feature = "polars")]
78pub use crate::dataset::annotations_dataframe;
79
80#[cfg(feature = "polars")]
81pub use crate::dataset::samples_dataframe;
82
83#[cfg(feature = "polars")]
84pub use crate::dataset::unflatten_polygon_coordinates;
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use std::{
90 collections::HashMap,
91 env,
92 fs::{File, read_to_string},
93 io::Write,
94 path::PathBuf,
95 };
96
97 fn get_test_data_dir() -> PathBuf {
100 let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
101 .parent()
102 .expect("CARGO_MANIFEST_DIR should have parent")
103 .parent()
104 .expect("workspace root should exist")
105 .join("target")
106 .join("testdata");
107
108 std::fs::create_dir_all(&test_dir).expect("Failed to create test data directory");
109 test_dir
110 }
111
112 #[ctor::ctor]
113 fn init() {
114 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
115 }
116
117 async fn get_client() -> Result<Client, Error> {
118 let client = Client::new()?.with_token_path(None)?;
119
120 let client = match env::var("STUDIO_TOKEN") {
121 Ok(token) => client.with_token(&token)?,
122 Err(_) => client,
123 };
124
125 let client = match env::var("STUDIO_SERVER") {
126 Ok(server) => client.with_server(&server)?,
127 Err(_) => client,
128 };
129
130 let client = match (env::var("STUDIO_USERNAME"), env::var("STUDIO_PASSWORD")) {
131 (Ok(username), Ok(password)) => client.with_login(&username, &password).await?,
132 _ => client,
133 };
134
135 client.verify_token().await?;
136
137 Ok(client)
138 }
139
140 async fn get_training_session_for_artifacts() -> Result<TrainingSession, Error> {
142 let client = get_client().await?;
143 let project = client
144 .projects(Some("Unit Testing"))
145 .await?
146 .into_iter()
147 .next()
148 .ok_or_else(|| Error::InvalidParameters("Unit Testing project not found".into()))?;
149 let experiment = client
150 .experiments(project.id(), Some("Unit Testing"))
151 .await?
152 .into_iter()
153 .next()
154 .ok_or_else(|| Error::InvalidParameters("Unit Testing experiment not found".into()))?;
155 let session = client
156 .training_sessions(experiment.id(), Some("modelpack-960x540"))
157 .await?
158 .into_iter()
159 .next()
160 .ok_or_else(|| {
161 Error::InvalidParameters("modelpack-960x540 session not found".into())
162 })?;
163 Ok(session)
164 }
165
166 async fn get_training_session_for_checkpoints() -> Result<TrainingSession, Error> {
168 let client = get_client().await?;
169 let project = client
170 .projects(Some("Unit Testing"))
171 .await?
172 .into_iter()
173 .next()
174 .ok_or_else(|| Error::InvalidParameters("Unit Testing project not found".into()))?;
175 let experiment = client
176 .experiments(project.id(), Some("Unit Testing"))
177 .await?
178 .into_iter()
179 .next()
180 .ok_or_else(|| Error::InvalidParameters("Unit Testing experiment not found".into()))?;
181 let session = client
182 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
183 .await?
184 .into_iter()
185 .next()
186 .ok_or_else(|| {
187 Error::InvalidParameters("modelpack-usermanaged session not found".into())
188 })?;
189 Ok(session)
190 }
191
192 #[tokio::test]
193 async fn test_training_session() -> Result<(), Error> {
194 let client = get_client().await?;
195 let project = client.projects(Some("Unit Testing")).await?;
196 assert!(!project.is_empty());
197 let project = project
198 .first()
199 .expect("'Unit Testing' project should exist");
200 let experiment = client
201 .experiments(project.id(), Some("Unit Testing"))
202 .await?;
203 let experiment = experiment
204 .first()
205 .expect("'Unit Testing' experiment should exist");
206
207 let sessions = client
208 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
209 .await?;
210 assert_ne!(sessions.len(), 0);
211 let session = sessions
212 .first()
213 .expect("Training sessions should exist for experiment");
214
215 let metrics = HashMap::from([
216 ("epochs".to_string(), Parameter::Integer(10)),
217 ("loss".to_string(), Parameter::Real(0.05)),
218 (
219 "model".to_string(),
220 Parameter::String("modelpack".to_string()),
221 ),
222 ]);
223
224 session.set_metrics(&client, metrics).await?;
225 let updated_metrics = session.metrics(&client).await?;
226 assert_eq!(updated_metrics.len(), 3);
227 assert_eq!(updated_metrics.get("epochs"), Some(&Parameter::Integer(10)));
228 assert_eq!(updated_metrics.get("loss"), Some(&Parameter::Real(0.05)));
229 assert_eq!(
230 updated_metrics.get("model"),
231 Some(&Parameter::String("modelpack".to_string()))
232 );
233
234 println!("Updated Metrics: {:?}", updated_metrics);
235
236 let mut labels = tempfile::NamedTempFile::new()?;
237 write!(labels, "background")?;
238 labels.flush()?;
239
240 session
241 .upload(
242 &client,
243 &[(
244 "artifacts/labels.txt".to_string(),
245 labels.path().to_path_buf(),
246 )],
247 )
248 .await?;
249
250 let labels = session.download(&client, "artifacts/labels.txt").await?;
251 assert_eq!(labels, "background");
252
253 Ok(())
254 }
255
256 #[tokio::test]
257 async fn test_validate() -> Result<(), Error> {
258 let client = get_client().await?;
259 let project = client.projects(Some("Unit Testing")).await?;
260 assert!(!project.is_empty());
261 let project = project
262 .first()
263 .expect("'Unit Testing' project should exist");
264
265 let sessions = client.validation_sessions(project.id()).await?;
266 for session in &sessions {
267 let s = client.validation_session(session.id()).await?;
268 assert_eq!(s.id(), session.id());
269 assert_eq!(s.description(), session.description());
270 }
271
272 let session = sessions
273 .into_iter()
274 .find(|s| s.name() == "modelpack-usermanaged")
275 .ok_or_else(|| {
276 Error::InvalidParameters(format!(
277 "Validation session 'modelpack-usermanaged' not found in project '{}'",
278 project.name()
279 ))
280 })?;
281
282 let metrics = HashMap::from([("accuracy".to_string(), Parameter::Real(0.95))]);
283 session.set_metrics(&client, metrics).await?;
284
285 let metrics = session.metrics(&client).await?;
286 assert_eq!(metrics.get("accuracy"), Some(&Parameter::Real(0.95)));
287
288 Ok(())
289 }
290
291 #[tokio::test]
292 async fn test_download_artifact_success() -> Result<(), Error> {
293 let trainer = get_training_session_for_artifacts().await?;
294 let client = get_client().await?;
295 let artifacts = client.artifacts(trainer.id()).await?;
296 assert!(!artifacts.is_empty());
297
298 let test_dir = get_test_data_dir();
299 let artifact = &artifacts[0];
300 let output_path = test_dir.join(artifact.name());
301
302 client
303 .download_artifact(
304 trainer.id(),
305 artifact.name(),
306 Some(output_path.clone()),
307 None,
308 )
309 .await?;
310
311 assert!(output_path.exists());
312 if output_path.exists() {
313 std::fs::remove_file(&output_path)?;
314 }
315
316 Ok(())
317 }
318
319 #[tokio::test]
320 async fn test_download_artifact_not_found() -> Result<(), Error> {
321 let trainer = get_training_session_for_artifacts().await?;
322 let client = get_client().await?;
323 let test_dir = get_test_data_dir();
324 let fake_path = test_dir.join("nonexistent_artifact.txt");
325
326 let result = client
327 .download_artifact(
328 trainer.id(),
329 "nonexistent_artifact.txt",
330 Some(fake_path.clone()),
331 None,
332 )
333 .await;
334
335 assert!(result.is_err());
336 assert!(!fake_path.exists());
337
338 Ok(())
339 }
340
341 #[tokio::test]
342 async fn test_artifacts() -> Result<(), Error> {
343 let client = get_client().await?;
344 let project = client.projects(Some("Unit Testing")).await?;
345 assert!(!project.is_empty());
346 let project = project
347 .first()
348 .expect("'Unit Testing' project should exist");
349 let experiment = client
350 .experiments(project.id(), Some("Unit Testing"))
351 .await?;
352 let experiment = experiment
353 .first()
354 .expect("'Unit Testing' experiment should exist");
355 let trainer = client
356 .training_sessions(experiment.id(), Some("modelpack-960x540"))
357 .await?;
358 let trainer = trainer
359 .first()
360 .expect("'modelpack-960x540' training session should exist");
361 let artifacts = client.artifacts(trainer.id()).await?;
362 assert!(!artifacts.is_empty());
363
364 let test_dir = get_test_data_dir();
365
366 for artifact in artifacts {
367 let output_path = test_dir.join(artifact.name());
368 client
369 .download_artifact(
370 trainer.id(),
371 artifact.name(),
372 Some(output_path.clone()),
373 None,
374 )
375 .await?;
376
377 if output_path.exists() {
379 std::fs::remove_file(&output_path)?;
380 }
381 }
382
383 let fake_path = test_dir.join("fakefile.txt");
384 let res = client
385 .download_artifact(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
386 .await;
387 assert!(res.is_err());
388 assert!(!fake_path.exists());
389
390 Ok(())
391 }
392
393 #[tokio::test]
394 async fn test_download_checkpoint_success() -> Result<(), Error> {
395 let trainer = get_training_session_for_checkpoints().await?;
396 let client = get_client().await?;
397 let test_dir = get_test_data_dir();
398
399 let checkpoint_path = test_dir.join("test_checkpoint.txt");
401 {
402 let mut f = File::create(&checkpoint_path)?;
403 f.write_all(b"Test Checkpoint Content")?;
404 }
405
406 trainer
408 .upload(
409 &client,
410 &[(
411 "checkpoints/test_checkpoint.txt".to_string(),
412 checkpoint_path.clone(),
413 )],
414 )
415 .await?;
416
417 let download_path = test_dir.join("downloaded_checkpoint.txt");
419 client
420 .download_checkpoint(
421 trainer.id(),
422 "test_checkpoint.txt",
423 Some(download_path.clone()),
424 None,
425 )
426 .await?;
427
428 let content = read_to_string(&download_path)?;
429 assert_eq!(content, "Test Checkpoint Content");
430
431 if checkpoint_path.exists() {
433 std::fs::remove_file(&checkpoint_path)?;
434 }
435 if download_path.exists() {
436 std::fs::remove_file(&download_path)?;
437 }
438
439 Ok(())
440 }
441
442 #[tokio::test]
443 async fn test_download_checkpoint_not_found() -> Result<(), Error> {
444 let trainer = get_training_session_for_checkpoints().await?;
445 let client = get_client().await?;
446 let test_dir = get_test_data_dir();
447 let fake_path = test_dir.join("nonexistent_checkpoint.txt");
448
449 let result = client
450 .download_checkpoint(
451 trainer.id(),
452 "nonexistent_checkpoint.txt",
453 Some(fake_path.clone()),
454 None,
455 )
456 .await;
457
458 assert!(result.is_err());
459 assert!(!fake_path.exists());
460
461 Ok(())
462 }
463
464 #[tokio::test]
465 async fn test_checkpoints() -> Result<(), Error> {
466 let client = get_client().await?;
467 let project = client.projects(Some("Unit Testing")).await?;
468 assert!(!project.is_empty());
469 let project = project
470 .first()
471 .expect("'Unit Testing' project should exist");
472 let experiment = client
473 .experiments(project.id(), Some("Unit Testing"))
474 .await?;
475 let experiment = experiment.first().ok_or_else(|| {
476 Error::InvalidParameters(format!(
477 "Experiment 'Unit Testing' not found in project '{}'",
478 project.name()
479 ))
480 })?;
481 let trainer = client
482 .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
483 .await?;
484 let trainer = trainer
485 .first()
486 .expect("'modelpack-usermanaged' training session should exist");
487
488 let test_dir = get_test_data_dir();
489 let checkpoint_path = test_dir.join("checkpoint.txt");
490 let checkpoint2_path = test_dir.join("checkpoint2.txt");
491
492 {
493 let mut chkpt = File::create(&checkpoint_path)?;
494 chkpt.write_all(b"Test Checkpoint")?;
495 }
496
497 trainer
498 .upload(
499 &client,
500 &[(
501 "checkpoints/checkpoint.txt".to_string(),
502 checkpoint_path.clone(),
503 )],
504 )
505 .await?;
506
507 client
508 .download_checkpoint(
509 trainer.id(),
510 "checkpoint.txt",
511 Some(checkpoint2_path.clone()),
512 None,
513 )
514 .await?;
515
516 let chkpt = read_to_string(&checkpoint2_path)?;
517 assert_eq!(chkpt, "Test Checkpoint");
518
519 let fake_path = test_dir.join("fakefile.txt");
520 let res = client
521 .download_checkpoint(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
522 .await;
523 assert!(res.is_err());
524 assert!(!fake_path.exists());
525
526 if checkpoint_path.exists() {
528 std::fs::remove_file(&checkpoint_path)?;
529 }
530 if checkpoint2_path.exists() {
531 std::fs::remove_file(&checkpoint2_path)?;
532 }
533
534 Ok(())
535 }
536
537 #[tokio::test]
538 async fn test_task_retrieval() -> Result<(), Error> {
539 let client = get_client().await?;
540
541 let tasks = client.tasks(None, None, None, None).await?;
543 assert!(!tasks.is_empty());
544
545 let task_id = tasks[0].id();
547 let task_info = client.task_info(task_id).await?;
548 assert_eq!(task_info.id(), task_id);
549
550 Ok(())
551 }
552
553 #[tokio::test]
554 async fn test_task_filtering_by_name() -> Result<(), Error> {
555 let client = get_client().await?;
556 let project = client.projects(Some("Unit Testing")).await?;
557 let project = project
558 .first()
559 .expect("'Unit Testing' project should exist");
560
561 let tasks = client
563 .tasks(Some("modelpack-usermanaged"), None, None, None)
564 .await?;
565
566 if !tasks.is_empty() {
567 let task_infos = tasks
569 .into_iter()
570 .map(|t| client.task_info(t.id()))
571 .collect::<Vec<_>>();
572 let task_infos = futures::future::try_join_all(task_infos).await?;
573
574 let filtered = task_infos
576 .into_iter()
577 .filter(|t| t.project_id() == Some(project.id()))
578 .collect::<Vec<_>>();
579
580 if !filtered.is_empty() {
581 assert_eq!(filtered[0].project_id(), Some(project.id()));
582 }
583 }
584
585 Ok(())
586 }
587
588 #[tokio::test]
589 async fn test_task_status_and_stages() -> Result<(), Error> {
590 let client = get_client().await?;
591
592 let tasks = client.tasks(None, None, None, None).await?;
594 if tasks.is_empty() {
595 return Ok(());
596 }
597
598 let task_id = tasks[0].id();
599
600 let status = client.task_status(task_id, "training").await?;
602 assert_eq!(status.id(), task_id);
603 assert_eq!(status.status(), "training");
604
605 let stages = [
607 ("download", "Downloading Dataset"),
608 ("train", "Training Model"),
609 ("export", "Exporting Model"),
610 ];
611 client.set_stages(task_id, &stages).await?;
612
613 client
615 .update_stage(task_id, "download", "running", "Downloading dataset", 50)
616 .await?;
617
618 let updated_task = client.task_info(task_id).await?;
620 assert_eq!(updated_task.id(), task_id);
621
622 Ok(())
623 }
624
625 #[tokio::test]
626 async fn test_tasks() -> Result<(), Error> {
627 let client = get_client().await?;
628 let project = client.projects(Some("Unit Testing")).await?;
629 let project = project
630 .first()
631 .expect("'Unit Testing' project should exist");
632 let tasks = client.tasks(None, None, None, None).await?;
633
634 for task in tasks {
635 let task_info = client.task_info(task.id()).await?;
636 println!("{} - {}", task, task_info);
637 }
638
639 let tasks = client
640 .tasks(Some("modelpack-usermanaged"), None, None, None)
641 .await?;
642 let tasks = tasks
643 .into_iter()
644 .map(|t| client.task_info(t.id()))
645 .collect::<Vec<_>>();
646 let tasks = futures::future::try_join_all(tasks).await?;
647 let tasks = tasks
648 .into_iter()
649 .filter(|t| t.project_id() == Some(project.id()))
650 .collect::<Vec<_>>();
651 assert_ne!(tasks.len(), 0);
652 let task = &tasks[0];
653
654 let t = client.task_status(task.id(), "training").await?;
655 assert_eq!(t.id(), task.id());
656 assert_eq!(t.status(), "training");
657
658 let stages = [
659 ("download", "Downloading Dataset"),
660 ("train", "Training Model"),
661 ("export", "Exporting Model"),
662 ];
663 client.set_stages(task.id(), &stages).await?;
664
665 client
666 .update_stage(task.id(), "download", "running", "Downloading dataset", 50)
667 .await?;
668
669 let task = client.task_info(task.id()).await?;
670 println!("task progress: {:?}", task.stages());
671
672 Ok(())
673 }
674}