edgefirst_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4// SPDX-License-Identifier: Apache-2.0
5// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
6
7//! # EdgeFirst Studio Client Library
8//!
9//! The EdgeFirst Studio Client Library provides a Rust client for interacting
10//! with EdgeFirst Studio, a comprehensive platform for computer vision and
11//! machine learning workflows. This library enables developers to
12//! programmatically manage datasets, annotations, training sessions, and other
13//! Studio resources.
14//!
15//! ## Features
16//!
17//! - **Authentication**: Secure token-based authentication with automatic
18//!   renewal
19//! - **Dataset Management**: Upload, download, and manage datasets with various
20//!   file types
21//! - **Annotation Management**: Create, update, and retrieve annotations for
22//!   computer vision tasks
23//! - **Training & Validation**: Manage machine learning training and validation
24//!   sessions
25//! - **Project Organization**: Organize work into projects with hierarchical
26//!   structure
27//! - **Polars Integration**: Optional integration with Polars DataFrames for
28//!   data analysis
29//!
30//! ## Quick Start
31//!
32//! ```rust,no_run
33//! use edgefirst_client::{Client, Error};
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<(), Error> {
37//!     // Create a new client
38//!     let client = Client::new()?;
39//!
40//!     // Authenticate with username and password
41//!     let client = client.with_login("username", "password").await?;
42//!
43//!     // List available projects
44//!     let projects = client.projects(None).await?;
45//!     println!("Found {} projects", projects.len());
46//!
47//!     Ok(())
48//! }
49//! ```
50//!
51//! ## Optional Features
52//!
53//! - `polars`: Enables integration with Polars DataFrames for enhanced data
54//!   manipulation
55
56mod api;
57mod client;
58mod dataset;
59mod error;
60
61pub use crate::{
62    api::{
63        AnnotationSetID, AppId, Artifact, DatasetID, DatasetParams, Experiment, ExperimentID,
64        ImageId, Organization, OrganizationID, Parameter, PresignedUrl, Project, ProjectID,
65        SampleID, SamplesCountResult, SamplesPopulateParams, SamplesPopulateResult, SequenceId,
66        SnapshotID, Stage, Task, TaskID, TaskInfo, TrainingSession, TrainingSessionID,
67        ValidationSession, ValidationSessionID,
68    },
69    client::{Client, Progress},
70    dataset::{
71        Annotation, AnnotationSet, AnnotationType, Box2d, Box3d, Dataset, FileType, GpsData,
72        ImuData, Label, Location, Mask, Sample, SampleFile,
73    },
74    error::Error,
75};
76
77#[cfg(feature = "polars")]
78pub use crate::dataset::annotations_dataframe;
79
80#[cfg(feature = "polars")]
81pub use crate::dataset::samples_dataframe;
82
83#[cfg(feature = "polars")]
84pub use crate::dataset::unflatten_polygon_coordinates;
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use std::{
90        collections::HashMap,
91        env,
92        fs::{File, read_to_string},
93        io::Write,
94        path::PathBuf,
95    };
96
97    /// Get the test data directory (target/testdata)
98    /// Creates it if it doesn't exist
99    fn get_test_data_dir() -> PathBuf {
100        let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
101            .parent()
102            .expect("CARGO_MANIFEST_DIR should have parent")
103            .parent()
104            .expect("workspace root should exist")
105            .join("target")
106            .join("testdata");
107
108        std::fs::create_dir_all(&test_dir).expect("Failed to create test data directory");
109        test_dir
110    }
111
112    #[ctor::ctor]
113    fn init() {
114        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
115    }
116
117    async fn get_client() -> Result<Client, Error> {
118        let client = Client::new()?.with_token_path(None)?;
119
120        let client = match env::var("STUDIO_TOKEN") {
121            Ok(token) => client.with_token(&token)?,
122            Err(_) => client,
123        };
124
125        let client = match env::var("STUDIO_SERVER") {
126            Ok(server) => client.with_server(&server)?,
127            Err(_) => client,
128        };
129
130        let client = match (env::var("STUDIO_USERNAME"), env::var("STUDIO_PASSWORD")) {
131            (Ok(username), Ok(password)) => client.with_login(&username, &password).await?,
132            _ => client,
133        };
134
135        client.verify_token().await?;
136
137        Ok(client)
138    }
139
140    /// Helper: Get training session for "Unit Testing" project
141    async fn get_training_session_for_artifacts() -> Result<TrainingSession, Error> {
142        let client = get_client().await?;
143        let project = client
144            .projects(Some("Unit Testing"))
145            .await?
146            .into_iter()
147            .next()
148            .ok_or_else(|| Error::InvalidParameters("Unit Testing project not found".into()))?;
149        let experiment = client
150            .experiments(project.id(), Some("Unit Testing"))
151            .await?
152            .into_iter()
153            .next()
154            .ok_or_else(|| Error::InvalidParameters("Unit Testing experiment not found".into()))?;
155        let session = client
156            .training_sessions(experiment.id(), Some("modelpack-960x540"))
157            .await?
158            .into_iter()
159            .next()
160            .ok_or_else(|| {
161                Error::InvalidParameters("modelpack-960x540 session not found".into())
162            })?;
163        Ok(session)
164    }
165
166    /// Helper: Get training session for "modelpack-usermanaged"
167    async fn get_training_session_for_checkpoints() -> Result<TrainingSession, Error> {
168        let client = get_client().await?;
169        let project = client
170            .projects(Some("Unit Testing"))
171            .await?
172            .into_iter()
173            .next()
174            .ok_or_else(|| Error::InvalidParameters("Unit Testing project not found".into()))?;
175        let experiment = client
176            .experiments(project.id(), Some("Unit Testing"))
177            .await?
178            .into_iter()
179            .next()
180            .ok_or_else(|| Error::InvalidParameters("Unit Testing experiment not found".into()))?;
181        let session = client
182            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
183            .await?
184            .into_iter()
185            .next()
186            .ok_or_else(|| {
187                Error::InvalidParameters("modelpack-usermanaged session not found".into())
188            })?;
189        Ok(session)
190    }
191
192    #[tokio::test]
193    async fn test_training_session() -> Result<(), Error> {
194        let client = get_client().await?;
195        let project = client.projects(Some("Unit Testing")).await?;
196        assert!(!project.is_empty());
197        let project = project
198            .first()
199            .expect("'Unit Testing' project should exist");
200        let experiment = client
201            .experiments(project.id(), Some("Unit Testing"))
202            .await?;
203        let experiment = experiment
204            .first()
205            .expect("'Unit Testing' experiment should exist");
206
207        let sessions = client
208            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
209            .await?;
210        assert_ne!(sessions.len(), 0);
211        let session = sessions
212            .first()
213            .expect("Training sessions should exist for experiment");
214
215        let metrics = HashMap::from([
216            ("epochs".to_string(), Parameter::Integer(10)),
217            ("loss".to_string(), Parameter::Real(0.05)),
218            (
219                "model".to_string(),
220                Parameter::String("modelpack".to_string()),
221            ),
222        ]);
223
224        session.set_metrics(&client, metrics).await?;
225        let updated_metrics = session.metrics(&client).await?;
226        assert_eq!(updated_metrics.len(), 3);
227        assert_eq!(updated_metrics.get("epochs"), Some(&Parameter::Integer(10)));
228        assert_eq!(updated_metrics.get("loss"), Some(&Parameter::Real(0.05)));
229        assert_eq!(
230            updated_metrics.get("model"),
231            Some(&Parameter::String("modelpack".to_string()))
232        );
233
234        println!("Updated Metrics: {:?}", updated_metrics);
235
236        let mut labels = tempfile::NamedTempFile::new()?;
237        write!(labels, "background")?;
238        labels.flush()?;
239
240        session
241            .upload(
242                &client,
243                &[(
244                    "artifacts/labels.txt".to_string(),
245                    labels.path().to_path_buf(),
246                )],
247            )
248            .await?;
249
250        let labels = session.download(&client, "artifacts/labels.txt").await?;
251        assert_eq!(labels, "background");
252
253        Ok(())
254    }
255
256    #[tokio::test]
257    async fn test_validate() -> Result<(), Error> {
258        let client = get_client().await?;
259        let project = client.projects(Some("Unit Testing")).await?;
260        assert!(!project.is_empty());
261        let project = project
262            .first()
263            .expect("'Unit Testing' project should exist");
264
265        let sessions = client.validation_sessions(project.id()).await?;
266        for session in &sessions {
267            let s = client.validation_session(session.id()).await?;
268            assert_eq!(s.id(), session.id());
269            assert_eq!(s.description(), session.description());
270        }
271
272        let session = sessions
273            .into_iter()
274            .find(|s| s.name() == "modelpack-usermanaged")
275            .ok_or_else(|| {
276                Error::InvalidParameters(format!(
277                    "Validation session 'modelpack-usermanaged' not found in project '{}'",
278                    project.name()
279                ))
280            })?;
281
282        let metrics = HashMap::from([("accuracy".to_string(), Parameter::Real(0.95))]);
283        session.set_metrics(&client, metrics).await?;
284
285        let metrics = session.metrics(&client).await?;
286        assert_eq!(metrics.get("accuracy"), Some(&Parameter::Real(0.95)));
287
288        Ok(())
289    }
290
291    #[tokio::test]
292    async fn test_download_artifact_success() -> Result<(), Error> {
293        let trainer = get_training_session_for_artifacts().await?;
294        let client = get_client().await?;
295        let artifacts = client.artifacts(trainer.id()).await?;
296        assert!(!artifacts.is_empty());
297
298        let test_dir = get_test_data_dir();
299        let artifact = &artifacts[0];
300        let output_path = test_dir.join(artifact.name());
301
302        client
303            .download_artifact(
304                trainer.id(),
305                artifact.name(),
306                Some(output_path.clone()),
307                None,
308            )
309            .await?;
310
311        assert!(output_path.exists());
312        if output_path.exists() {
313            std::fs::remove_file(&output_path)?;
314        }
315
316        Ok(())
317    }
318
319    #[tokio::test]
320    async fn test_download_artifact_not_found() -> Result<(), Error> {
321        let trainer = get_training_session_for_artifacts().await?;
322        let client = get_client().await?;
323        let test_dir = get_test_data_dir();
324        let fake_path = test_dir.join("nonexistent_artifact.txt");
325
326        let result = client
327            .download_artifact(
328                trainer.id(),
329                "nonexistent_artifact.txt",
330                Some(fake_path.clone()),
331                None,
332            )
333            .await;
334
335        assert!(result.is_err());
336        assert!(!fake_path.exists());
337
338        Ok(())
339    }
340
341    #[tokio::test]
342    async fn test_artifacts() -> Result<(), Error> {
343        let client = get_client().await?;
344        let project = client.projects(Some("Unit Testing")).await?;
345        assert!(!project.is_empty());
346        let project = project
347            .first()
348            .expect("'Unit Testing' project should exist");
349        let experiment = client
350            .experiments(project.id(), Some("Unit Testing"))
351            .await?;
352        let experiment = experiment
353            .first()
354            .expect("'Unit Testing' experiment should exist");
355        let trainer = client
356            .training_sessions(experiment.id(), Some("modelpack-960x540"))
357            .await?;
358        let trainer = trainer
359            .first()
360            .expect("'modelpack-960x540' training session should exist");
361        let artifacts = client.artifacts(trainer.id()).await?;
362        assert!(!artifacts.is_empty());
363
364        let test_dir = get_test_data_dir();
365
366        for artifact in artifacts {
367            let output_path = test_dir.join(artifact.name());
368            client
369                .download_artifact(
370                    trainer.id(),
371                    artifact.name(),
372                    Some(output_path.clone()),
373                    None,
374                )
375                .await?;
376
377            // Clean up downloaded file
378            if output_path.exists() {
379                std::fs::remove_file(&output_path)?;
380            }
381        }
382
383        let fake_path = test_dir.join("fakefile.txt");
384        let res = client
385            .download_artifact(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
386            .await;
387        assert!(res.is_err());
388        assert!(!fake_path.exists());
389
390        Ok(())
391    }
392
393    #[tokio::test]
394    async fn test_download_checkpoint_success() -> Result<(), Error> {
395        let trainer = get_training_session_for_checkpoints().await?;
396        let client = get_client().await?;
397        let test_dir = get_test_data_dir();
398
399        // Create temporary test file
400        let checkpoint_path = test_dir.join("test_checkpoint.txt");
401        {
402            let mut f = File::create(&checkpoint_path)?;
403            f.write_all(b"Test Checkpoint Content")?;
404        }
405
406        // Upload the checkpoint
407        trainer
408            .upload(
409                &client,
410                &[(
411                    "checkpoints/test_checkpoint.txt".to_string(),
412                    checkpoint_path.clone(),
413                )],
414            )
415            .await?;
416
417        // Download and verify
418        let download_path = test_dir.join("downloaded_checkpoint.txt");
419        client
420            .download_checkpoint(
421                trainer.id(),
422                "test_checkpoint.txt",
423                Some(download_path.clone()),
424                None,
425            )
426            .await?;
427
428        let content = read_to_string(&download_path)?;
429        assert_eq!(content, "Test Checkpoint Content");
430
431        // Cleanup
432        if checkpoint_path.exists() {
433            std::fs::remove_file(&checkpoint_path)?;
434        }
435        if download_path.exists() {
436            std::fs::remove_file(&download_path)?;
437        }
438
439        Ok(())
440    }
441
442    #[tokio::test]
443    async fn test_download_checkpoint_not_found() -> Result<(), Error> {
444        let trainer = get_training_session_for_checkpoints().await?;
445        let client = get_client().await?;
446        let test_dir = get_test_data_dir();
447        let fake_path = test_dir.join("nonexistent_checkpoint.txt");
448
449        let result = client
450            .download_checkpoint(
451                trainer.id(),
452                "nonexistent_checkpoint.txt",
453                Some(fake_path.clone()),
454                None,
455            )
456            .await;
457
458        assert!(result.is_err());
459        assert!(!fake_path.exists());
460
461        Ok(())
462    }
463
464    #[tokio::test]
465    async fn test_checkpoints() -> Result<(), Error> {
466        let client = get_client().await?;
467        let project = client.projects(Some("Unit Testing")).await?;
468        assert!(!project.is_empty());
469        let project = project
470            .first()
471            .expect("'Unit Testing' project should exist");
472        let experiment = client
473            .experiments(project.id(), Some("Unit Testing"))
474            .await?;
475        let experiment = experiment.first().ok_or_else(|| {
476            Error::InvalidParameters(format!(
477                "Experiment 'Unit Testing' not found in project '{}'",
478                project.name()
479            ))
480        })?;
481        let trainer = client
482            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
483            .await?;
484        let trainer = trainer
485            .first()
486            .expect("'modelpack-usermanaged' training session should exist");
487
488        let test_dir = get_test_data_dir();
489        let checkpoint_path = test_dir.join("checkpoint.txt");
490        let checkpoint2_path = test_dir.join("checkpoint2.txt");
491
492        {
493            let mut chkpt = File::create(&checkpoint_path)?;
494            chkpt.write_all(b"Test Checkpoint")?;
495        }
496
497        trainer
498            .upload(
499                &client,
500                &[(
501                    "checkpoints/checkpoint.txt".to_string(),
502                    checkpoint_path.clone(),
503                )],
504            )
505            .await?;
506
507        client
508            .download_checkpoint(
509                trainer.id(),
510                "checkpoint.txt",
511                Some(checkpoint2_path.clone()),
512                None,
513            )
514            .await?;
515
516        let chkpt = read_to_string(&checkpoint2_path)?;
517        assert_eq!(chkpt, "Test Checkpoint");
518
519        let fake_path = test_dir.join("fakefile.txt");
520        let res = client
521            .download_checkpoint(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
522            .await;
523        assert!(res.is_err());
524        assert!(!fake_path.exists());
525
526        // Clean up
527        if checkpoint_path.exists() {
528            std::fs::remove_file(&checkpoint_path)?;
529        }
530        if checkpoint2_path.exists() {
531            std::fs::remove_file(&checkpoint2_path)?;
532        }
533
534        Ok(())
535    }
536
537    #[tokio::test]
538    async fn test_task_retrieval() -> Result<(), Error> {
539        let client = get_client().await?;
540
541        // Test: Get all tasks
542        let tasks = client.tasks(None, None, None, None).await?;
543        assert!(!tasks.is_empty());
544
545        // Test: Get task info for first task
546        let task_id = tasks[0].id();
547        let task_info = client.task_info(task_id).await?;
548        assert_eq!(task_info.id(), task_id);
549
550        Ok(())
551    }
552
553    #[tokio::test]
554    async fn test_task_filtering_by_name() -> Result<(), Error> {
555        let client = get_client().await?;
556        let project = client.projects(Some("Unit Testing")).await?;
557        let project = project
558            .first()
559            .expect("'Unit Testing' project should exist");
560
561        // Test: Get tasks by name
562        let tasks = client
563            .tasks(Some("modelpack-usermanaged"), None, None, None)
564            .await?;
565
566        if !tasks.is_empty() {
567            // Get detailed info for each task
568            let task_infos = tasks
569                .into_iter()
570                .map(|t| client.task_info(t.id()))
571                .collect::<Vec<_>>();
572            let task_infos = futures::future::try_join_all(task_infos).await?;
573
574            // Filter by project
575            let filtered = task_infos
576                .into_iter()
577                .filter(|t| t.project_id() == Some(project.id()))
578                .collect::<Vec<_>>();
579
580            if !filtered.is_empty() {
581                assert_eq!(filtered[0].project_id(), Some(project.id()));
582            }
583        }
584
585        Ok(())
586    }
587
588    #[tokio::test]
589    async fn test_task_status_and_stages() -> Result<(), Error> {
590        let client = get_client().await?;
591
592        // Get first available task
593        let tasks = client.tasks(None, None, None, None).await?;
594        if tasks.is_empty() {
595            return Ok(());
596        }
597
598        let task_id = tasks[0].id();
599
600        // Test: Get task status
601        let status = client.task_status(task_id, "training").await?;
602        assert_eq!(status.id(), task_id);
603        assert_eq!(status.status(), "training");
604
605        // Test: Set stages
606        let stages = [
607            ("download", "Downloading Dataset"),
608            ("train", "Training Model"),
609            ("export", "Exporting Model"),
610        ];
611        client.set_stages(task_id, &stages).await?;
612
613        // Test: Update stage
614        client
615            .update_stage(task_id, "download", "running", "Downloading dataset", 50)
616            .await?;
617
618        // Verify task with updated stages
619        let updated_task = client.task_info(task_id).await?;
620        assert_eq!(updated_task.id(), task_id);
621
622        Ok(())
623    }
624
625    #[tokio::test]
626    async fn test_tasks() -> Result<(), Error> {
627        let client = get_client().await?;
628        let project = client.projects(Some("Unit Testing")).await?;
629        let project = project
630            .first()
631            .expect("'Unit Testing' project should exist");
632        let tasks = client.tasks(None, None, None, None).await?;
633
634        for task in tasks {
635            let task_info = client.task_info(task.id()).await?;
636            println!("{} - {}", task, task_info);
637        }
638
639        let tasks = client
640            .tasks(Some("modelpack-usermanaged"), None, None, None)
641            .await?;
642        let tasks = tasks
643            .into_iter()
644            .map(|t| client.task_info(t.id()))
645            .collect::<Vec<_>>();
646        let tasks = futures::future::try_join_all(tasks).await?;
647        let tasks = tasks
648            .into_iter()
649            .filter(|t| t.project_id() == Some(project.id()))
650            .collect::<Vec<_>>();
651        assert_ne!(tasks.len(), 0);
652        let task = &tasks[0];
653
654        let t = client.task_status(task.id(), "training").await?;
655        assert_eq!(t.id(), task.id());
656        assert_eq!(t.status(), "training");
657
658        let stages = [
659            ("download", "Downloading Dataset"),
660            ("train", "Training Model"),
661            ("export", "Exporting Model"),
662        ];
663        client.set_stages(task.id(), &stages).await?;
664
665        client
666            .update_stage(task.id(), "download", "running", "Downloading dataset", 50)
667            .await?;
668
669        let task = client.task_info(task.id()).await?;
670        println!("task progress: {:?}", task.stages());
671
672        Ok(())
673    }
674}