1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23 collections::HashMap,
24 ffi::OsStr,
25 fs::create_dir_all,
26 io::{SeekFrom, Write as _},
27 path::{Path, PathBuf},
28 sync::{
29 Arc,
30 atomic::{AtomicUsize, Ordering},
31 },
32 time::Duration,
33 vec,
34};
35use tokio::{
36 fs::{self, File},
37 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
38 sync::{RwLock, Semaphore, mpsc::Sender},
39};
40use tokio_util::codec::{BytesCodec, FramedRead};
41use walkdir::WalkDir;
42
43#[cfg(feature = "polars")]
44use polars::prelude::*;
45
46static MAX_TASKS: usize = 32;
47static MAX_RETRIES: u32 = 10;
48static PART_SIZE: usize = 100 * 1024 * 1024;
49
50fn sanitize_path_component(name: &str) -> String {
51 let trimmed = name.trim();
52 if trimmed.is_empty() {
53 return "unnamed".to_string();
54 }
55
56 let component = Path::new(trimmed)
57 .file_name()
58 .unwrap_or_else(|| OsStr::new(trimmed));
59
60 let sanitized: String = component
61 .to_string_lossy()
62 .chars()
63 .map(|c| match c {
64 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
65 _ => c,
66 })
67 .collect();
68
69 if sanitized.is_empty() {
70 "unnamed".to_string()
71 } else {
72 sanitized
73 }
74}
75
76#[derive(Debug, Clone)]
98pub struct Progress {
99 pub current: usize,
101 pub total: usize,
103}
104
105#[derive(Serialize)]
106struct RpcRequest<Params> {
107 id: u64,
108 jsonrpc: String,
109 method: String,
110 params: Option<Params>,
111}
112
113impl<T> Default for RpcRequest<T> {
114 fn default() -> Self {
115 RpcRequest {
116 id: 0,
117 jsonrpc: "2.0".to_string(),
118 method: "".to_string(),
119 params: None,
120 }
121 }
122}
123
124#[derive(Deserialize)]
125struct RpcError {
126 code: i32,
127 message: String,
128}
129
130#[derive(Deserialize)]
131struct RpcResponse<RpcResult> {
132 #[allow(dead_code)]
133 id: String,
134 #[allow(dead_code)]
135 jsonrpc: String,
136 error: Option<RpcError>,
137 result: Option<RpcResult>,
138}
139
140#[derive(Deserialize)]
141#[allow(dead_code)]
142struct EmptyResult {}
143
144#[derive(Debug, Serialize)]
145#[allow(dead_code)]
146struct SnapshotCreateParams {
147 snapshot_name: String,
148 keys: Vec<String>,
149}
150
151#[derive(Debug, Deserialize)]
152#[allow(dead_code)]
153struct SnapshotCreateResult {
154 snapshot_id: SnapshotID,
155 urls: Vec<String>,
156}
157
158#[derive(Debug, Serialize)]
159struct SnapshotCreateMultipartParams {
160 snapshot_name: String,
161 keys: Vec<String>,
162 file_sizes: Vec<usize>,
163}
164
165#[derive(Debug, Deserialize)]
166#[serde(untagged)]
167enum SnapshotCreateMultipartResultField {
168 Id(u64),
169 Part(SnapshotPart),
170}
171
172#[derive(Debug, Serialize)]
173struct SnapshotCompleteMultipartParams {
174 key: String,
175 upload_id: String,
176 etag_list: Vec<EtagPart>,
177}
178
179#[derive(Debug, Clone, Serialize)]
180struct EtagPart {
181 #[serde(rename = "ETag")]
182 etag: String,
183 #[serde(rename = "PartNumber")]
184 part_number: usize,
185}
186
187#[derive(Debug, Clone, Deserialize)]
188struct SnapshotPart {
189 key: Option<String>,
190 upload_id: String,
191 urls: Vec<String>,
192}
193
194#[derive(Debug, Serialize)]
195struct SnapshotStatusParams {
196 snapshot_id: SnapshotID,
197 status: String,
198}
199
200#[derive(Deserialize, Debug)]
201struct SnapshotStatusResult {
202 #[allow(dead_code)]
203 pub id: SnapshotID,
204 #[allow(dead_code)]
205 pub uid: String,
206 #[allow(dead_code)]
207 pub description: String,
208 #[allow(dead_code)]
209 pub date: String,
210 #[allow(dead_code)]
211 pub status: String,
212}
213
214#[derive(Serialize)]
215#[allow(dead_code)]
216struct ImageListParams {
217 images_filter: ImagesFilter,
218 image_files_filter: HashMap<String, String>,
219 only_ids: bool,
220}
221
222#[derive(Serialize)]
223#[allow(dead_code)]
224struct ImagesFilter {
225 dataset_id: DatasetID,
226}
227
228#[derive(Clone, Debug)]
277pub struct Client {
278 http: reqwest::Client,
279 url: String,
280 token: Arc<RwLock<String>>,
281 token_path: Option<PathBuf>,
282}
283
284struct FetchContext<'a> {
286 dataset_id: DatasetID,
287 annotation_set_id: Option<AnnotationSetID>,
288 groups: &'a [String],
289 types: Vec<String>,
290 labels: &'a HashMap<String, u64>,
291}
292
293impl Client {
294 pub fn new() -> Result<Self, Error> {
303 Ok(Client {
304 http: reqwest::Client::builder()
305 .read_timeout(Duration::from_secs(60))
306 .build()?,
307 url: "https://edgefirst.studio".to_string(),
308 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
309 token_path: None,
310 })
311 }
312
313 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
317 Ok(Client {
318 url: format!("https://{}.edgefirst.studio", server),
319 ..self.clone()
320 })
321 }
322
323 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
326 let params = HashMap::from([("username", username), ("password", password)]);
327 let login: LoginResult = self
328 .rpc_without_auth("auth.login".to_owned(), Some(params))
329 .await?;
330 Ok(Client {
331 token: Arc::new(tokio::sync::RwLock::new(login.token)),
332 ..self.clone()
333 })
334 }
335
336 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
339 let token_path = match token_path {
340 Some(path) => path.to_path_buf(),
341 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
342 .ok_or_else(|| {
343 Error::IoError(std::io::Error::new(
344 std::io::ErrorKind::NotFound,
345 "Could not determine user config directory",
346 ))
347 })?
348 .config_dir()
349 .join("token"),
350 };
351
352 debug!("Using token path: {:?}", token_path);
353
354 let token = match token_path.exists() {
355 true => std::fs::read_to_string(&token_path)?,
356 false => "".to_string(),
357 };
358
359 if !token.is_empty() {
360 let client = self.with_token(&token)?;
361 Ok(Client {
362 token_path: Some(token_path),
363 ..client
364 })
365 } else {
366 Ok(Client {
367 token_path: Some(token_path),
368 ..self.clone()
369 })
370 }
371 }
372
373 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
375 if token.is_empty() {
376 return Ok(self.clone());
377 }
378
379 let token_parts: Vec<&str> = token.split('.').collect();
380 if token_parts.len() != 3 {
381 return Err(Error::InvalidToken);
382 }
383
384 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
385 .decode(token_parts[1])
386 .map_err(|_| Error::InvalidToken)?;
387 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
388 let server = match payload.get("database") {
389 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
390 None => return Err(Error::InvalidToken),
391 };
392
393 Ok(Client {
394 url: format!("https://{}.edgefirst.studio", server),
395 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
396 ..self.clone()
397 })
398 }
399
400 pub async fn save_token(&self) -> Result<(), Error> {
401 let path = self.token_path.clone().unwrap_or_else(|| {
402 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
403 .map(|dirs| dirs.config_dir().join("token"))
404 .unwrap_or_else(|| PathBuf::from(".token"))
405 });
406
407 create_dir_all(path.parent().ok_or_else(|| {
408 Error::IoError(std::io::Error::new(
409 std::io::ErrorKind::InvalidInput,
410 "Token path has no parent directory",
411 ))
412 })?)?;
413 let mut file = std::fs::File::create(&path)?;
414 file.write_all(self.token.read().await.as_bytes())?;
415
416 debug!("Saved token to {:?}", path);
417
418 Ok(())
419 }
420
421 pub async fn version(&self) -> Result<String, Error> {
424 let version: HashMap<String, String> = self
425 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
426 .await?;
427 let version = version.get("version").ok_or(Error::InvalidResponse)?;
428 Ok(version.to_owned())
429 }
430
431 pub async fn logout(&self) -> Result<(), Error> {
435 {
436 let mut token = self.token.write().await;
437 *token = "".to_string();
438 }
439
440 if let Some(path) = &self.token_path
441 && path.exists()
442 {
443 fs::remove_file(path).await?;
444 }
445
446 Ok(())
447 }
448
449 pub async fn token(&self) -> String {
453 self.token.read().await.clone()
454 }
455
456 pub async fn verify_token(&self) -> Result<(), Error> {
461 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
462 .await?;
463 Ok::<(), Error>(())
464 }
465
466 pub async fn renew_token(&self) -> Result<(), Error> {
471 let params = HashMap::from([("username".to_string(), self.username().await?)]);
472 let result: LoginResult = self
473 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
474 .await?;
475
476 {
477 let mut token = self.token.write().await;
478 *token = result.token;
479 }
480
481 if self.token_path.is_some() {
482 self.save_token().await?;
483 }
484
485 Ok(())
486 }
487
488 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
489 let token = self.token.read().await;
490 if token.is_empty() {
491 return Err(Error::EmptyToken);
492 }
493
494 let token_parts: Vec<&str> = token.split('.').collect();
495 if token_parts.len() != 3 {
496 return Err(Error::InvalidToken);
497 }
498
499 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
500 .decode(token_parts[1])
501 .map_err(|_| Error::InvalidToken)?;
502 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
503 match payload.get(field) {
504 Some(value) => Ok(value.to_owned()),
505 None => Err(Error::InvalidToken),
506 }
507 }
508
509 pub fn url(&self) -> &str {
511 &self.url
512 }
513
514 pub async fn username(&self) -> Result<String, Error> {
516 match self.token_field("username").await? {
517 serde_json::Value::String(username) => Ok(username),
518 _ => Err(Error::InvalidToken),
519 }
520 }
521
522 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
524 let ts = match self.token_field("exp").await? {
525 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
526 _ => return Err(Error::InvalidToken),
527 };
528
529 match DateTime::<Utc>::from_timestamp_secs(ts) {
530 Some(dt) => Ok(dt),
531 None => Err(Error::InvalidToken),
532 }
533 }
534
535 pub async fn organization(&self) -> Result<Organization, Error> {
537 self.rpc::<(), Organization>("org.get".to_owned(), None)
538 .await
539 }
540
541 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
549 let projects = self
550 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
551 .await?;
552 if let Some(name) = name {
553 Ok(projects
554 .into_iter()
555 .filter(|p| p.name().contains(name))
556 .collect())
557 } else {
558 Ok(projects)
559 }
560 }
561
562 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
565 let params = HashMap::from([("project_id", project_id)]);
566 self.rpc("project.get".to_owned(), Some(params)).await
567 }
568
569 pub async fn datasets(
573 &self,
574 project_id: ProjectID,
575 name: Option<&str>,
576 ) -> Result<Vec<Dataset>, Error> {
577 let params = HashMap::from([("project_id", project_id)]);
578 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
579 if let Some(name) = name {
580 Ok(datasets
581 .into_iter()
582 .filter(|d| d.name().contains(name))
583 .collect())
584 } else {
585 Ok(datasets)
586 }
587 }
588
589 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
592 let params = HashMap::from([("dataset_id", dataset_id)]);
593 self.rpc("dataset.get".to_owned(), Some(params)).await
594 }
595
596 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
598 let params = HashMap::from([("dataset_id", dataset_id)]);
599 self.rpc("label.list".to_owned(), Some(params)).await
600 }
601
602 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
604 let new_label = NewLabel {
605 dataset_id,
606 labels: vec![NewLabelObject {
607 name: name.to_owned(),
608 }],
609 };
610 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
611 Ok(())
612 }
613
614 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
617 let params = HashMap::from([("label_id", label_id)]);
618 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
619 Ok(())
620 }
621
622 pub async fn create_dataset(
634 &self,
635 project_id: &str,
636 name: &str,
637 description: Option<&str>,
638 ) -> Result<DatasetID, Error> {
639 let mut params = HashMap::new();
640 params.insert("project_id", project_id);
641 params.insert("name", name);
642 if let Some(desc) = description {
643 params.insert("description", desc);
644 }
645
646 #[derive(Deserialize)]
647 struct CreateDatasetResult {
648 id: DatasetID,
649 }
650
651 let result: CreateDatasetResult =
652 self.rpc("dataset.create".to_owned(), Some(params)).await?;
653 Ok(result.id)
654 }
655
656 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
666 let params = HashMap::from([("id", dataset_id)]);
667 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
668 Ok(())
669 }
670
671 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
675 #[derive(Serialize)]
676 struct Params {
677 dataset_id: DatasetID,
678 label_id: u64,
679 label_name: String,
680 label_index: u64,
681 }
682
683 let _: String = self
684 .rpc(
685 "label.update".to_owned(),
686 Some(Params {
687 dataset_id: label.dataset_id(),
688 label_id: label.id(),
689 label_name: label.name().to_owned(),
690 label_index: label.index(),
691 }),
692 )
693 .await?;
694 Ok(())
695 }
696
697 pub async fn download_dataset(
698 &self,
699 dataset_id: DatasetID,
700 groups: &[String],
701 file_types: &[FileType],
702 output: PathBuf,
703 progress: Option<Sender<Progress>>,
704 ) -> Result<(), Error> {
705 let samples = self
706 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
707 .await?;
708 fs::create_dir_all(&output).await?;
709
710 let client = self.clone();
711 let file_types = file_types.to_vec();
712 let output = output.clone();
713
714 parallel_foreach_items(samples, progress, move |sample| {
715 let client = client.clone();
716 let file_types = file_types.clone();
717 let output = output.clone();
718
719 async move {
720 for file_type in file_types {
721 if let Some(data) = sample.download(&client, file_type.clone()).await? {
722 let (file_ext, is_image) = match file_type.clone() {
723 FileType::Image => (
724 infer::get(&data)
725 .expect("Failed to identify image file format for sample")
726 .extension()
727 .to_string(),
728 true,
729 ),
730 other => (other.to_string(), false),
731 };
732
733 let sequence_dir = sample
738 .sequence_name()
739 .map(|name| sanitize_path_component(name));
740
741 let target_dir = sequence_dir
742 .map(|seq| output.join(seq))
743 .unwrap_or_else(|| output.clone());
744 fs::create_dir_all(&target_dir).await?;
745
746 let sanitized_sample_name = sample
747 .name()
748 .map(|name| sanitize_path_component(&name))
749 .unwrap_or_else(|| "unknown".to_string());
750
751 let image_name = sample.image_name().map(sanitize_path_component);
752
753 let file_name = if is_image {
754 image_name.unwrap_or_else(|| {
755 format!("{}.{}", sanitized_sample_name, file_ext)
756 })
757 } else {
758 format!("{}.{}", sanitized_sample_name, file_ext)
759 };
760
761 let file_path = target_dir.join(&file_name);
762
763 let mut file = File::create(&file_path).await?;
764 file.write_all(&data).await?;
765 } else {
766 warn!(
767 "No data for sample: {}",
768 sample
769 .id()
770 .map(|id| id.to_string())
771 .unwrap_or_else(|| "unknown".to_string())
772 );
773 }
774 }
775
776 Ok(())
777 }
778 })
779 .await
780 }
781
782 pub async fn annotation_sets(
784 &self,
785 dataset_id: DatasetID,
786 ) -> Result<Vec<AnnotationSet>, Error> {
787 let params = HashMap::from([("dataset_id", dataset_id)]);
788 self.rpc("annset.list".to_owned(), Some(params)).await
789 }
790
791 pub async fn create_annotation_set(
803 &self,
804 dataset_id: DatasetID,
805 name: &str,
806 description: Option<&str>,
807 ) -> Result<AnnotationSetID, Error> {
808 #[derive(Serialize)]
809 struct Params<'a> {
810 dataset_id: DatasetID,
811 name: &'a str,
812 operator: &'a str,
813 #[serde(skip_serializing_if = "Option::is_none")]
814 description: Option<&'a str>,
815 }
816
817 #[derive(Deserialize)]
818 struct CreateAnnotationSetResult {
819 id: AnnotationSetID,
820 }
821
822 let username = self.username().await?;
823 let result: CreateAnnotationSetResult = self
824 .rpc(
825 "annset.add".to_owned(),
826 Some(Params {
827 dataset_id,
828 name,
829 operator: &username,
830 description,
831 }),
832 )
833 .await?;
834 Ok(result.id)
835 }
836
837 pub async fn delete_annotation_set(
848 &self,
849 annotation_set_id: AnnotationSetID,
850 ) -> Result<(), Error> {
851 let params = HashMap::from([("id", annotation_set_id)]);
852 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
853 Ok(())
854 }
855
856 pub async fn annotation_set(
858 &self,
859 annotation_set_id: AnnotationSetID,
860 ) -> Result<AnnotationSet, Error> {
861 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
862 self.rpc("annset.get".to_owned(), Some(params)).await
863 }
864
865 pub async fn annotations(
878 &self,
879 annotation_set_id: AnnotationSetID,
880 groups: &[String],
881 annotation_types: &[AnnotationType],
882 progress: Option<Sender<Progress>>,
883 ) -> Result<Vec<Annotation>, Error> {
884 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
885 let labels = self
886 .labels(dataset_id)
887 .await?
888 .into_iter()
889 .map(|label| (label.name().to_string(), label.index()))
890 .collect::<HashMap<_, _>>();
891 let total = self
892 .samples_count(
893 dataset_id,
894 Some(annotation_set_id),
895 annotation_types,
896 groups,
897 &[],
898 )
899 .await?
900 .total as usize;
901
902 if total == 0 {
903 return Ok(vec![]);
904 }
905
906 let context = FetchContext {
907 dataset_id,
908 annotation_set_id: Some(annotation_set_id),
909 groups,
910 types: annotation_types.iter().map(|t| t.to_string()).collect(),
911 labels: &labels,
912 };
913
914 self.fetch_annotations_paginated(context, total, progress)
915 .await
916 }
917
918 async fn fetch_annotations_paginated(
919 &self,
920 context: FetchContext<'_>,
921 total: usize,
922 progress: Option<Sender<Progress>>,
923 ) -> Result<Vec<Annotation>, Error> {
924 let mut annotations = vec![];
925 let mut continue_token: Option<String> = None;
926 let mut current = 0;
927
928 loop {
929 let params = SamplesListParams {
930 dataset_id: context.dataset_id,
931 annotation_set_id: context.annotation_set_id,
932 types: context.types.clone(),
933 group_names: context.groups.to_vec(),
934 continue_token,
935 };
936
937 let result: SamplesListResult =
938 self.rpc("samples.list".to_owned(), Some(params)).await?;
939 current += result.samples.len();
940 continue_token = result.continue_token;
941
942 if result.samples.is_empty() {
943 break;
944 }
945
946 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
947
948 if let Some(progress) = &progress {
949 let _ = progress.send(Progress { current, total }).await;
950 }
951
952 match &continue_token {
953 Some(token) if !token.is_empty() => continue,
954 _ => break,
955 }
956 }
957
958 drop(progress);
959 Ok(annotations)
960 }
961
962 fn process_sample_annotations(
963 &self,
964 samples: &[Sample],
965 labels: &HashMap<String, u64>,
966 annotations: &mut Vec<Annotation>,
967 ) {
968 for sample in samples {
969 if sample.annotations().is_empty() {
970 let mut annotation = Annotation::new();
971 annotation.set_sample_id(sample.id());
972 annotation.set_name(sample.name());
973 annotation.set_sequence_name(sample.sequence_name().cloned());
974 annotation.set_frame_number(sample.frame_number());
975 annotation.set_group(sample.group().cloned());
976 annotations.push(annotation);
977 continue;
978 }
979
980 for annotation in sample.annotations() {
981 let mut annotation = annotation.clone();
982 annotation.set_sample_id(sample.id());
983 annotation.set_name(sample.name());
984 annotation.set_sequence_name(sample.sequence_name().cloned());
985 annotation.set_frame_number(sample.frame_number());
986 annotation.set_group(sample.group().cloned());
987 Self::set_label_index_from_map(&mut annotation, labels);
988 annotations.push(annotation);
989 }
990 }
991 }
992
993 fn parse_frame_from_image_name(
1001 image_name: Option<&String>,
1002 sequence_name: Option<&String>,
1003 ) -> Option<u32> {
1004 use std::path::Path;
1005
1006 let sequence = sequence_name?;
1007 let name = image_name?;
1008
1009 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1011
1012 stem.strip_prefix(sequence)
1014 .and_then(|suffix| suffix.strip_prefix('_'))
1015 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1016 }
1017
1018 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1020 if let Some(label) = annotation.label() {
1021 annotation.set_label_index(Some(labels[label.as_str()]));
1022 }
1023 }
1024
1025 pub async fn samples_count(
1026 &self,
1027 dataset_id: DatasetID,
1028 annotation_set_id: Option<AnnotationSetID>,
1029 annotation_types: &[AnnotationType],
1030 groups: &[String],
1031 types: &[FileType],
1032 ) -> Result<SamplesCountResult, Error> {
1033 let types = annotation_types
1034 .iter()
1035 .map(|t| t.to_string())
1036 .chain(types.iter().map(|t| t.to_string()))
1037 .collect::<Vec<_>>();
1038
1039 let params = SamplesListParams {
1040 dataset_id,
1041 annotation_set_id,
1042 group_names: groups.to_vec(),
1043 types,
1044 continue_token: None,
1045 };
1046
1047 self.rpc("samples.count".to_owned(), Some(params)).await
1048 }
1049
1050 pub async fn samples(
1051 &self,
1052 dataset_id: DatasetID,
1053 annotation_set_id: Option<AnnotationSetID>,
1054 annotation_types: &[AnnotationType],
1055 groups: &[String],
1056 types: &[FileType],
1057 progress: Option<Sender<Progress>>,
1058 ) -> Result<Vec<Sample>, Error> {
1059 let types_vec = annotation_types
1060 .iter()
1061 .map(|t| t.to_string())
1062 .chain(types.iter().map(|t| t.to_string()))
1063 .collect::<Vec<_>>();
1064 let labels = self
1065 .labels(dataset_id)
1066 .await?
1067 .into_iter()
1068 .map(|label| (label.name().to_string(), label.index()))
1069 .collect::<HashMap<_, _>>();
1070 let total = self
1071 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1072 .await?
1073 .total as usize;
1074
1075 if total == 0 {
1076 return Ok(vec![]);
1077 }
1078
1079 let context = FetchContext {
1080 dataset_id,
1081 annotation_set_id,
1082 groups,
1083 types: types_vec,
1084 labels: &labels,
1085 };
1086
1087 self.fetch_samples_paginated(context, total, progress).await
1088 }
1089
1090 async fn fetch_samples_paginated(
1091 &self,
1092 context: FetchContext<'_>,
1093 total: usize,
1094 progress: Option<Sender<Progress>>,
1095 ) -> Result<Vec<Sample>, Error> {
1096 let mut samples = vec![];
1097 let mut continue_token: Option<String> = None;
1098 let mut current = 0;
1099
1100 loop {
1101 let params = SamplesListParams {
1102 dataset_id: context.dataset_id,
1103 annotation_set_id: context.annotation_set_id,
1104 types: context.types.clone(),
1105 group_names: context.groups.to_vec(),
1106 continue_token: continue_token.clone(),
1107 };
1108
1109 let result: SamplesListResult =
1110 self.rpc("samples.list".to_owned(), Some(params)).await?;
1111 current += result.samples.len();
1112 continue_token = result.continue_token;
1113
1114 if result.samples.is_empty() {
1115 break;
1116 }
1117
1118 samples.append(
1119 &mut result
1120 .samples
1121 .into_iter()
1122 .map(|s| {
1123 let frame_number = s.frame_number.or_else(|| {
1128 Self::parse_frame_from_image_name(
1129 s.image_name.as_ref(),
1130 s.sequence_name.as_ref(),
1131 )
1132 });
1133
1134 let mut anns = s.annotations().to_vec();
1135 for ann in &mut anns {
1136 ann.set_name(s.name());
1138 ann.set_group(s.group().cloned());
1139 ann.set_sequence_name(s.sequence_name().cloned());
1140 ann.set_frame_number(frame_number);
1141 Self::set_label_index_from_map(ann, context.labels);
1142 }
1143 s.with_annotations(anns).with_frame_number(frame_number)
1144 })
1145 .collect::<Vec<_>>(),
1146 );
1147
1148 if let Some(progress) = &progress {
1149 let _ = progress.send(Progress { current, total }).await;
1150 }
1151
1152 match &continue_token {
1153 Some(token) if !token.is_empty() => continue,
1154 _ => break,
1155 }
1156 }
1157
1158 drop(progress);
1159 Ok(samples)
1160 }
1161
1162 pub async fn populate_samples(
1254 &self,
1255 dataset_id: DatasetID,
1256 annotation_set_id: Option<AnnotationSetID>,
1257 samples: Vec<Sample>,
1258 progress: Option<Sender<Progress>>,
1259 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1260 use crate::api::SamplesPopulateParams;
1261
1262 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1264
1265 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1267
1268 let has_files_to_upload = !files_to_upload.is_empty();
1269
1270 let params = SamplesPopulateParams {
1272 dataset_id,
1273 annotation_set_id,
1274 presigned_urls: Some(has_files_to_upload),
1275 samples,
1276 };
1277
1278 let results: Vec<crate::SamplesPopulateResult> = self
1279 .rpc("samples.populate2".to_owned(), Some(params))
1280 .await?;
1281
1282 if has_files_to_upload {
1284 self.upload_sample_files(&results, files_to_upload, progress)
1285 .await?;
1286 }
1287
1288 Ok(results)
1289 }
1290
1291 fn prepare_samples_for_upload(
1292 &self,
1293 samples: Vec<Sample>,
1294 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1295 ) -> Result<Vec<Sample>, Error> {
1296 Ok(samples
1297 .into_iter()
1298 .map(|mut sample| {
1299 if sample.uuid.is_none() {
1301 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1302 }
1303
1304 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1305
1306 let files_copy = sample.files.clone();
1308 let updated_files: Vec<crate::SampleFile> = files_copy
1309 .iter()
1310 .map(|file| {
1311 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1312 })
1313 .collect();
1314
1315 sample.files = updated_files;
1316 sample
1317 })
1318 .collect())
1319 }
1320
1321 fn process_sample_file(
1322 &self,
1323 file: &crate::SampleFile,
1324 sample_uuid: &str,
1325 sample: &mut Sample,
1326 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1327 ) -> crate::SampleFile {
1328 use std::path::Path;
1329
1330 if let Some(filename) = file.filename() {
1331 let path = Path::new(filename);
1332
1333 if path.exists()
1335 && path.is_file()
1336 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1337 {
1338 if file.file_type() == "image"
1340 && (sample.width.is_none() || sample.height.is_none())
1341 && let Ok(size) = imagesize::size(path)
1342 {
1343 sample.width = Some(size.width as u32);
1344 sample.height = Some(size.height as u32);
1345 }
1346
1347 files_to_upload.push((
1349 sample_uuid.to_string(),
1350 file.file_type().to_string(),
1351 path.to_path_buf(),
1352 basename.to_string(),
1353 ));
1354
1355 return crate::SampleFile::with_filename(
1357 file.file_type().to_string(),
1358 basename.to_string(),
1359 );
1360 }
1361 }
1362 file.clone()
1364 }
1365
1366 async fn upload_sample_files(
1367 &self,
1368 results: &[crate::SamplesPopulateResult],
1369 files_to_upload: Vec<(String, String, PathBuf, String)>,
1370 progress: Option<Sender<Progress>>,
1371 ) -> Result<(), Error> {
1372 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1374 for (uuid, _file_type, path, basename) in files_to_upload {
1375 upload_map.insert((uuid, basename), path);
1376 }
1377
1378 let http = self.http.clone();
1379
1380 let upload_tasks: Vec<_> = results
1382 .iter()
1383 .map(|result| (result.uuid.clone(), result.urls.clone()))
1384 .collect();
1385
1386 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1387 let http = http.clone();
1388 let upload_map = upload_map.clone();
1389
1390 async move {
1391 for url_info in &urls {
1393 if let Some(local_path) =
1394 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1395 {
1396 upload_file_to_presigned_url(
1398 http.clone(),
1399 &url_info.url,
1400 local_path.clone(),
1401 )
1402 .await?;
1403 }
1404 }
1405
1406 Ok(())
1407 }
1408 })
1409 .await
1410 }
1411
1412 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1413 for attempt in 1..MAX_RETRIES {
1414 let resp = match self.http.get(url).send().await {
1415 Ok(resp) => resp,
1416 Err(err) => {
1417 warn!(
1418 "Socket Error [retry {}/{}]: {:?}",
1419 attempt, MAX_RETRIES, err
1420 );
1421 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1422 continue;
1423 }
1424 };
1425
1426 match resp.bytes().await {
1427 Ok(body) => return Ok(body.to_vec()),
1428 Err(err) => {
1429 warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
1430 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1431 continue;
1432 }
1433 };
1434 }
1435
1436 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1437 }
1438
1439 #[deprecated(
1479 since = "0.8.0",
1480 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
1481 )]
1482 #[cfg(feature = "polars")]
1483 pub async fn annotations_dataframe(
1484 &self,
1485 annotation_set_id: AnnotationSetID,
1486 groups: &[String],
1487 types: &[AnnotationType],
1488 progress: Option<Sender<Progress>>,
1489 ) -> Result<DataFrame, Error> {
1490 use crate::dataset::annotations_dataframe;
1491
1492 let annotations = self
1493 .annotations(annotation_set_id, groups, types, progress)
1494 .await?;
1495 #[allow(deprecated)]
1496 annotations_dataframe(&annotations)
1497 }
1498
1499 #[cfg(feature = "polars")]
1536 pub async fn samples_dataframe(
1537 &self,
1538 dataset_id: DatasetID,
1539 annotation_set_id: Option<AnnotationSetID>,
1540 groups: &[String],
1541 types: &[AnnotationType],
1542 progress: Option<Sender<Progress>>,
1543 ) -> Result<DataFrame, Error> {
1544 use crate::dataset::samples_dataframe;
1545
1546 let samples = self
1547 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
1548 .await?;
1549 samples_dataframe(&samples)
1550 }
1551
1552 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1555 let snapshots: Vec<Snapshot> = self
1556 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1557 .await?;
1558 if let Some(name) = name {
1559 Ok(snapshots
1560 .into_iter()
1561 .filter(|s| s.description().contains(name))
1562 .collect())
1563 } else {
1564 Ok(snapshots)
1565 }
1566 }
1567
1568 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1570 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1571 self.rpc("snapshots.get".to_owned(), Some(params)).await
1572 }
1573
1574 pub async fn create_snapshot(
1581 &self,
1582 path: &str,
1583 progress: Option<Sender<Progress>>,
1584 ) -> Result<Snapshot, Error> {
1585 let path = Path::new(path);
1586
1587 if path.is_dir() {
1588 let path_str = path.to_str().ok_or_else(|| {
1589 Error::IoError(std::io::Error::new(
1590 std::io::ErrorKind::InvalidInput,
1591 "Path contains invalid UTF-8",
1592 ))
1593 })?;
1594 return self.create_snapshot_folder(path_str, progress).await;
1595 }
1596
1597 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1598 Error::IoError(std::io::Error::new(
1599 std::io::ErrorKind::InvalidInput,
1600 "Invalid filename",
1601 ))
1602 })?;
1603 let total = path.metadata()?.len() as usize;
1604 let current = Arc::new(AtomicUsize::new(0));
1605
1606 if let Some(progress) = &progress {
1607 let _ = progress.send(Progress { current: 0, total }).await;
1608 }
1609
1610 let params = SnapshotCreateMultipartParams {
1611 snapshot_name: name.to_owned(),
1612 keys: vec![name.to_owned()],
1613 file_sizes: vec![total],
1614 };
1615 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1616 .rpc(
1617 "snapshots.create_upload_url_multipart".to_owned(),
1618 Some(params),
1619 )
1620 .await?;
1621
1622 let snapshot_id = match multipart.get("snapshot_id") {
1623 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1624 _ => return Err(Error::InvalidResponse),
1625 };
1626
1627 let snapshot = self.snapshot(snapshot_id).await?;
1628 let part_prefix = snapshot
1629 .path()
1630 .split("::/")
1631 .last()
1632 .ok_or(Error::InvalidResponse)?
1633 .to_owned();
1634 let part_key = format!("{}/{}", part_prefix, name);
1635 let mut part = match multipart.get(&part_key) {
1636 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1637 _ => return Err(Error::InvalidResponse),
1638 }
1639 .clone();
1640 part.key = Some(part_key);
1641
1642 let params = upload_multipart(
1643 self.http.clone(),
1644 part.clone(),
1645 path.to_path_buf(),
1646 total,
1647 current,
1648 progress.clone(),
1649 )
1650 .await?;
1651
1652 let complete: String = self
1653 .rpc(
1654 "snapshots.complete_multipart_upload".to_owned(),
1655 Some(params),
1656 )
1657 .await?;
1658 debug!("Snapshot Multipart Complete: {:?}", complete);
1659
1660 let params: SnapshotStatusParams = SnapshotStatusParams {
1661 snapshot_id,
1662 status: "available".to_owned(),
1663 };
1664 let _: SnapshotStatusResult = self
1665 .rpc("snapshots.update".to_owned(), Some(params))
1666 .await?;
1667
1668 if let Some(progress) = progress {
1669 drop(progress);
1670 }
1671
1672 self.snapshot(snapshot_id).await
1673 }
1674
1675 async fn create_snapshot_folder(
1676 &self,
1677 path: &str,
1678 progress: Option<Sender<Progress>>,
1679 ) -> Result<Snapshot, Error> {
1680 let path = Path::new(path);
1681 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1682 Error::IoError(std::io::Error::new(
1683 std::io::ErrorKind::InvalidInput,
1684 "Invalid directory name",
1685 ))
1686 })?;
1687
1688 let files = WalkDir::new(path)
1689 .into_iter()
1690 .filter_map(|entry| entry.ok())
1691 .filter(|entry| entry.file_type().is_file())
1692 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
1693 .collect::<Vec<_>>();
1694
1695 let total: usize = files
1696 .iter()
1697 .filter_map(|file| path.join(file).metadata().ok())
1698 .map(|metadata| metadata.len() as usize)
1699 .sum();
1700 let current = Arc::new(AtomicUsize::new(0));
1701
1702 if let Some(progress) = &progress {
1703 let _ = progress.send(Progress { current: 0, total }).await;
1704 }
1705
1706 let keys = files
1707 .iter()
1708 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
1709 .collect::<Vec<_>>();
1710 let file_sizes = files
1711 .iter()
1712 .filter_map(|key| path.join(key).metadata().ok())
1713 .map(|metadata| metadata.len() as usize)
1714 .collect::<Vec<_>>();
1715
1716 let params = SnapshotCreateMultipartParams {
1717 snapshot_name: name.to_owned(),
1718 keys,
1719 file_sizes,
1720 };
1721
1722 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1723 .rpc(
1724 "snapshots.create_upload_url_multipart".to_owned(),
1725 Some(params),
1726 )
1727 .await?;
1728
1729 let snapshot_id = match multipart.get("snapshot_id") {
1730 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1731 _ => return Err(Error::InvalidResponse),
1732 };
1733
1734 let snapshot = self.snapshot(snapshot_id).await?;
1735 let part_prefix = snapshot
1736 .path()
1737 .split("::/")
1738 .last()
1739 .ok_or(Error::InvalidResponse)?
1740 .to_owned();
1741
1742 for file in files {
1743 let file_str = file.to_str().ok_or_else(|| {
1744 Error::IoError(std::io::Error::new(
1745 std::io::ErrorKind::InvalidInput,
1746 "File path contains invalid UTF-8",
1747 ))
1748 })?;
1749 let part_key = format!("{}/{}", part_prefix, file_str);
1750 let mut part = match multipart.get(&part_key) {
1751 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1752 _ => return Err(Error::InvalidResponse),
1753 }
1754 .clone();
1755 part.key = Some(part_key);
1756
1757 let params = upload_multipart(
1758 self.http.clone(),
1759 part.clone(),
1760 path.join(file),
1761 total,
1762 current.clone(),
1763 progress.clone(),
1764 )
1765 .await?;
1766
1767 let complete: String = self
1768 .rpc(
1769 "snapshots.complete_multipart_upload".to_owned(),
1770 Some(params),
1771 )
1772 .await?;
1773 debug!("Snapshot Part Complete: {:?}", complete);
1774 }
1775
1776 let params = SnapshotStatusParams {
1777 snapshot_id,
1778 status: "available".to_owned(),
1779 };
1780 let _: SnapshotStatusResult = self
1781 .rpc("snapshots.update".to_owned(), Some(params))
1782 .await?;
1783
1784 if let Some(progress) = progress {
1785 drop(progress);
1786 }
1787
1788 self.snapshot(snapshot_id).await
1789 }
1790
1791 pub async fn download_snapshot(
1796 &self,
1797 snapshot_id: SnapshotID,
1798 output: PathBuf,
1799 progress: Option<Sender<Progress>>,
1800 ) -> Result<(), Error> {
1801 fs::create_dir_all(&output).await?;
1802
1803 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1804 let items: HashMap<String, String> = self
1805 .rpc("snapshots.create_download_url".to_owned(), Some(params))
1806 .await?;
1807
1808 let total = Arc::new(AtomicUsize::new(0));
1809 let current = Arc::new(AtomicUsize::new(0));
1810 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1811
1812 let tasks = items
1813 .iter()
1814 .map(|(key, url)| {
1815 let http = self.http.clone();
1816 let key = key.clone();
1817 let url = url.clone();
1818 let output = output.clone();
1819 let progress = progress.clone();
1820 let current = current.clone();
1821 let total = total.clone();
1822 let sem = sem.clone();
1823
1824 tokio::spawn(async move {
1825 let _permit = sem.acquire().await.map_err(|_| {
1826 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1827 })?;
1828 let res = http.get(url).send().await?;
1829 let content_length = res.content_length().unwrap_or(0) as usize;
1830
1831 if let Some(progress) = &progress {
1832 let total = total.fetch_add(content_length, Ordering::SeqCst);
1833 let _ = progress
1834 .send(Progress {
1835 current: current.load(Ordering::SeqCst),
1836 total: total + content_length,
1837 })
1838 .await;
1839 }
1840
1841 let mut file = File::create(output.join(key)).await?;
1842 let mut stream = res.bytes_stream();
1843
1844 while let Some(chunk) = stream.next().await {
1845 let chunk = chunk?;
1846 file.write_all(&chunk).await?;
1847 let len = chunk.len();
1848
1849 if let Some(progress) = &progress {
1850 let total = total.load(Ordering::SeqCst);
1851 let current = current.fetch_add(len, Ordering::SeqCst);
1852
1853 let _ = progress
1854 .send(Progress {
1855 current: current + len,
1856 total,
1857 })
1858 .await;
1859 }
1860 }
1861
1862 Ok::<(), Error>(())
1863 })
1864 })
1865 .collect::<Vec<_>>();
1866
1867 join_all(tasks)
1868 .await
1869 .into_iter()
1870 .collect::<Result<Vec<_>, _>>()?
1871 .into_iter()
1872 .collect::<Result<Vec<_>, _>>()?;
1873
1874 Ok(())
1875 }
1876
1877 #[allow(clippy::too_many_arguments)]
1892 pub async fn restore_snapshot(
1893 &self,
1894 project_id: ProjectID,
1895 snapshot_id: SnapshotID,
1896 topics: &[String],
1897 autolabel: &[String],
1898 autodepth: bool,
1899 dataset_name: Option<&str>,
1900 dataset_description: Option<&str>,
1901 ) -> Result<SnapshotRestoreResult, Error> {
1902 let params = SnapshotRestore {
1903 project_id,
1904 snapshot_id,
1905 fps: 1,
1906 autodepth,
1907 agtg_pipeline: !autolabel.is_empty(),
1908 autolabel: autolabel.to_vec(),
1909 topics: topics.to_vec(),
1910 dataset_name: dataset_name.map(|s| s.to_owned()),
1911 dataset_description: dataset_description.map(|s| s.to_owned()),
1912 };
1913 self.rpc("snapshots.restore".to_owned(), Some(params)).await
1914 }
1915
1916 pub async fn experiments(
1925 &self,
1926 project_id: ProjectID,
1927 name: Option<&str>,
1928 ) -> Result<Vec<Experiment>, Error> {
1929 let params = HashMap::from([("project_id", project_id)]);
1930 let experiments: Vec<Experiment> =
1931 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1932 if let Some(name) = name {
1933 Ok(experiments
1934 .into_iter()
1935 .filter(|e| e.name().contains(name))
1936 .collect())
1937 } else {
1938 Ok(experiments)
1939 }
1940 }
1941
1942 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1945 let params = HashMap::from([("trainer_id", experiment_id)]);
1946 self.rpc("trainer.get".to_owned(), Some(params)).await
1947 }
1948
1949 pub async fn training_sessions(
1958 &self,
1959 experiment_id: ExperimentID,
1960 name: Option<&str>,
1961 ) -> Result<Vec<TrainingSession>, Error> {
1962 let params = HashMap::from([("trainer_id", experiment_id)]);
1963 let sessions: Vec<TrainingSession> = self
1964 .rpc("trainer.session.list".to_owned(), Some(params))
1965 .await?;
1966 if let Some(name) = name {
1967 Ok(sessions
1968 .into_iter()
1969 .filter(|s| s.name().contains(name))
1970 .collect())
1971 } else {
1972 Ok(sessions)
1973 }
1974 }
1975
1976 pub async fn training_session(
1979 &self,
1980 session_id: TrainingSessionID,
1981 ) -> Result<TrainingSession, Error> {
1982 let params = HashMap::from([("trainer_session_id", session_id)]);
1983 self.rpc("trainer.session.get".to_owned(), Some(params))
1984 .await
1985 }
1986
1987 pub async fn validation_sessions(
1989 &self,
1990 project_id: ProjectID,
1991 ) -> Result<Vec<ValidationSession>, Error> {
1992 let params = HashMap::from([("project_id", project_id)]);
1993 self.rpc("validate.session.list".to_owned(), Some(params))
1994 .await
1995 }
1996
1997 pub async fn validation_session(
1999 &self,
2000 session_id: ValidationSessionID,
2001 ) -> Result<ValidationSession, Error> {
2002 let params = HashMap::from([("validate_session_id", session_id)]);
2003 self.rpc("validate.session.get".to_owned(), Some(params))
2004 .await
2005 }
2006
2007 pub async fn artifacts(
2010 &self,
2011 training_session_id: TrainingSessionID,
2012 ) -> Result<Vec<Artifact>, Error> {
2013 let params = HashMap::from([("training_session_id", training_session_id)]);
2014 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2015 .await
2016 }
2017
2018 pub async fn download_artifact(
2024 &self,
2025 training_session_id: TrainingSessionID,
2026 modelname: &str,
2027 filename: Option<PathBuf>,
2028 progress: Option<Sender<Progress>>,
2029 ) -> Result<(), Error> {
2030 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2031 let resp = self
2032 .http
2033 .get(format!(
2034 "{}/download_model?training_session_id={}&file={}",
2035 self.url,
2036 training_session_id.value(),
2037 modelname
2038 ))
2039 .header("Authorization", format!("Bearer {}", self.token().await))
2040 .send()
2041 .await?;
2042 if !resp.status().is_success() {
2043 let err = resp.error_for_status_ref().unwrap_err();
2044 return Err(Error::HttpError(err));
2045 }
2046
2047 if let Some(parent) = filename.parent() {
2048 fs::create_dir_all(parent).await?;
2049 }
2050
2051 if let Some(progress) = progress {
2052 let total = resp.content_length().unwrap_or(0) as usize;
2053 let _ = progress.send(Progress { current: 0, total }).await;
2054
2055 let mut file = File::create(filename).await?;
2056 let mut current = 0;
2057 let mut stream = resp.bytes_stream();
2058
2059 while let Some(item) = stream.next().await {
2060 let chunk = item?;
2061 file.write_all(&chunk).await?;
2062 current += chunk.len();
2063 let _ = progress.send(Progress { current, total }).await;
2064 }
2065 } else {
2066 let body = resp.bytes().await?;
2067 fs::write(filename, body).await?;
2068 }
2069
2070 Ok(())
2071 }
2072
2073 pub async fn download_checkpoint(
2083 &self,
2084 training_session_id: TrainingSessionID,
2085 checkpoint: &str,
2086 filename: Option<PathBuf>,
2087 progress: Option<Sender<Progress>>,
2088 ) -> Result<(), Error> {
2089 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2090 let resp = self
2091 .http
2092 .get(format!(
2093 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2094 self.url,
2095 training_session_id.value(),
2096 checkpoint
2097 ))
2098 .header("Authorization", format!("Bearer {}", self.token().await))
2099 .send()
2100 .await?;
2101 if !resp.status().is_success() {
2102 let err = resp.error_for_status_ref().unwrap_err();
2103 return Err(Error::HttpError(err));
2104 }
2105
2106 if let Some(parent) = filename.parent() {
2107 fs::create_dir_all(parent).await?;
2108 }
2109
2110 if let Some(progress) = progress {
2111 let total = resp.content_length().unwrap_or(0) as usize;
2112 let _ = progress.send(Progress { current: 0, total }).await;
2113
2114 let mut file = File::create(filename).await?;
2115 let mut current = 0;
2116 let mut stream = resp.bytes_stream();
2117
2118 while let Some(item) = stream.next().await {
2119 let chunk = item?;
2120 file.write_all(&chunk).await?;
2121 current += chunk.len();
2122 let _ = progress.send(Progress { current, total }).await;
2123 }
2124 } else {
2125 let body = resp.bytes().await?;
2126 fs::write(filename, body).await?;
2127 }
2128
2129 Ok(())
2130 }
2131
2132 pub async fn tasks(
2134 &self,
2135 name: Option<&str>,
2136 workflow: Option<&str>,
2137 status: Option<&str>,
2138 manager: Option<&str>,
2139 ) -> Result<Vec<Task>, Error> {
2140 let mut params = TasksListParams {
2141 continue_token: None,
2142 status: status.map(|s| vec![s.to_owned()]),
2143 manager: manager.map(|m| vec![m.to_owned()]),
2144 };
2145 let mut tasks = Vec::new();
2146
2147 loop {
2148 let result = self
2149 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
2150 .await?;
2151 tasks.extend(result.tasks);
2152
2153 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
2154 params.continue_token = None;
2155 } else {
2156 params.continue_token = result.continue_token;
2157 }
2158
2159 if params.continue_token.is_none() {
2160 break;
2161 }
2162 }
2163
2164 if let Some(name) = name {
2165 tasks.retain(|t| t.name().contains(name));
2166 }
2167
2168 if let Some(workflow) = workflow {
2169 tasks.retain(|t| t.workflow().contains(workflow));
2170 }
2171
2172 Ok(tasks)
2173 }
2174
2175 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
2177 self.rpc(
2178 "task.get".to_owned(),
2179 Some(HashMap::from([("id", task_id)])),
2180 )
2181 .await
2182 }
2183
2184 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
2186 let status = TaskStatus {
2187 task_id,
2188 status: status.to_owned(),
2189 };
2190 self.rpc("docker.update.status".to_owned(), Some(status))
2191 .await
2192 }
2193
2194 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
2198 let stages: Vec<HashMap<String, String>> = stages
2199 .iter()
2200 .map(|(key, value)| {
2201 let mut stage_map = HashMap::new();
2202 stage_map.insert(key.to_string(), value.to_string());
2203 stage_map
2204 })
2205 .collect();
2206 let params = TaskStages { task_id, stages };
2207 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
2208 Ok(())
2209 }
2210
2211 pub async fn update_stage(
2214 &self,
2215 task_id: TaskID,
2216 stage: &str,
2217 status: &str,
2218 message: &str,
2219 percentage: u8,
2220 ) -> Result<(), Error> {
2221 let stage = Stage::new(
2222 Some(task_id),
2223 stage.to_owned(),
2224 Some(status.to_owned()),
2225 Some(message.to_owned()),
2226 percentage,
2227 );
2228 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
2229 Ok(())
2230 }
2231
2232 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
2234 let req = self
2235 .http
2236 .get(format!("{}/{}", self.url, query))
2237 .header("User-Agent", "EdgeFirst Client")
2238 .header("Authorization", format!("Bearer {}", self.token().await));
2239 let resp = req.send().await?;
2240
2241 if resp.status().is_success() {
2242 let body = resp.bytes().await?;
2243
2244 if log_enabled!(Level::Trace) {
2245 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
2246 }
2247
2248 Ok(body.to_vec())
2249 } else {
2250 let err = resp.error_for_status_ref().unwrap_err();
2251 Err(Error::HttpError(err))
2252 }
2253 }
2254
2255 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2259 let req = self
2260 .http
2261 .post(format!("{}/api?method={}", self.url, method))
2262 .header("Accept", "application/json")
2263 .header("User-Agent", "EdgeFirst Client")
2264 .header("Authorization", format!("Bearer {}", self.token().await))
2265 .multipart(form);
2266 let resp = req.send().await?;
2267
2268 if resp.status().is_success() {
2269 let body = resp.bytes().await?;
2270
2271 if log_enabled!(Level::Trace) {
2272 trace!(
2273 "POST Multipart Response: {}",
2274 String::from_utf8_lossy(&body)
2275 );
2276 }
2277
2278 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2279 Ok(response) => response,
2280 Err(err) => {
2281 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2282 return Err(err.into());
2283 }
2284 };
2285
2286 if let Some(error) = response.error {
2287 Err(Error::RpcError(error.code, error.message))
2288 } else if let Some(result) = response.result {
2289 Ok(result)
2290 } else {
2291 Err(Error::InvalidResponse)
2292 }
2293 } else {
2294 let err = resp.error_for_status_ref().unwrap_err();
2295 Err(Error::HttpError(err))
2296 }
2297 }
2298
2299 pub async fn rpc<Params, RpcResult>(
2308 &self,
2309 method: String,
2310 params: Option<Params>,
2311 ) -> Result<RpcResult, Error>
2312 where
2313 Params: Serialize,
2314 RpcResult: DeserializeOwned,
2315 {
2316 let auth_expires = self.token_expiration().await?;
2317 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2318 self.renew_token().await?;
2319 }
2320
2321 self.rpc_without_auth(method, params).await
2322 }
2323
2324 async fn rpc_without_auth<Params, RpcResult>(
2325 &self,
2326 method: String,
2327 params: Option<Params>,
2328 ) -> Result<RpcResult, Error>
2329 where
2330 Params: Serialize,
2331 RpcResult: DeserializeOwned,
2332 {
2333 let request = RpcRequest {
2334 method,
2335 params,
2336 ..Default::default()
2337 };
2338
2339 if log_enabled!(Level::Trace) {
2340 trace!(
2341 "RPC Request: {}",
2342 serde_json::ser::to_string_pretty(&request)?
2343 );
2344 }
2345
2346 for attempt in 0..MAX_RETRIES {
2347 match self.try_rpc_request(&request, attempt).await {
2348 Ok(result) => return Ok(result),
2349 Err(Error::MaxRetriesExceeded(_)) => continue,
2350 Err(err) => return Err(err),
2351 }
2352 }
2353
2354 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2355 }
2356
2357 async fn try_rpc_request<Params, RpcResult>(
2358 &self,
2359 request: &RpcRequest<Params>,
2360 attempt: u32,
2361 ) -> Result<RpcResult, Error>
2362 where
2363 Params: Serialize,
2364 RpcResult: DeserializeOwned,
2365 {
2366 let res = match self
2367 .http
2368 .post(format!("{}/api", self.url))
2369 .header("Accept", "application/json")
2370 .header("User-Agent", "EdgeFirst Client")
2371 .header("Authorization", format!("Bearer {}", self.token().await))
2372 .json(&request)
2373 .send()
2374 .await
2375 {
2376 Ok(res) => res,
2377 Err(err) => {
2378 warn!("Socket Error: {:?}", err);
2379 return Err(Error::MaxRetriesExceeded(attempt));
2380 }
2381 };
2382
2383 if res.status().is_success() {
2384 self.process_rpc_response(res).await
2385 } else {
2386 let err = res.error_for_status_ref().unwrap_err();
2387 warn!("HTTP Error {}: {}", err, res.text().await?);
2388 warn!(
2389 "Retrying RPC request (attempt {}/{})...",
2390 attempt + 1,
2391 MAX_RETRIES
2392 );
2393 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2394 Err(Error::MaxRetriesExceeded(attempt))
2395 }
2396 }
2397
2398 async fn process_rpc_response<RpcResult>(
2399 &self,
2400 res: reqwest::Response,
2401 ) -> Result<RpcResult, Error>
2402 where
2403 RpcResult: DeserializeOwned,
2404 {
2405 let body = res.bytes().await?;
2406
2407 if log_enabled!(Level::Trace) {
2408 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2409 }
2410
2411 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2412 Ok(response) => response,
2413 Err(err) => {
2414 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2415 return Err(err.into());
2416 }
2417 };
2418
2419 if let Some(error) = response.error {
2425 Err(Error::RpcError(error.code, error.message))
2426 } else if let Some(result) = response.result {
2427 Ok(result)
2428 } else {
2429 Err(Error::InvalidResponse)
2430 }
2431 }
2432}
2433
2434async fn parallel_foreach_items<T, F, Fut>(
2463 items: Vec<T>,
2464 progress: Option<Sender<Progress>>,
2465 work_fn: F,
2466) -> Result<(), Error>
2467where
2468 T: Send + 'static,
2469 F: Fn(T) -> Fut + Send + Sync + 'static,
2470 Fut: Future<Output = Result<(), Error>> + Send + 'static,
2471{
2472 let total = items.len();
2473 let current = Arc::new(AtomicUsize::new(0));
2474 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2475 let work_fn = Arc::new(work_fn);
2476
2477 let tasks = items
2478 .into_iter()
2479 .map(|item| {
2480 let sem = sem.clone();
2481 let current = current.clone();
2482 let progress = progress.clone();
2483 let work_fn = work_fn.clone();
2484
2485 tokio::spawn(async move {
2486 let _permit = sem.acquire().await.map_err(|_| {
2487 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2488 })?;
2489
2490 work_fn(item).await?;
2492
2493 if let Some(progress) = &progress {
2495 let current = current.fetch_add(1, Ordering::SeqCst);
2496 let _ = progress
2497 .send(Progress {
2498 current: current + 1,
2499 total,
2500 })
2501 .await;
2502 }
2503
2504 Ok::<(), Error>(())
2505 })
2506 })
2507 .collect::<Vec<_>>();
2508
2509 join_all(tasks)
2510 .await
2511 .into_iter()
2512 .collect::<Result<Vec<_>, _>>()?
2513 .into_iter()
2514 .collect::<Result<Vec<_>, _>>()?;
2515
2516 if let Some(progress) = progress {
2517 drop(progress);
2518 }
2519
2520 Ok(())
2521}
2522
2523async fn upload_multipart(
2548 http: reqwest::Client,
2549 part: SnapshotPart,
2550 path: PathBuf,
2551 total: usize,
2552 current: Arc<AtomicUsize>,
2553 progress: Option<Sender<Progress>>,
2554) -> Result<SnapshotCompleteMultipartParams, Error> {
2555 let filesize = path.metadata()?.len() as usize;
2556 let n_parts = filesize.div_ceil(PART_SIZE);
2557 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2558
2559 let key = part.key.ok_or(Error::InvalidResponse)?;
2560 let upload_id = part.upload_id;
2561
2562 let urls = part.urls.clone();
2563 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2565 EtagPart {
2566 etag: "".to_owned(),
2567 part_number: 0,
2568 };
2569 n_parts
2570 ]));
2571
2572 let tasks = (0..n_parts)
2574 .map(|part| {
2575 let http = http.clone();
2576 let url = urls[part].clone();
2577 let etags = etags.clone();
2578 let path = path.to_owned();
2579 let sem = sem.clone();
2580 let progress = progress.clone();
2581 let current = current.clone();
2582
2583 tokio::spawn(async move {
2584 let _permit = sem.acquire().await?;
2586 let mut etag = None;
2587
2588 for attempt in 0..MAX_RETRIES {
2590 match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
2591 {
2592 Ok(v) => {
2593 etag = Some(v);
2594 break;
2595 }
2596 Err(err) => {
2597 warn!("Upload Part Error: {:?}", err);
2598 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2599 }
2600 }
2601 }
2602
2603 if let Some(etag) = etag {
2604 let mut etags = etags.lock().await;
2606 etags[part] = EtagPart {
2607 etag,
2608 part_number: part + 1,
2609 };
2610
2611 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2613 if let Some(progress) = &progress {
2614 let _ = progress
2615 .send(Progress {
2616 current: current + PART_SIZE,
2617 total,
2618 })
2619 .await;
2620 }
2621
2622 Ok(())
2623 } else {
2624 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2625 }
2626 })
2627 })
2628 .collect::<Vec<_>>();
2629
2630 join_all(tasks)
2632 .await
2633 .into_iter()
2634 .collect::<Result<Vec<_>, _>>()?;
2635
2636 Ok(SnapshotCompleteMultipartParams {
2637 key,
2638 upload_id,
2639 etag_list: etags.lock().await.clone(),
2640 })
2641}
2642
2643async fn upload_part(
2644 http: reqwest::Client,
2645 url: String,
2646 path: PathBuf,
2647 part: usize,
2648 n_parts: usize,
2649) -> Result<String, Error> {
2650 let filesize = path.metadata()?.len() as usize;
2651 let mut file = File::open(path).await?;
2652 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2653 .await?;
2654 let file = file.take(PART_SIZE as u64);
2655
2656 let body_length = if part + 1 == n_parts {
2657 filesize % PART_SIZE
2658 } else {
2659 PART_SIZE
2660 };
2661
2662 let stream = FramedRead::new(file, BytesCodec::new());
2663 let body = Body::wrap_stream(stream);
2664
2665 let resp = http
2666 .put(url.clone())
2667 .header(CONTENT_LENGTH, body_length)
2668 .body(body)
2669 .send()
2670 .await?
2671 .error_for_status()?;
2672
2673 let etag = resp
2674 .headers()
2675 .get("etag")
2676 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
2677 .to_str()
2678 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
2679 .to_owned();
2680
2681 let etag = etag
2683 .strip_prefix("\"")
2684 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
2685 let etag = etag
2686 .strip_suffix("\"")
2687 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
2688
2689 Ok(etag.to_owned())
2690}
2691
2692async fn upload_file_to_presigned_url(
2697 http: reqwest::Client,
2698 url: &str,
2699 path: PathBuf,
2700) -> Result<(), Error> {
2701 let file_data = fs::read(&path).await?;
2703 let file_size = file_data.len();
2704
2705 for attempt in 1..=MAX_RETRIES {
2707 match http
2708 .put(url)
2709 .header(CONTENT_LENGTH, file_size)
2710 .body(file_data.clone())
2711 .send()
2712 .await
2713 {
2714 Ok(resp) => {
2715 if resp.status().is_success() {
2716 debug!(
2717 "Successfully uploaded file: {:?} ({} bytes)",
2718 path, file_size
2719 );
2720 return Ok(());
2721 } else {
2722 let status = resp.status();
2723 let error_text = resp.text().await.unwrap_or_default();
2724 warn!(
2725 "Upload failed [attempt {}/{}]: HTTP {} - {}",
2726 attempt, MAX_RETRIES, status, error_text
2727 );
2728 }
2729 }
2730 Err(err) => {
2731 warn!(
2732 "Upload error [attempt {}/{}]: {:?}",
2733 attempt, MAX_RETRIES, err
2734 );
2735 }
2736 }
2737
2738 if attempt < MAX_RETRIES {
2739 tokio::time::sleep(Duration::from_secs(attempt as u64)).await;
2740 }
2741 }
2742
2743 Err(Error::InvalidParameters(format!(
2744 "Failed to upload file {:?} after {} attempts",
2745 path, MAX_RETRIES
2746 )))
2747}