edgefirst_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4// SPDX-License-Identifier: Apache-2.0
5// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
6
7//! # EdgeFirst Studio Client Library
8//!
9//! The EdgeFirst Studio Client Library provides a Rust client for interacting
10//! with EdgeFirst Studio, a comprehensive platform for computer vision and
11//! machine learning workflows. This library enables developers to
12//! programmatically manage datasets, annotations, training sessions, and other
13//! Studio resources.
14//!
15//! ## Features
16//!
17//! - **Authentication**: Secure token-based authentication with automatic
18//!   renewal
19//! - **Dataset Management**: Upload, download, and manage datasets with various
20//!   file types
21//! - **Annotation Management**: Create, update, and retrieve annotations for
22//!   computer vision tasks
23//! - **Training & Validation**: Manage machine learning training and validation
24//!   sessions
25//! - **Project Organization**: Organize work into projects with hierarchical
26//!   structure
27//! - **Polars Integration**: Optional integration with Polars DataFrames for
28//!   data analysis
29//!
30//! ## Quick Start
31//!
32//! ```rust,no_run
33//! use edgefirst_client::{Client, Error};
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<(), Error> {
37//!     // Create a new client
38//!     let client = Client::new()?;
39//!
40//!     // Authenticate with username and password
41//!     let client = client.with_login("username", "password").await?;
42//!
43//!     // List available projects
44//!     let projects = client.projects(None).await?;
45//!     println!("Found {} projects", projects.len());
46//!
47//!     Ok(())
48//! }
49//! ```
50//!
51//! ## Optional Features
52//!
53//! - `polars`: Enables integration with Polars DataFrames for enhanced data
54//!   manipulation
55
56mod api;
57mod client;
58mod dataset;
59mod error;
60
61pub use crate::{
62    api::{
63        AnnotationSetID, AppId, Artifact, DatasetID, DatasetParams, Experiment, ExperimentID,
64        ImageId, Organization, OrganizationID, Parameter, Project, ProjectID, SampleID, SequenceId,
65        SnapshotID, Stage, Task, TaskID, TaskInfo, TrainingSession, TrainingSessionID,
66        ValidationSession, ValidationSessionID,
67    },
68    client::{Client, Progress},
69    dataset::{
70        Annotation, AnnotationSet, AnnotationType, Box2d, Box3d, Dataset, FileType, Label, Mask,
71        Sample,
72    },
73    error::Error,
74};
75
76#[cfg(feature = "polars")]
77pub use crate::dataset::annotations_dataframe;
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use polars::frame::UniqueKeepStrategy;
83    use std::{
84        collections::HashMap,
85        env,
86        fs::{File, read_to_string},
87        io::Write,
88        path::PathBuf,
89    };
90    use tokio::time::{Duration, sleep};
91
92    /// Get the test data directory (target/testdata)
93    /// Creates it if it doesn't exist
94    fn get_test_data_dir() -> PathBuf {
95        let test_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
96            .parent()
97            .unwrap()
98            .parent()
99            .unwrap()
100            .join("target")
101            .join("testdata");
102
103        std::fs::create_dir_all(&test_dir).expect("Failed to create test data directory");
104        test_dir
105    }
106
107    #[ctor::ctor]
108    fn init() {
109        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
110    }
111
112    #[tokio::test]
113    async fn test_version() -> Result<(), Error> {
114        let client = match env::var("STUDIO_SERVER") {
115            Ok(server) => Client::new()?.with_server(&server)?,
116            Err(_) => Client::new()?,
117        };
118        let result = client.version().await?;
119        println!("EdgeFirst Studio Version: {}", result);
120        Ok(())
121    }
122
123    async fn get_client() -> Result<Client, Error> {
124        let client = Client::new()?.with_token_path(None)?;
125
126        let client = match env::var("STUDIO_TOKEN") {
127            Ok(token) => client.with_token(&token)?,
128            Err(_) => client,
129        };
130
131        let client = match env::var("STUDIO_SERVER") {
132            Ok(server) => client.with_server(&server)?,
133            Err(_) => client,
134        };
135
136        let client = match (env::var("STUDIO_USERNAME"), env::var("STUDIO_PASSWORD")) {
137            (Ok(username), Ok(password)) => client.with_login(&username, &password).await?,
138            _ => client,
139        };
140
141        client.verify_token().await?;
142
143        Ok(client)
144    }
145
146    #[tokio::test]
147    async fn test_token() -> Result<(), Error> {
148        let client = get_client().await?;
149        let token = client.token().await;
150        assert!(!token.is_empty());
151        println!("Token: {}", token);
152
153        let exp = client.token_expiration().await?;
154        println!("Token Expiration: {}", exp);
155
156        let username = client.username().await?;
157        assert!(!username.is_empty());
158        println!("Username: {}", username);
159
160        // Wait for 2 seconds to ensure token renewal updates the time
161        sleep(Duration::from_secs(2)).await;
162
163        client.renew_token().await?;
164        let new_token = client.token().await;
165        assert!(!new_token.is_empty());
166        assert_ne!(token, new_token);
167        println!("New Token Expiration: {}", client.token_expiration().await?);
168
169        Ok(())
170    }
171
172    #[tokio::test]
173    async fn test_organization() -> Result<(), Error> {
174        let client = get_client().await?;
175        let org = client.organization().await?;
176        println!(
177            "Organization: {}\nID: {}\nCredits: {}",
178            org.name(),
179            org.id(),
180            org.credits()
181        );
182        Ok(())
183    }
184
185    #[tokio::test]
186    async fn test_projects() -> Result<(), Error> {
187        let client = get_client().await?;
188        let project = client.projects(Some("Unit Testing")).await?;
189        assert!(!project.is_empty());
190        Ok(())
191    }
192
193    #[tokio::test]
194    async fn test_datasets() -> Result<(), Error> {
195        let client = get_client().await?;
196        let project = client.projects(Some("Unit Testing")).await?;
197        assert!(!project.is_empty());
198        let project = project.first().unwrap();
199        let datasets = client.datasets(project.id(), None).await?;
200
201        for dataset in datasets {
202            let dataset_id = dataset.id();
203            let result = client.dataset(dataset_id).await?;
204            assert_eq!(result.id(), dataset_id);
205        }
206
207        Ok(())
208    }
209
210    #[tokio::test]
211    async fn test_labels() -> Result<(), Error> {
212        let client = get_client().await?;
213        let project = client.projects(Some("Unit Testing")).await?;
214        assert!(!project.is_empty());
215        let project = project.first().unwrap();
216        let datasets = client.datasets(project.id(), Some("Test Labels")).await?;
217        let dataset = datasets.first().unwrap_or_else(|| {
218            panic!(
219                "Dataset 'Test Labels' not found in project '{}'",
220                project.name()
221            )
222        });
223
224        let labels = dataset.labels(&client).await?;
225        for label in labels {
226            label.remove(&client).await?;
227        }
228
229        let labels = dataset.labels(&client).await?;
230        assert_eq!(labels.len(), 0);
231
232        dataset.add_label(&client, "test").await?;
233        let labels = dataset.labels(&client).await?;
234        assert_eq!(labels.len(), 1);
235        assert_eq!(labels[0].name(), "test");
236
237        dataset.remove_label(&client, "test").await?;
238        let labels = dataset.labels(&client).await?;
239        assert_eq!(labels.len(), 0);
240
241        Ok(())
242    }
243
244    #[tokio::test]
245    async fn test_coco() -> Result<(), Error> {
246        let coco_labels = HashMap::from([
247            (0, "person"),
248            (1, "bicycle"),
249            (2, "car"),
250            (3, "motorcycle"),
251            (4, "airplane"),
252            (5, "bus"),
253            (6, "train"),
254            (7, "truck"),
255            (8, "boat"),
256            (9, "traffic light"),
257            (10, "fire hydrant"),
258            (11, "stop sign"),
259            (12, "parking meter"),
260            (13, "bench"),
261            (14, "bird"),
262            (15, "cat"),
263            (16, "dog"),
264            (17, "horse"),
265            (18, "sheep"),
266            (19, "cow"),
267            (20, "elephant"),
268            (21, "bear"),
269            (22, "zebra"),
270            (23, "giraffe"),
271            (24, "backpack"),
272            (25, "umbrella"),
273            (26, "handbag"),
274            (27, "tie"),
275            (28, "suitcase"),
276            (29, "frisbee"),
277            (30, "skis"),
278            (31, "snowboard"),
279            (32, "sports ball"),
280            (33, "kite"),
281            (34, "baseball bat"),
282            (35, "baseball glove"),
283            (36, "skateboard"),
284            (37, "surfboard"),
285            (38, "tennis racket"),
286            (39, "bottle"),
287            (40, "wine glass"),
288            (41, "cup"),
289            (42, "fork"),
290            (43, "knife"),
291            (44, "spoon"),
292            (45, "bowl"),
293            (46, "banana"),
294            (47, "apple"),
295            (48, "sandwich"),
296            (49, "orange"),
297            (50, "broccoli"),
298            (51, "carrot"),
299            (52, "hot dog"),
300            (53, "pizza"),
301            (54, "donut"),
302            (55, "cake"),
303            (56, "chair"),
304            (57, "couch"),
305            (58, "potted plant"),
306            (59, "bed"),
307            (60, "dining table"),
308            (61, "toilet"),
309            (62, "tv"),
310            (63, "laptop"),
311            (64, "mouse"),
312            (65, "remote"),
313            (66, "keyboard"),
314            (67, "cell phone"),
315            (68, "microwave"),
316            (69, "oven"),
317            (70, "toaster"),
318            (71, "sink"),
319            (72, "refrigerator"),
320            (73, "book"),
321            (74, "clock"),
322            (75, "vase"),
323            (76, "scissors"),
324            (77, "teddy bear"),
325            (78, "hair drier"),
326            (79, "toothbrush"),
327        ]);
328
329        let client = get_client().await?;
330        let project = client.projects(Some("Sample Project")).await?;
331        assert!(!project.is_empty());
332        let project = project.first().unwrap();
333        let datasets = client.datasets(project.id(), Some("COCO")).await?;
334        assert!(!datasets.is_empty());
335        // Filter to avoid fetching the COCO People dataset.
336        let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
337
338        let labels = dataset.labels(&client).await?;
339        assert_eq!(labels.len(), 80);
340
341        for label in &labels {
342            assert_eq!(label.name(), coco_labels[&label.index()]);
343        }
344
345        let n_samples = client
346            .samples_count(dataset.id(), None, &[], &["val".to_string()], &[])
347            .await?;
348        assert_eq!(n_samples.total, 5000);
349
350        let samples = client
351            .samples(dataset.id(), None, &[], &["val".to_string()], &[], None)
352            .await?;
353        assert_eq!(samples.len(), 5000);
354
355        Ok(())
356    }
357
358    #[cfg(feature = "polars")]
359    #[tokio::test]
360    async fn test_coco_dataframe() -> Result<(), Error> {
361        let client = get_client().await?;
362        let project = client.projects(Some("Sample Project")).await?;
363        assert!(!project.is_empty());
364        let project = project.first().unwrap();
365        let datasets = client.datasets(project.id(), Some("COCO")).await?;
366        assert!(!datasets.is_empty());
367        // Filter to avoid fetching the COCO People dataset.
368        let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
369
370        let annotation_set_id = dataset
371            .annotation_sets(&client)
372            .await?
373            .first()
374            .unwrap()
375            .id();
376
377        let annotations = client
378            .annotations(annotation_set_id, &["val".to_string()], &[], None)
379            .await?;
380        let df = annotations_dataframe(&annotations);
381        let df = df
382            .unique_stable(Some(&["name".to_string()]), UniqueKeepStrategy::First, None)
383            .unwrap();
384        assert_eq!(df.height(), 5000);
385
386        Ok(())
387    }
388
389    #[tokio::test]
390    async fn test_snapshots() -> Result<(), Error> {
391        let client = get_client().await?;
392        let snapshots = client.snapshots(None).await?;
393
394        for snapshot in snapshots {
395            let snapshot_id = snapshot.id();
396            let result = client.snapshot(snapshot_id).await?;
397            assert_eq!(result.id(), snapshot_id);
398        }
399
400        Ok(())
401    }
402
403    #[tokio::test]
404    async fn test_experiments() -> Result<(), Error> {
405        let client = get_client().await?;
406        let project = client.projects(Some("Unit Testing")).await?;
407        assert!(!project.is_empty());
408        let project = project.first().unwrap();
409        let experiments = client.experiments(project.id(), None).await?;
410
411        for experiment in experiments {
412            let experiment_id = experiment.id();
413            let result = client.experiment(experiment_id).await?;
414            assert_eq!(result.id(), experiment_id);
415        }
416
417        Ok(())
418    }
419
420    #[tokio::test]
421    async fn test_training_session() -> Result<(), Error> {
422        let client = get_client().await?;
423        let project = client.projects(Some("Unit Testing")).await?;
424        assert!(!project.is_empty());
425        let project = project.first().unwrap();
426        let experiment = client
427            .experiments(project.id(), Some("Unit Testing"))
428            .await?;
429        let experiment = experiment.first().unwrap();
430
431        let sessions = client
432            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
433            .await?;
434        assert_ne!(sessions.len(), 0);
435        let session = sessions.first().unwrap();
436
437        let metrics = HashMap::from([
438            ("epochs".to_string(), Parameter::Integer(10)),
439            ("loss".to_string(), Parameter::Real(0.05)),
440            (
441                "model".to_string(),
442                Parameter::String("modelpack".to_string()),
443            ),
444        ]);
445
446        session.set_metrics(&client, metrics).await?;
447        let updated_metrics = session.metrics(&client).await?;
448        assert_eq!(updated_metrics.len(), 3);
449        assert_eq!(updated_metrics.get("epochs"), Some(&Parameter::Integer(10)));
450        assert_eq!(updated_metrics.get("loss"), Some(&Parameter::Real(0.05)));
451        assert_eq!(
452            updated_metrics.get("model"),
453            Some(&Parameter::String("modelpack".to_string()))
454        );
455
456        println!("Updated Metrics: {:?}", updated_metrics);
457
458        let mut labels = tempfile::NamedTempFile::new()?;
459        write!(labels, "background")?;
460        labels.flush()?;
461
462        session
463            .upload(
464                &client,
465                &[(
466                    "artifacts/labels.txt".to_string(),
467                    labels.path().to_path_buf(),
468                )],
469            )
470            .await?;
471
472        let labels = session.download(&client, "artifacts/labels.txt").await?;
473        assert_eq!(labels, "background");
474
475        Ok(())
476    }
477
478    #[tokio::test]
479    async fn test_validate() -> Result<(), Error> {
480        let client = get_client().await?;
481        let project = client.projects(Some("Unit Testing")).await?;
482        assert!(!project.is_empty());
483        let project = project.first().unwrap();
484
485        let sessions = client.validation_sessions(project.id()).await?;
486        for session in &sessions {
487            let s = client.validation_session(session.id()).await?;
488            assert_eq!(s.id(), session.id());
489            assert_eq!(s.description(), session.description());
490        }
491
492        let session = sessions
493            .into_iter()
494            .find(|s| s.name() == "modelpack-usermanaged")
495            .unwrap_or_else(|| {
496                panic!(
497                    "Validation session 'modelpack-usermanaged' not found in project '{}'",
498                    project.name()
499                )
500            });
501
502        let metrics = HashMap::from([("accuracy".to_string(), Parameter::Real(0.95))]);
503        session.set_metrics(&client, metrics).await?;
504
505        let metrics = session.metrics(&client).await?;
506        assert_eq!(metrics.get("accuracy"), Some(&Parameter::Real(0.95)));
507
508        Ok(())
509    }
510
511    #[tokio::test]
512    async fn test_artifacts() -> Result<(), Error> {
513        let client = get_client().await?;
514        let project = client.projects(Some("Unit Testing")).await?;
515        assert!(!project.is_empty());
516        let project = project.first().unwrap();
517        let experiment = client
518            .experiments(project.id(), Some("Unit Testing"))
519            .await?;
520        let experiment = experiment.first().unwrap();
521        let trainer = client
522            .training_sessions(experiment.id(), Some("modelpack-960x540"))
523            .await?;
524        let trainer = trainer.first().unwrap();
525        let artifacts = client.artifacts(trainer.id()).await?;
526        assert!(!artifacts.is_empty());
527
528        let test_dir = get_test_data_dir();
529
530        for artifact in artifacts {
531            let output_path = test_dir.join(artifact.name());
532            client
533                .download_artifact(
534                    trainer.id(),
535                    artifact.name(),
536                    Some(output_path.clone()),
537                    None,
538                )
539                .await?;
540
541            // Clean up downloaded file
542            if output_path.exists() {
543                std::fs::remove_file(&output_path)?;
544            }
545        }
546
547        let fake_path = test_dir.join("fakefile.txt");
548        let res = client
549            .download_artifact(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
550            .await;
551        assert!(res.is_err());
552        assert!(!fake_path.exists());
553
554        Ok(())
555    }
556
557    #[tokio::test]
558    async fn test_checkpoints() -> Result<(), Error> {
559        let client = get_client().await?;
560        let project = client.projects(Some("Unit Testing")).await?;
561        assert!(!project.is_empty());
562        let project = project.first().unwrap();
563        let experiment = client
564            .experiments(project.id(), Some("Unit Testing"))
565            .await?;
566        let experiment = experiment.first().unwrap_or_else(|| {
567            panic!(
568                "Experiment 'Unit Testing' not found in project '{}'",
569                project.name()
570            )
571        });
572        let trainer = client
573            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
574            .await?;
575        let trainer = trainer.first().unwrap();
576
577        let test_dir = get_test_data_dir();
578        let checkpoint_path = test_dir.join("checkpoint.txt");
579        let checkpoint2_path = test_dir.join("checkpoint2.txt");
580
581        {
582            let mut chkpt = File::create(&checkpoint_path)?;
583            chkpt.write_all(b"Test Checkpoint")?;
584        }
585
586        trainer
587            .upload(
588                &client,
589                &[(
590                    "checkpoints/checkpoint.txt".to_string(),
591                    checkpoint_path.clone(),
592                )],
593            )
594            .await?;
595
596        client
597            .download_checkpoint(
598                trainer.id(),
599                "checkpoint.txt",
600                Some(checkpoint2_path.clone()),
601                None,
602            )
603            .await?;
604
605        let chkpt = read_to_string(&checkpoint2_path)?;
606        assert_eq!(chkpt, "Test Checkpoint");
607
608        let fake_path = test_dir.join("fakefile.txt");
609        let res = client
610            .download_checkpoint(trainer.id(), "fakefile.txt", Some(fake_path.clone()), None)
611            .await;
612        assert!(res.is_err());
613        assert!(!fake_path.exists());
614
615        // Clean up
616        if checkpoint_path.exists() {
617            std::fs::remove_file(&checkpoint_path)?;
618        }
619        if checkpoint2_path.exists() {
620            std::fs::remove_file(&checkpoint2_path)?;
621        }
622
623        Ok(())
624    }
625
626    #[tokio::test]
627    async fn test_tasks() -> Result<(), Error> {
628        let client = get_client().await?;
629        let project = client.projects(Some("Unit Testing")).await?;
630        let project = project.first().unwrap();
631        let tasks = client.tasks(None, None, None, None).await?;
632
633        for task in tasks {
634            let task_info = client.task_info(task.id()).await?;
635            println!("{} - {}", task, task_info);
636        }
637
638        let tasks = client
639            .tasks(Some("modelpack-usermanaged"), None, None, None)
640            .await?;
641        let tasks = tasks
642            .into_iter()
643            .map(|t| client.task_info(t.id()))
644            .collect::<Vec<_>>();
645        let tasks = futures::future::try_join_all(tasks).await?;
646        let tasks = tasks
647            .into_iter()
648            .filter(|t| t.project_id() == Some(project.id()))
649            .collect::<Vec<_>>();
650        assert_ne!(tasks.len(), 0);
651        let task = &tasks[0];
652
653        let t = client.task_status(task.id(), "training").await?;
654        assert_eq!(t.id(), task.id());
655        assert_eq!(t.status(), "training");
656
657        let stages = [
658            ("download", "Downloading Dataset"),
659            ("train", "Training Model"),
660            ("export", "Exporting Model"),
661        ];
662        client.set_stages(task.id(), &stages).await?;
663
664        client
665            .update_stage(task.id(), "download", "running", "Downloading dataset", 50)
666            .await?;
667
668        let task = client.task_info(task.id()).await?;
669        println!("task progress: {:?}", task.stages());
670
671        Ok(())
672    }
673}