edgefirst_client/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4// SPDX-License-Identifier: Apache-2.0
5// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
6
7//! # EdgeFirst Studio Client Library
8//!
9//! The EdgeFirst Studio Client Library provides a Rust client for interacting
10//! with EdgeFirst Studio, a comprehensive platform for computer vision and
11//! machine learning workflows. This library enables developers to
12//! programmatically manage datasets, annotations, training sessions, and other
13//! Studio resources.
14//!
15//! ## Features
16//!
17//! - **Authentication**: Secure token-based authentication with automatic
18//!   renewal
19//! - **Dataset Management**: Upload, download, and manage datasets with various
20//!   file types
21//! - **Annotation Management**: Create, update, and retrieve annotations for
22//!   computer vision tasks
23//! - **Training & Validation**: Manage machine learning training and validation
24//!   sessions
25//! - **Project Organization**: Organize work into projects with hierarchical
26//!   structure
27//! - **Polars Integration**: Optional integration with Polars DataFrames for
28//!   data analysis
29//!
30//! ## Quick Start
31//!
32//! ```rust,no_run
33//! use edgefirst_client::{Client, Error};
34//!
35//! #[tokio::main]
36//! async fn main() -> Result<(), Error> {
37//!     // Create a new client
38//!     let client = Client::new()?;
39//!
40//!     // Authenticate with username and password
41//!     let client = client.with_login("username", "password").await?;
42//!
43//!     // List available projects
44//!     let projects = client.projects(None).await?;
45//!     println!("Found {} projects", projects.len());
46//!
47//!     Ok(())
48//! }
49//! ```
50//!
51//! ## Optional Features
52//!
53//! - `polars`: Enables integration with Polars DataFrames for enhanced data
54//!   manipulation
55
56mod api;
57mod client;
58mod dataset;
59mod error;
60
61pub use crate::{
62    api::{
63        AnnotationSetID, AppId, Artifact, DatasetID, DatasetParams, Experiment, ExperimentID,
64        ImageId, Organization, OrganizationID, Parameter, Project, ProjectID, SampleID, SequenceId,
65        SnapshotID, Stage, Task, TaskID, TaskInfo, TrainingSession, TrainingSessionID,
66        ValidationSession, ValidationSessionID,
67    },
68    client::{Client, Progress},
69    dataset::{
70        Annotation, AnnotationSet, AnnotationType, Box2d, Box3d, Dataset, FileType, Label, Mask,
71        Sample,
72    },
73    error::Error,
74};
75
76#[cfg(feature = "polars")]
77pub use crate::dataset::annotations_dataframe;
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use polars::frame::UniqueKeepStrategy;
83    use std::{
84        collections::HashMap,
85        env,
86        fs::{File, read_to_string},
87        io::Write,
88        path::Path,
89    };
90    use tokio::time::{Duration, sleep};
91
92    #[ctor::ctor]
93    fn init() {
94        env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
95    }
96
97    #[tokio::test]
98    async fn test_version() -> Result<(), Error> {
99        let client = match env::var("STUDIO_SERVER") {
100            Ok(server) => Client::new()?.with_server(&server)?,
101            Err(_) => Client::new()?,
102        };
103        let result = client.version().await?;
104        println!("EdgeFirst Studio Version: {}", result);
105        Ok(())
106    }
107
108    async fn get_client() -> Result<Client, Error> {
109        let client = Client::new()?.with_token_path(None)?;
110
111        let client = match env::var("STUDIO_TOKEN") {
112            Ok(token) => client.with_token(&token)?,
113            Err(_) => client,
114        };
115
116        let client = match env::var("STUDIO_SERVER") {
117            Ok(server) => client.with_server(&server)?,
118            Err(_) => client,
119        };
120
121        let client = match (env::var("STUDIO_USERNAME"), env::var("STUDIO_PASSWORD")) {
122            (Ok(username), Ok(password)) => client.with_login(&username, &password).await?,
123            _ => client,
124        };
125
126        client.verify_token().await?;
127
128        Ok(client)
129    }
130
131    #[tokio::test]
132    async fn test_token() -> Result<(), Error> {
133        let client = get_client().await?;
134        let token = client.token().await;
135        assert!(!token.is_empty());
136        println!("Token: {}", token);
137
138        let exp = client.token_expiration().await?;
139        println!("Token Expiration: {}", exp);
140
141        let username = client.username().await?;
142        assert!(!username.is_empty());
143        println!("Username: {}", username);
144
145        // Wait for 2 seconds to ensure token renewal updates the time
146        sleep(Duration::from_secs(2)).await;
147
148        client.renew_token().await?;
149        let new_token = client.token().await;
150        assert!(!new_token.is_empty());
151        assert_ne!(token, new_token);
152        println!("New Token Expiration: {}", client.token_expiration().await?);
153
154        Ok(())
155    }
156
157    #[tokio::test]
158    async fn test_organization() -> Result<(), Error> {
159        let client = get_client().await?;
160        let org = client.organization().await?;
161        println!(
162            "Organization: {}\nID: {}\nCredits: {}",
163            org.name(),
164            org.id(),
165            org.credits()
166        );
167        Ok(())
168    }
169
170    #[tokio::test]
171    async fn test_projects() -> Result<(), Error> {
172        let client = get_client().await?;
173        let project = client.projects(Some("Unit Testing")).await?;
174        assert!(!project.is_empty());
175        Ok(())
176    }
177
178    #[tokio::test]
179    async fn test_datasets() -> Result<(), Error> {
180        let client = get_client().await?;
181        let project = client.projects(Some("Unit Testing")).await?;
182        assert!(!project.is_empty());
183        let project = project.first().unwrap();
184        let datasets = client.datasets(project.id(), None).await?;
185
186        for dataset in datasets {
187            let dataset_id = dataset.id();
188            let result = client.dataset(dataset_id).await?;
189            assert_eq!(result.id(), dataset_id);
190        }
191
192        Ok(())
193    }
194
195    #[tokio::test]
196    async fn test_labels() -> Result<(), Error> {
197        let client = get_client().await?;
198        let project = client.projects(Some("Unit Testing")).await?;
199        assert!(!project.is_empty());
200        let project = project.first().unwrap();
201        let datasets = client.datasets(project.id(), Some("Test Labels")).await?;
202        let dataset = datasets.first().unwrap_or_else(|| {
203            panic!(
204                "Dataset 'Test Labels' not found in project '{}'",
205                project.name()
206            )
207        });
208
209        let labels = dataset.labels(&client).await?;
210        for label in labels {
211            label.remove(&client).await?;
212        }
213
214        let labels = dataset.labels(&client).await?;
215        assert_eq!(labels.len(), 0);
216
217        dataset.add_label(&client, "test").await?;
218        let labels = dataset.labels(&client).await?;
219        assert_eq!(labels.len(), 1);
220        assert_eq!(labels[0].name(), "test");
221
222        dataset.remove_label(&client, "test").await?;
223        let labels = dataset.labels(&client).await?;
224        assert_eq!(labels.len(), 0);
225
226        Ok(())
227    }
228
229    #[tokio::test]
230    async fn test_coco() -> Result<(), Error> {
231        let coco_labels = HashMap::from([
232            (0, "person"),
233            (1, "bicycle"),
234            (2, "car"),
235            (3, "motorcycle"),
236            (4, "airplane"),
237            (5, "bus"),
238            (6, "train"),
239            (7, "truck"),
240            (8, "boat"),
241            (9, "traffic light"),
242            (10, "fire hydrant"),
243            (11, "stop sign"),
244            (12, "parking meter"),
245            (13, "bench"),
246            (14, "bird"),
247            (15, "cat"),
248            (16, "dog"),
249            (17, "horse"),
250            (18, "sheep"),
251            (19, "cow"),
252            (20, "elephant"),
253            (21, "bear"),
254            (22, "zebra"),
255            (23, "giraffe"),
256            (24, "backpack"),
257            (25, "umbrella"),
258            (26, "handbag"),
259            (27, "tie"),
260            (28, "suitcase"),
261            (29, "frisbee"),
262            (30, "skis"),
263            (31, "snowboard"),
264            (32, "sports ball"),
265            (33, "kite"),
266            (34, "baseball bat"),
267            (35, "baseball glove"),
268            (36, "skateboard"),
269            (37, "surfboard"),
270            (38, "tennis racket"),
271            (39, "bottle"),
272            (40, "wine glass"),
273            (41, "cup"),
274            (42, "fork"),
275            (43, "knife"),
276            (44, "spoon"),
277            (45, "bowl"),
278            (46, "banana"),
279            (47, "apple"),
280            (48, "sandwich"),
281            (49, "orange"),
282            (50, "broccoli"),
283            (51, "carrot"),
284            (52, "hot dog"),
285            (53, "pizza"),
286            (54, "donut"),
287            (55, "cake"),
288            (56, "chair"),
289            (57, "couch"),
290            (58, "potted plant"),
291            (59, "bed"),
292            (60, "dining table"),
293            (61, "toilet"),
294            (62, "tv"),
295            (63, "laptop"),
296            (64, "mouse"),
297            (65, "remote"),
298            (66, "keyboard"),
299            (67, "cell phone"),
300            (68, "microwave"),
301            (69, "oven"),
302            (70, "toaster"),
303            (71, "sink"),
304            (72, "refrigerator"),
305            (73, "book"),
306            (74, "clock"),
307            (75, "vase"),
308            (76, "scissors"),
309            (77, "teddy bear"),
310            (78, "hair drier"),
311            (79, "toothbrush"),
312        ]);
313
314        let client = get_client().await?;
315        let project = client.projects(Some("Sample Project")).await?;
316        assert!(!project.is_empty());
317        let project = project.first().unwrap();
318        let datasets = client.datasets(project.id(), Some("COCO")).await?;
319        assert!(!datasets.is_empty());
320        // Filter to avoid fetching the COCO People dataset.
321        let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
322
323        let labels = dataset.labels(&client).await?;
324        assert_eq!(labels.len(), 80);
325
326        for label in &labels {
327            assert_eq!(label.name(), coco_labels[&label.index()]);
328        }
329
330        let n_samples = client
331            .samples_count(dataset.id(), None, &[], &["val".to_string()], &[])
332            .await?;
333        assert_eq!(n_samples.total, 5000);
334
335        let samples = client
336            .samples(dataset.id(), None, &[], &["val".to_string()], &[], None)
337            .await?;
338        assert_eq!(samples.len(), 5000);
339
340        Ok(())
341    }
342
343    #[cfg(feature = "polars")]
344    #[tokio::test]
345    async fn test_coco_dataframe() -> Result<(), Error> {
346        let client = get_client().await?;
347        let project = client.projects(Some("Sample Project")).await?;
348        assert!(!project.is_empty());
349        let project = project.first().unwrap();
350        let datasets = client.datasets(project.id(), Some("COCO")).await?;
351        assert!(!datasets.is_empty());
352        // Filter to avoid fetching the COCO People dataset.
353        let dataset = datasets.iter().find(|d| d.name() == "COCO").unwrap();
354
355        let annotation_set_id = dataset
356            .annotation_sets(&client)
357            .await?
358            .first()
359            .unwrap()
360            .id();
361
362        let annotations = client
363            .annotations(annotation_set_id, &["val".to_string()], &[], None)
364            .await?;
365        let df = annotations_dataframe(&annotations);
366        let df = df
367            .unique_stable(Some(&["name".to_string()]), UniqueKeepStrategy::First, None)
368            .unwrap();
369        assert_eq!(df.height(), 5000);
370
371        Ok(())
372    }
373
374    #[tokio::test]
375    async fn test_snapshots() -> Result<(), Error> {
376        let client = get_client().await?;
377        let snapshots = client.snapshots(None).await?;
378
379        for snapshot in snapshots {
380            let snapshot_id = snapshot.id();
381            let result = client.snapshot(snapshot_id).await?;
382            assert_eq!(result.id(), snapshot_id);
383        }
384
385        Ok(())
386    }
387
388    #[tokio::test]
389    async fn test_experiments() -> Result<(), Error> {
390        let client = get_client().await?;
391        let project = client.projects(Some("Unit Testing")).await?;
392        assert!(!project.is_empty());
393        let project = project.first().unwrap();
394        let experiments = client.experiments(project.id(), None).await?;
395
396        for experiment in experiments {
397            let experiment_id = experiment.id();
398            let result = client.experiment(experiment_id).await?;
399            assert_eq!(result.id(), experiment_id);
400        }
401
402        Ok(())
403    }
404
405    #[tokio::test]
406    async fn test_training_session() -> Result<(), Error> {
407        let client = get_client().await?;
408        let project = client.projects(Some("Unit Testing")).await?;
409        assert!(!project.is_empty());
410        let project = project.first().unwrap();
411        let experiment = client
412            .experiments(project.id(), Some("Unit Testing"))
413            .await?;
414        let experiment = experiment.first().unwrap();
415
416        let sessions = client
417            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
418            .await?;
419        assert_ne!(sessions.len(), 0);
420        let session = sessions.first().unwrap();
421
422        let metrics = HashMap::from([
423            ("epochs".to_string(), Parameter::Integer(10)),
424            ("loss".to_string(), Parameter::Real(0.05)),
425            (
426                "model".to_string(),
427                Parameter::String("modelpack".to_string()),
428            ),
429        ]);
430
431        session.set_metrics(&client, metrics).await?;
432        let updated_metrics = session.metrics(&client).await?;
433        assert_eq!(updated_metrics.len(), 3);
434        assert_eq!(updated_metrics.get("epochs"), Some(&Parameter::Integer(10)));
435        assert_eq!(updated_metrics.get("loss"), Some(&Parameter::Real(0.05)));
436        assert_eq!(
437            updated_metrics.get("model"),
438            Some(&Parameter::String("modelpack".to_string()))
439        );
440
441        println!("Updated Metrics: {:?}", updated_metrics);
442
443        let mut labels = tempfile::NamedTempFile::new()?;
444        write!(labels, "background")?;
445        labels.flush()?;
446
447        session
448            .upload(
449                &client,
450                &[(
451                    "artifacts/labels.txt".to_string(),
452                    labels.path().to_path_buf(),
453                )],
454            )
455            .await?;
456
457        let labels = session.download(&client, "artifacts/labels.txt").await?;
458        assert_eq!(labels, "background");
459
460        Ok(())
461    }
462
463    #[tokio::test]
464    async fn test_validate() -> Result<(), Error> {
465        let client = get_client().await?;
466        let project = client.projects(Some("Unit Testing")).await?;
467        assert!(!project.is_empty());
468        let project = project.first().unwrap();
469
470        let sessions = client.validation_sessions(project.id()).await?;
471        for session in &sessions {
472            let s = client.validation_session(session.id()).await?;
473            assert_eq!(s.id(), session.id());
474            assert_eq!(s.description(), session.description());
475        }
476
477        let session = sessions
478            .into_iter()
479            .find(|s| s.name() == "modelpack-usermanaged")
480            .unwrap_or_else(|| {
481                panic!(
482                    "Validation session 'modelpack-usermanaged' not found in project '{}'",
483                    project.name()
484                )
485            });
486
487        let metrics = HashMap::from([("accuracy".to_string(), Parameter::Real(0.95))]);
488        session.set_metrics(&client, metrics).await?;
489
490        let metrics = session.metrics(&client).await?;
491        assert_eq!(metrics.get("accuracy"), Some(&Parameter::Real(0.95)));
492
493        Ok(())
494    }
495
496    #[tokio::test]
497    async fn test_artifacts() -> Result<(), Error> {
498        let client = get_client().await?;
499        let project = client.projects(Some("Unit Testing")).await?;
500        assert!(!project.is_empty());
501        let project = project.first().unwrap();
502        let experiment = client
503            .experiments(project.id(), Some("Unit Testing"))
504            .await?;
505        let experiment = experiment.first().unwrap();
506        let trainer = client
507            .training_sessions(experiment.id(), Some("modelpack-960x540"))
508            .await?;
509        let trainer = trainer.first().unwrap();
510        let artifacts = client.artifacts(trainer.id()).await?;
511        assert!(!artifacts.is_empty());
512
513        for artifact in artifacts {
514            client
515                .download_artifact(
516                    trainer.id(),
517                    artifact.name(),
518                    Some(artifact.name().into()),
519                    None,
520                )
521                .await?;
522        }
523
524        let res = client
525            .download_artifact(
526                trainer.id(),
527                "fakefile.txt",
528                Some("fakefile.txt".into()),
529                None,
530            )
531            .await;
532        assert!(res.is_err());
533        assert!(!Path::new("fakefile.txt").exists());
534
535        Ok(())
536    }
537
538    #[tokio::test]
539    async fn test_checkpoints() -> Result<(), Error> {
540        let client = get_client().await?;
541        let project = client.projects(Some("Unit Testing")).await?;
542        assert!(!project.is_empty());
543        let project = project.first().unwrap();
544        let experiment = client
545            .experiments(project.id(), Some("Unit Testing"))
546            .await?;
547        let experiment = experiment.first().unwrap_or_else(|| {
548            panic!(
549                "Experiment 'Unit Testing' not found in project '{}'",
550                project.name()
551            )
552        });
553        let trainer = client
554            .training_sessions(experiment.id(), Some("modelpack-usermanaged"))
555            .await?;
556        let trainer = trainer.first().unwrap();
557
558        {
559            let mut chkpt = File::create("checkpoint.txt")?;
560            chkpt.write_all(b"Test Checkpoint")?;
561        }
562
563        trainer
564            .upload(
565                &client,
566                &[(
567                    "checkpoints/checkpoint.txt".to_string(),
568                    "checkpoint.txt".into(),
569                )],
570            )
571            .await?;
572
573        client
574            .download_checkpoint(
575                trainer.id(),
576                "checkpoint.txt",
577                Some("checkpoint2.txt".into()),
578                None,
579            )
580            .await?;
581
582        let chkpt = read_to_string("checkpoint2.txt")?;
583        assert_eq!(chkpt, "Test Checkpoint");
584
585        let res = client
586            .download_checkpoint(
587                trainer.id(),
588                "fakefile.txt",
589                Some("fakefile.txt".into()),
590                None,
591            )
592            .await;
593        assert!(res.is_err());
594        assert!(!Path::new("fakefile.txt").exists());
595
596        Ok(())
597    }
598
599    #[tokio::test]
600    async fn test_tasks() -> Result<(), Error> {
601        let client = get_client().await?;
602        let project = client.projects(Some("Unit Testing")).await?;
603        let project = project.first().unwrap();
604        let tasks = client.tasks(None, None, None, None).await?;
605
606        for task in tasks {
607            let task_info = client.task_info(task.id()).await?;
608            println!("{} - {}", task, task_info);
609        }
610
611        let tasks = client
612            .tasks(Some("modelpack-usermanaged"), None, None, None)
613            .await?;
614        let tasks = tasks
615            .into_iter()
616            .map(|t| client.task_info(t.id()))
617            .collect::<Vec<_>>();
618        let tasks = futures::future::try_join_all(tasks).await?;
619        let tasks = tasks
620            .into_iter()
621            .filter(|t| t.project_id() == Some(project.id()))
622            .collect::<Vec<_>>();
623        assert_ne!(tasks.len(), 0);
624        let task = &tasks[0];
625
626        let t = client.task_status(task.id(), "training").await?;
627        assert_eq!(t.id(), task.id());
628        assert_eq!(t.status(), "training");
629
630        let stages = [
631            ("download", "Downloading Dataset"),
632            ("train", "Training Model"),
633            ("export", "Exporting Model"),
634        ];
635        client.set_stages(task.id(), &stages).await?;
636
637        client
638            .update_stage(task.id(), "download", "running", "Downloading dataset", 50)
639            .await?;
640
641        let task = client.task_info(task.id()).await?;
642        println!("task progress: {:?}", task.stages());
643
644        Ok(())
645    }
646}