1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23 collections::HashMap,
24 ffi::OsStr,
25 fs::create_dir_all,
26 io::{SeekFrom, Write as _},
27 path::{Path, PathBuf},
28 sync::{
29 Arc,
30 atomic::{AtomicUsize, Ordering},
31 },
32 time::Duration,
33 vec,
34};
35use tokio::{
36 fs::{self, File},
37 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
38 sync::{RwLock, Semaphore, mpsc::Sender},
39};
40use tokio_util::codec::{BytesCodec, FramedRead};
41use walkdir::WalkDir;
42
43#[cfg(feature = "polars")]
44use polars::prelude::*;
45
46static MAX_TASKS: usize = 32;
47static MAX_RETRIES: u32 = 10;
48static PART_SIZE: usize = 100 * 1024 * 1024;
49
50fn sanitize_path_component(name: &str) -> String {
51 let trimmed = name.trim();
52 if trimmed.is_empty() {
53 return "unnamed".to_string();
54 }
55
56 let component = Path::new(trimmed)
57 .file_name()
58 .unwrap_or_else(|| OsStr::new(trimmed));
59
60 let sanitized: String = component
61 .to_string_lossy()
62 .chars()
63 .map(|c| match c {
64 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
65 _ => c,
66 })
67 .collect();
68
69 if sanitized.is_empty() {
70 "unnamed".to_string()
71 } else {
72 sanitized
73 }
74}
75
76#[derive(Debug, Clone)]
98pub struct Progress {
99 pub current: usize,
101 pub total: usize,
103}
104
105#[derive(Serialize)]
106struct RpcRequest<Params> {
107 id: u64,
108 jsonrpc: String,
109 method: String,
110 params: Option<Params>,
111}
112
113impl<T> Default for RpcRequest<T> {
114 fn default() -> Self {
115 RpcRequest {
116 id: 0,
117 jsonrpc: "2.0".to_string(),
118 method: "".to_string(),
119 params: None,
120 }
121 }
122}
123
124#[derive(Deserialize)]
125struct RpcError {
126 code: i32,
127 message: String,
128}
129
130#[derive(Deserialize)]
131struct RpcResponse<RpcResult> {
132 #[allow(dead_code)]
133 id: String,
134 #[allow(dead_code)]
135 jsonrpc: String,
136 error: Option<RpcError>,
137 result: Option<RpcResult>,
138}
139
140#[derive(Deserialize)]
141#[allow(dead_code)]
142struct EmptyResult {}
143
144#[derive(Debug, Serialize)]
145#[allow(dead_code)]
146struct SnapshotCreateParams {
147 snapshot_name: String,
148 keys: Vec<String>,
149}
150
151#[derive(Debug, Deserialize)]
152#[allow(dead_code)]
153struct SnapshotCreateResult {
154 snapshot_id: SnapshotID,
155 urls: Vec<String>,
156}
157
158#[derive(Debug, Serialize)]
159struct SnapshotCreateMultipartParams {
160 snapshot_name: String,
161 keys: Vec<String>,
162 file_sizes: Vec<usize>,
163}
164
165#[derive(Debug, Deserialize)]
166#[serde(untagged)]
167enum SnapshotCreateMultipartResultField {
168 Id(u64),
169 Part(SnapshotPart),
170}
171
172#[derive(Debug, Serialize)]
173struct SnapshotCompleteMultipartParams {
174 key: String,
175 upload_id: String,
176 etag_list: Vec<EtagPart>,
177}
178
179#[derive(Debug, Clone, Serialize)]
180struct EtagPart {
181 #[serde(rename = "ETag")]
182 etag: String,
183 #[serde(rename = "PartNumber")]
184 part_number: usize,
185}
186
187#[derive(Debug, Clone, Deserialize)]
188struct SnapshotPart {
189 key: Option<String>,
190 upload_id: String,
191 urls: Vec<String>,
192}
193
194#[derive(Debug, Serialize)]
195struct SnapshotStatusParams {
196 snapshot_id: SnapshotID,
197 status: String,
198}
199
200#[derive(Deserialize, Debug)]
201struct SnapshotStatusResult {
202 #[allow(dead_code)]
203 pub id: SnapshotID,
204 #[allow(dead_code)]
205 pub uid: String,
206 #[allow(dead_code)]
207 pub description: String,
208 #[allow(dead_code)]
209 pub date: String,
210 #[allow(dead_code)]
211 pub status: String,
212}
213
214#[derive(Serialize)]
215#[allow(dead_code)]
216struct ImageListParams {
217 images_filter: ImagesFilter,
218 image_files_filter: HashMap<String, String>,
219 only_ids: bool,
220}
221
222#[derive(Serialize)]
223#[allow(dead_code)]
224struct ImagesFilter {
225 dataset_id: DatasetID,
226}
227
228#[derive(Clone, Debug)]
277pub struct Client {
278 http: reqwest::Client,
279 url: String,
280 token: Arc<RwLock<String>>,
281 token_path: Option<PathBuf>,
282}
283
284struct FetchContext<'a> {
286 dataset_id: DatasetID,
287 annotation_set_id: Option<AnnotationSetID>,
288 groups: &'a [String],
289 types: Vec<String>,
290 labels: &'a HashMap<String, u64>,
291}
292
293impl Client {
294 pub fn new() -> Result<Self, Error> {
303 Ok(Client {
304 http: reqwest::Client::builder()
305 .read_timeout(Duration::from_secs(60))
306 .build()?,
307 url: "https://edgefirst.studio".to_string(),
308 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
309 token_path: None,
310 })
311 }
312
313 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
317 Ok(Client {
318 url: format!("https://{}.edgefirst.studio", server),
319 ..self.clone()
320 })
321 }
322
323 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
326 let params = HashMap::from([("username", username), ("password", password)]);
327 let login: LoginResult = self
328 .rpc_without_auth("auth.login".to_owned(), Some(params))
329 .await?;
330 Ok(Client {
331 token: Arc::new(tokio::sync::RwLock::new(login.token)),
332 ..self.clone()
333 })
334 }
335
336 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
339 let token_path = match token_path {
340 Some(path) => path.to_path_buf(),
341 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
342 .ok_or_else(|| {
343 Error::IoError(std::io::Error::new(
344 std::io::ErrorKind::NotFound,
345 "Could not determine user config directory",
346 ))
347 })?
348 .config_dir()
349 .join("token"),
350 };
351
352 debug!("Using token path: {:?}", token_path);
353
354 let token = match token_path.exists() {
355 true => std::fs::read_to_string(&token_path)?,
356 false => "".to_string(),
357 };
358
359 if !token.is_empty() {
360 let client = self.with_token(&token)?;
361 Ok(Client {
362 token_path: Some(token_path),
363 ..client
364 })
365 } else {
366 Ok(Client {
367 token_path: Some(token_path),
368 ..self.clone()
369 })
370 }
371 }
372
373 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
375 if token.is_empty() {
376 return Ok(self.clone());
377 }
378
379 let token_parts: Vec<&str> = token.split('.').collect();
380 if token_parts.len() != 3 {
381 return Err(Error::InvalidToken);
382 }
383
384 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
385 .decode(token_parts[1])
386 .map_err(|_| Error::InvalidToken)?;
387 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
388 let server = match payload.get("database") {
389 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
390 None => return Err(Error::InvalidToken),
391 };
392
393 Ok(Client {
394 url: format!("https://{}.edgefirst.studio", server),
395 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
396 ..self.clone()
397 })
398 }
399
400 pub async fn save_token(&self) -> Result<(), Error> {
401 let path = self.token_path.clone().unwrap_or_else(|| {
402 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
403 .map(|dirs| dirs.config_dir().join("token"))
404 .unwrap_or_else(|| PathBuf::from(".token"))
405 });
406
407 create_dir_all(path.parent().ok_or_else(|| {
408 Error::IoError(std::io::Error::new(
409 std::io::ErrorKind::InvalidInput,
410 "Token path has no parent directory",
411 ))
412 })?)?;
413 let mut file = std::fs::File::create(&path)?;
414 file.write_all(self.token.read().await.as_bytes())?;
415
416 debug!("Saved token to {:?}", path);
417
418 Ok(())
419 }
420
421 pub async fn version(&self) -> Result<String, Error> {
424 let version: HashMap<String, String> = self
425 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
426 .await?;
427 let version = version.get("version").ok_or(Error::InvalidResponse)?;
428 Ok(version.to_owned())
429 }
430
431 pub async fn logout(&self) -> Result<(), Error> {
435 {
436 let mut token = self.token.write().await;
437 *token = "".to_string();
438 }
439
440 if let Some(path) = &self.token_path
441 && path.exists()
442 {
443 fs::remove_file(path).await?;
444 }
445
446 Ok(())
447 }
448
449 pub async fn token(&self) -> String {
453 self.token.read().await.clone()
454 }
455
456 pub async fn verify_token(&self) -> Result<(), Error> {
461 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
462 .await?;
463 Ok::<(), Error>(())
464 }
465
466 pub async fn renew_token(&self) -> Result<(), Error> {
471 let params = HashMap::from([("username".to_string(), self.username().await?)]);
472 let result: LoginResult = self
473 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
474 .await?;
475
476 {
477 let mut token = self.token.write().await;
478 *token = result.token;
479 }
480
481 if self.token_path.is_some() {
482 self.save_token().await?;
483 }
484
485 Ok(())
486 }
487
488 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
489 let token = self.token.read().await;
490 if token.is_empty() {
491 return Err(Error::EmptyToken);
492 }
493
494 let token_parts: Vec<&str> = token.split('.').collect();
495 if token_parts.len() != 3 {
496 return Err(Error::InvalidToken);
497 }
498
499 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
500 .decode(token_parts[1])
501 .map_err(|_| Error::InvalidToken)?;
502 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
503 match payload.get(field) {
504 Some(value) => Ok(value.to_owned()),
505 None => Err(Error::InvalidToken),
506 }
507 }
508
509 pub fn url(&self) -> &str {
511 &self.url
512 }
513
514 pub async fn username(&self) -> Result<String, Error> {
516 match self.token_field("username").await? {
517 serde_json::Value::String(username) => Ok(username),
518 _ => Err(Error::InvalidToken),
519 }
520 }
521
522 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
524 let ts = match self.token_field("exp").await? {
525 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
526 _ => return Err(Error::InvalidToken),
527 };
528
529 match DateTime::<Utc>::from_timestamp_secs(ts) {
530 Some(dt) => Ok(dt),
531 None => Err(Error::InvalidToken),
532 }
533 }
534
535 pub async fn organization(&self) -> Result<Organization, Error> {
537 self.rpc::<(), Organization>("org.get".to_owned(), None)
538 .await
539 }
540
541 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
549 let projects = self
550 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
551 .await?;
552 if let Some(name) = name {
553 Ok(projects
554 .into_iter()
555 .filter(|p| p.name().contains(name))
556 .collect())
557 } else {
558 Ok(projects)
559 }
560 }
561
562 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
565 let params = HashMap::from([("project_id", project_id)]);
566 self.rpc("project.get".to_owned(), Some(params)).await
567 }
568
569 pub async fn datasets(
573 &self,
574 project_id: ProjectID,
575 name: Option<&str>,
576 ) -> Result<Vec<Dataset>, Error> {
577 let params = HashMap::from([("project_id", project_id)]);
578 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
579 if let Some(name) = name {
580 Ok(datasets
581 .into_iter()
582 .filter(|d| d.name().contains(name))
583 .collect())
584 } else {
585 Ok(datasets)
586 }
587 }
588
589 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
592 let params = HashMap::from([("dataset_id", dataset_id)]);
593 self.rpc("dataset.get".to_owned(), Some(params)).await
594 }
595
596 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
598 let params = HashMap::from([("dataset_id", dataset_id)]);
599 self.rpc("label.list".to_owned(), Some(params)).await
600 }
601
602 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
604 let new_label = NewLabel {
605 dataset_id,
606 labels: vec![NewLabelObject {
607 name: name.to_owned(),
608 }],
609 };
610 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
611 Ok(())
612 }
613
614 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
617 let params = HashMap::from([("label_id", label_id)]);
618 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
619 Ok(())
620 }
621
622 pub async fn create_dataset(
634 &self,
635 project_id: &str,
636 name: &str,
637 description: Option<&str>,
638 ) -> Result<DatasetID, Error> {
639 let mut params = HashMap::new();
640 params.insert("project_id", project_id);
641 params.insert("name", name);
642 if let Some(desc) = description {
643 params.insert("description", desc);
644 }
645
646 #[derive(Deserialize)]
647 struct CreateDatasetResult {
648 id: DatasetID,
649 }
650
651 let result: CreateDatasetResult =
652 self.rpc("dataset.create".to_owned(), Some(params)).await?;
653 Ok(result.id)
654 }
655
656 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
666 let params = HashMap::from([("id", dataset_id)]);
667 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
668 Ok(())
669 }
670
671 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
675 #[derive(Serialize)]
676 struct Params {
677 dataset_id: DatasetID,
678 label_id: u64,
679 label_name: String,
680 label_index: u64,
681 }
682
683 let _: String = self
684 .rpc(
685 "label.update".to_owned(),
686 Some(Params {
687 dataset_id: label.dataset_id(),
688 label_id: label.id(),
689 label_name: label.name().to_owned(),
690 label_index: label.index(),
691 }),
692 )
693 .await?;
694 Ok(())
695 }
696
697 pub async fn download_dataset(
698 &self,
699 dataset_id: DatasetID,
700 groups: &[String],
701 file_types: &[FileType],
702 output: PathBuf,
703 progress: Option<Sender<Progress>>,
704 ) -> Result<(), Error> {
705 let samples = self
706 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
707 .await?;
708 fs::create_dir_all(&output).await?;
709
710 let client = self.clone();
711 let file_types = file_types.to_vec();
712 let output = output.clone();
713
714 parallel_foreach_items(samples, progress, move |sample| {
715 let client = client.clone();
716 let file_types = file_types.clone();
717 let output = output.clone();
718
719 async move {
720 for file_type in file_types {
721 if let Some(data) = sample.download(&client, file_type.clone()).await? {
722 let (file_ext, is_image) = match file_type.clone() {
723 FileType::Image => (
724 infer::get(&data)
725 .expect("Failed to identify image file format for sample")
726 .extension()
727 .to_string(),
728 true,
729 ),
730 other => (other.to_string(), false),
731 };
732
733 let sequence_dir = sample
738 .sequence_name()
739 .map(|name| sanitize_path_component(name));
740
741 let target_dir = sequence_dir
742 .map(|seq| output.join(seq))
743 .unwrap_or_else(|| output.clone());
744 fs::create_dir_all(&target_dir).await?;
745
746 let sanitized_sample_name = sample
747 .name()
748 .map(|name| sanitize_path_component(&name))
749 .unwrap_or_else(|| "unknown".to_string());
750
751 let image_name = sample.image_name().map(sanitize_path_component);
752
753 let file_name = if is_image {
754 image_name.unwrap_or_else(|| {
755 format!("{}.{}", sanitized_sample_name, file_ext)
756 })
757 } else {
758 format!("{}.{}", sanitized_sample_name, file_ext)
759 };
760
761 let file_path = target_dir.join(&file_name);
762
763 let mut file = File::create(&file_path).await?;
764 file.write_all(&data).await?;
765 } else {
766 warn!(
767 "No data for sample: {}",
768 sample
769 .id()
770 .map(|id| id.to_string())
771 .unwrap_or_else(|| "unknown".to_string())
772 );
773 }
774 }
775
776 Ok(())
777 }
778 })
779 .await
780 }
781
782 pub async fn annotation_sets(
784 &self,
785 dataset_id: DatasetID,
786 ) -> Result<Vec<AnnotationSet>, Error> {
787 let params = HashMap::from([("dataset_id", dataset_id)]);
788 self.rpc("annset.list".to_owned(), Some(params)).await
789 }
790
791 pub async fn create_annotation_set(
803 &self,
804 dataset_id: DatasetID,
805 name: &str,
806 description: Option<&str>,
807 ) -> Result<AnnotationSetID, Error> {
808 #[derive(Serialize)]
809 struct Params<'a> {
810 dataset_id: DatasetID,
811 name: &'a str,
812 operator: &'a str,
813 #[serde(skip_serializing_if = "Option::is_none")]
814 description: Option<&'a str>,
815 }
816
817 #[derive(Deserialize)]
818 struct CreateAnnotationSetResult {
819 id: AnnotationSetID,
820 }
821
822 let username = self.username().await?;
823 let result: CreateAnnotationSetResult = self
824 .rpc(
825 "annset.add".to_owned(),
826 Some(Params {
827 dataset_id,
828 name,
829 operator: &username,
830 description,
831 }),
832 )
833 .await?;
834 Ok(result.id)
835 }
836
837 pub async fn delete_annotation_set(
848 &self,
849 annotation_set_id: AnnotationSetID,
850 ) -> Result<(), Error> {
851 let params = HashMap::from([("id", annotation_set_id)]);
852 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
853 Ok(())
854 }
855
856 pub async fn annotation_set(
858 &self,
859 annotation_set_id: AnnotationSetID,
860 ) -> Result<AnnotationSet, Error> {
861 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
862 self.rpc("annset.get".to_owned(), Some(params)).await
863 }
864
865 pub async fn annotations(
878 &self,
879 annotation_set_id: AnnotationSetID,
880 groups: &[String],
881 annotation_types: &[AnnotationType],
882 progress: Option<Sender<Progress>>,
883 ) -> Result<Vec<Annotation>, Error> {
884 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
885 let labels = self
886 .labels(dataset_id)
887 .await?
888 .into_iter()
889 .map(|label| (label.name().to_string(), label.index()))
890 .collect::<HashMap<_, _>>();
891 let total = self
892 .samples_count(
893 dataset_id,
894 Some(annotation_set_id),
895 annotation_types,
896 groups,
897 &[],
898 )
899 .await?
900 .total as usize;
901
902 if total == 0 {
903 return Ok(vec![]);
904 }
905
906 let context = FetchContext {
907 dataset_id,
908 annotation_set_id: Some(annotation_set_id),
909 groups,
910 types: annotation_types.iter().map(|t| t.to_string()).collect(),
911 labels: &labels,
912 };
913
914 self.fetch_annotations_paginated(context, total, progress)
915 .await
916 }
917
918 async fn fetch_annotations_paginated(
919 &self,
920 context: FetchContext<'_>,
921 total: usize,
922 progress: Option<Sender<Progress>>,
923 ) -> Result<Vec<Annotation>, Error> {
924 let mut annotations = vec![];
925 let mut continue_token: Option<String> = None;
926 let mut current = 0;
927
928 loop {
929 let params = SamplesListParams {
930 dataset_id: context.dataset_id,
931 annotation_set_id: context.annotation_set_id,
932 types: context.types.clone(),
933 group_names: context.groups.to_vec(),
934 continue_token,
935 };
936
937 let result: SamplesListResult =
938 self.rpc("samples.list".to_owned(), Some(params)).await?;
939 current += result.samples.len();
940 continue_token = result.continue_token;
941
942 if result.samples.is_empty() {
943 break;
944 }
945
946 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
947
948 if let Some(progress) = &progress {
949 let _ = progress.send(Progress { current, total }).await;
950 }
951
952 match &continue_token {
953 Some(token) if !token.is_empty() => continue,
954 _ => break,
955 }
956 }
957
958 drop(progress);
959 Ok(annotations)
960 }
961
962 fn process_sample_annotations(
963 &self,
964 samples: &[Sample],
965 labels: &HashMap<String, u64>,
966 annotations: &mut Vec<Annotation>,
967 ) {
968 for sample in samples {
969 if sample.annotations().is_empty() {
970 let mut annotation = Annotation::new();
971 annotation.set_sample_id(sample.id());
972 annotation.set_name(sample.name());
973 annotation.set_sequence_name(sample.sequence_name().cloned());
974 annotation.set_frame_number(sample.frame_number());
975 annotation.set_group(sample.group().cloned());
976 annotations.push(annotation);
977 continue;
978 }
979
980 for annotation in sample.annotations() {
981 let mut annotation = annotation.clone();
982 annotation.set_sample_id(sample.id());
983 annotation.set_name(sample.name());
984 annotation.set_sequence_name(sample.sequence_name().cloned());
985 annotation.set_frame_number(sample.frame_number());
986 annotation.set_group(sample.group().cloned());
987 Self::set_label_index_from_map(&mut annotation, labels);
988 annotations.push(annotation);
989 }
990 }
991 }
992
993 fn parse_frame_from_image_name(
1001 image_name: Option<&String>,
1002 sequence_name: Option<&String>,
1003 ) -> Option<u32> {
1004 use std::path::Path;
1005
1006 let sequence = sequence_name?;
1007 let name = image_name?;
1008
1009 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1011
1012 stem.strip_prefix(sequence)
1014 .and_then(|suffix| suffix.strip_prefix('_'))
1015 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1016 }
1017
1018 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1020 if let Some(label) = annotation.label() {
1021 annotation.set_label_index(Some(labels[label.as_str()]));
1022 }
1023 }
1024
1025 pub async fn samples_count(
1026 &self,
1027 dataset_id: DatasetID,
1028 annotation_set_id: Option<AnnotationSetID>,
1029 annotation_types: &[AnnotationType],
1030 groups: &[String],
1031 types: &[FileType],
1032 ) -> Result<SamplesCountResult, Error> {
1033 let types = annotation_types
1034 .iter()
1035 .map(|t| t.to_string())
1036 .chain(types.iter().map(|t| t.to_string()))
1037 .collect::<Vec<_>>();
1038
1039 let params = SamplesListParams {
1040 dataset_id,
1041 annotation_set_id,
1042 group_names: groups.to_vec(),
1043 types,
1044 continue_token: None,
1045 };
1046
1047 self.rpc("samples.count".to_owned(), Some(params)).await
1048 }
1049
1050 pub async fn samples(
1051 &self,
1052 dataset_id: DatasetID,
1053 annotation_set_id: Option<AnnotationSetID>,
1054 annotation_types: &[AnnotationType],
1055 groups: &[String],
1056 types: &[FileType],
1057 progress: Option<Sender<Progress>>,
1058 ) -> Result<Vec<Sample>, Error> {
1059 let types_vec = annotation_types
1060 .iter()
1061 .map(|t| t.to_string())
1062 .chain(types.iter().map(|t| t.to_string()))
1063 .collect::<Vec<_>>();
1064 let labels = self
1065 .labels(dataset_id)
1066 .await?
1067 .into_iter()
1068 .map(|label| (label.name().to_string(), label.index()))
1069 .collect::<HashMap<_, _>>();
1070 let total = self
1071 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1072 .await?
1073 .total as usize;
1074
1075 if total == 0 {
1076 return Ok(vec![]);
1077 }
1078
1079 let context = FetchContext {
1080 dataset_id,
1081 annotation_set_id,
1082 groups,
1083 types: types_vec,
1084 labels: &labels,
1085 };
1086
1087 self.fetch_samples_paginated(context, total, progress).await
1088 }
1089
1090 async fn fetch_samples_paginated(
1091 &self,
1092 context: FetchContext<'_>,
1093 total: usize,
1094 progress: Option<Sender<Progress>>,
1095 ) -> Result<Vec<Sample>, Error> {
1096 let mut samples = vec![];
1097 let mut continue_token: Option<String> = None;
1098 let mut current = 0;
1099
1100 loop {
1101 let params = SamplesListParams {
1102 dataset_id: context.dataset_id,
1103 annotation_set_id: context.annotation_set_id,
1104 types: context.types.clone(),
1105 group_names: context.groups.to_vec(),
1106 continue_token: continue_token.clone(),
1107 };
1108
1109 let result: SamplesListResult =
1110 self.rpc("samples.list".to_owned(), Some(params)).await?;
1111 current += result.samples.len();
1112 continue_token = result.continue_token;
1113
1114 if result.samples.is_empty() {
1115 break;
1116 }
1117
1118 samples.append(
1119 &mut result
1120 .samples
1121 .into_iter()
1122 .map(|s| {
1123 let frame_number = s.frame_number.or_else(|| {
1128 Self::parse_frame_from_image_name(
1129 s.image_name.as_ref(),
1130 s.sequence_name.as_ref(),
1131 )
1132 });
1133
1134 let mut anns = s.annotations().to_vec();
1135 for ann in &mut anns {
1136 ann.set_name(s.name());
1138 ann.set_group(s.group().cloned());
1139 ann.set_sequence_name(s.sequence_name().cloned());
1140 ann.set_frame_number(frame_number);
1141 Self::set_label_index_from_map(ann, context.labels);
1142 }
1143 s.with_annotations(anns).with_frame_number(frame_number)
1144 })
1145 .collect::<Vec<_>>(),
1146 );
1147
1148 if let Some(progress) = &progress {
1149 let _ = progress.send(Progress { current, total }).await;
1150 }
1151
1152 match &continue_token {
1153 Some(token) if !token.is_empty() => continue,
1154 _ => break,
1155 }
1156 }
1157
1158 drop(progress);
1159 Ok(samples)
1160 }
1161
1162 pub async fn populate_samples(
1254 &self,
1255 dataset_id: DatasetID,
1256 annotation_set_id: Option<AnnotationSetID>,
1257 samples: Vec<Sample>,
1258 progress: Option<Sender<Progress>>,
1259 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1260 use crate::api::SamplesPopulateParams;
1261
1262 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1264
1265 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1267
1268 let has_files_to_upload = !files_to_upload.is_empty();
1269
1270 let params = SamplesPopulateParams {
1272 dataset_id,
1273 annotation_set_id,
1274 presigned_urls: Some(has_files_to_upload),
1275 samples,
1276 };
1277
1278 let results: Vec<crate::SamplesPopulateResult> = self
1279 .rpc("samples.populate2".to_owned(), Some(params))
1280 .await?;
1281
1282 if has_files_to_upload {
1284 self.upload_sample_files(&results, files_to_upload, progress)
1285 .await?;
1286 }
1287
1288 Ok(results)
1289 }
1290
1291 fn prepare_samples_for_upload(
1292 &self,
1293 samples: Vec<Sample>,
1294 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1295 ) -> Result<Vec<Sample>, Error> {
1296 Ok(samples
1297 .into_iter()
1298 .map(|mut sample| {
1299 if sample.uuid.is_none() {
1301 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1302 }
1303
1304 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1305
1306 let files_copy = sample.files.clone();
1308 let updated_files: Vec<crate::SampleFile> = files_copy
1309 .iter()
1310 .map(|file| {
1311 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1312 })
1313 .collect();
1314
1315 sample.files = updated_files;
1316 sample
1317 })
1318 .collect())
1319 }
1320
1321 fn process_sample_file(
1322 &self,
1323 file: &crate::SampleFile,
1324 sample_uuid: &str,
1325 sample: &mut Sample,
1326 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1327 ) -> crate::SampleFile {
1328 use std::path::Path;
1329
1330 if let Some(filename) = file.filename() {
1331 let path = Path::new(filename);
1332
1333 if path.exists()
1335 && path.is_file()
1336 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1337 {
1338 if file.file_type() == "image"
1340 && (sample.width.is_none() || sample.height.is_none())
1341 && let Ok(size) = imagesize::size(path)
1342 {
1343 sample.width = Some(size.width as u32);
1344 sample.height = Some(size.height as u32);
1345 }
1346
1347 files_to_upload.push((
1349 sample_uuid.to_string(),
1350 file.file_type().to_string(),
1351 path.to_path_buf(),
1352 basename.to_string(),
1353 ));
1354
1355 return crate::SampleFile::with_filename(
1357 file.file_type().to_string(),
1358 basename.to_string(),
1359 );
1360 }
1361 }
1362 file.clone()
1364 }
1365
1366 async fn upload_sample_files(
1367 &self,
1368 results: &[crate::SamplesPopulateResult],
1369 files_to_upload: Vec<(String, String, PathBuf, String)>,
1370 progress: Option<Sender<Progress>>,
1371 ) -> Result<(), Error> {
1372 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1374 for (uuid, _file_type, path, basename) in files_to_upload {
1375 upload_map.insert((uuid, basename), path);
1376 }
1377
1378 let http = self.http.clone();
1379
1380 let upload_tasks: Vec<_> = results
1382 .iter()
1383 .map(|result| (result.uuid.clone(), result.urls.clone()))
1384 .collect();
1385
1386 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1387 let http = http.clone();
1388 let upload_map = upload_map.clone();
1389
1390 async move {
1391 for url_info in &urls {
1393 if let Some(local_path) =
1394 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1395 {
1396 upload_file_to_presigned_url(
1398 http.clone(),
1399 &url_info.url,
1400 local_path.clone(),
1401 )
1402 .await?;
1403 }
1404 }
1405
1406 Ok(())
1407 }
1408 })
1409 .await
1410 }
1411
1412 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1413 for attempt in 1..MAX_RETRIES {
1414 trace!("Download attempt {}/{} for URL: {}", attempt, MAX_RETRIES, url);
1415 let resp = match self.http.get(url).send().await {
1416 Ok(resp) => resp,
1417 Err(err) => {
1418 warn!(
1419 "Socket Error [retry {}/{}]: {:?}",
1420 attempt, MAX_RETRIES, err
1421 );
1422 trace!(
1423 "Socket error details - Kind: {:?}, URL: {}, Is timeout: {}, Is connect: {}",
1424 err.status(),
1425 url,
1426 err.is_timeout(),
1427 err.is_connect()
1428 );
1429 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1430 continue;
1431 }
1432 };
1433
1434 trace!("Download response status: {} for URL: {}", resp.status(), url);
1435 match resp.bytes().await {
1436 Ok(body) => {
1437 trace!("Successfully downloaded {} bytes from {}", body.len(), url);
1438 return Ok(body.to_vec());
1439 }
1440 Err(err) => {
1441 warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
1442 trace!("HTTP error details - Is timeout: {}, Is connect: {}", err.is_timeout(), err.is_connect());
1443 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1444 continue;
1445 }
1446 };
1447 }
1448
1449 error!("Max retries ({}) exceeded for download of URL: {}", MAX_RETRIES, url);
1450 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1451 }
1452
1453 #[deprecated(
1493 since = "0.8.0",
1494 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
1495 )]
1496 #[cfg(feature = "polars")]
1497 pub async fn annotations_dataframe(
1498 &self,
1499 annotation_set_id: AnnotationSetID,
1500 groups: &[String],
1501 types: &[AnnotationType],
1502 progress: Option<Sender<Progress>>,
1503 ) -> Result<DataFrame, Error> {
1504 use crate::dataset::annotations_dataframe;
1505
1506 let annotations = self
1507 .annotations(annotation_set_id, groups, types, progress)
1508 .await?;
1509 #[allow(deprecated)]
1510 annotations_dataframe(&annotations)
1511 }
1512
1513 #[cfg(feature = "polars")]
1550 pub async fn samples_dataframe(
1551 &self,
1552 dataset_id: DatasetID,
1553 annotation_set_id: Option<AnnotationSetID>,
1554 groups: &[String],
1555 types: &[AnnotationType],
1556 progress: Option<Sender<Progress>>,
1557 ) -> Result<DataFrame, Error> {
1558 use crate::dataset::samples_dataframe;
1559
1560 let samples = self
1561 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
1562 .await?;
1563 samples_dataframe(&samples)
1564 }
1565
1566 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1569 let snapshots: Vec<Snapshot> = self
1570 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1571 .await?;
1572 if let Some(name) = name {
1573 Ok(snapshots
1574 .into_iter()
1575 .filter(|s| s.description().contains(name))
1576 .collect())
1577 } else {
1578 Ok(snapshots)
1579 }
1580 }
1581
1582 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1584 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1585 self.rpc("snapshots.get".to_owned(), Some(params)).await
1586 }
1587
1588 pub async fn create_snapshot(
1595 &self,
1596 path: &str,
1597 progress: Option<Sender<Progress>>,
1598 ) -> Result<Snapshot, Error> {
1599 let path = Path::new(path);
1600
1601 if path.is_dir() {
1602 let path_str = path.to_str().ok_or_else(|| {
1603 Error::IoError(std::io::Error::new(
1604 std::io::ErrorKind::InvalidInput,
1605 "Path contains invalid UTF-8",
1606 ))
1607 })?;
1608 return self.create_snapshot_folder(path_str, progress).await;
1609 }
1610
1611 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1612 Error::IoError(std::io::Error::new(
1613 std::io::ErrorKind::InvalidInput,
1614 "Invalid filename",
1615 ))
1616 })?;
1617 let total = path.metadata()?.len() as usize;
1618 let current = Arc::new(AtomicUsize::new(0));
1619
1620 if let Some(progress) = &progress {
1621 let _ = progress.send(Progress { current: 0, total }).await;
1622 }
1623
1624 let params = SnapshotCreateMultipartParams {
1625 snapshot_name: name.to_owned(),
1626 keys: vec![name.to_owned()],
1627 file_sizes: vec![total],
1628 };
1629 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1630 .rpc(
1631 "snapshots.create_upload_url_multipart".to_owned(),
1632 Some(params),
1633 )
1634 .await?;
1635
1636 let snapshot_id = match multipart.get("snapshot_id") {
1637 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1638 _ => return Err(Error::InvalidResponse),
1639 };
1640
1641 let snapshot = self.snapshot(snapshot_id).await?;
1642 let part_prefix = snapshot
1643 .path()
1644 .split("::/")
1645 .last()
1646 .ok_or(Error::InvalidResponse)?
1647 .to_owned();
1648 let part_key = format!("{}/{}", part_prefix, name);
1649 let mut part = match multipart.get(&part_key) {
1650 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1651 _ => return Err(Error::InvalidResponse),
1652 }
1653 .clone();
1654 part.key = Some(part_key);
1655
1656 let params = upload_multipart(
1657 self.http.clone(),
1658 part.clone(),
1659 path.to_path_buf(),
1660 total,
1661 current,
1662 progress.clone(),
1663 )
1664 .await?;
1665
1666 let complete: String = self
1667 .rpc(
1668 "snapshots.complete_multipart_upload".to_owned(),
1669 Some(params),
1670 )
1671 .await?;
1672 debug!("Snapshot Multipart Complete: {:?}", complete);
1673
1674 let params: SnapshotStatusParams = SnapshotStatusParams {
1675 snapshot_id,
1676 status: "available".to_owned(),
1677 };
1678 let _: SnapshotStatusResult = self
1679 .rpc("snapshots.update".to_owned(), Some(params))
1680 .await?;
1681
1682 if let Some(progress) = progress {
1683 drop(progress);
1684 }
1685
1686 self.snapshot(snapshot_id).await
1687 }
1688
1689 async fn create_snapshot_folder(
1690 &self,
1691 path: &str,
1692 progress: Option<Sender<Progress>>,
1693 ) -> Result<Snapshot, Error> {
1694 let path = Path::new(path);
1695 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1696 Error::IoError(std::io::Error::new(
1697 std::io::ErrorKind::InvalidInput,
1698 "Invalid directory name",
1699 ))
1700 })?;
1701
1702 let files = WalkDir::new(path)
1703 .into_iter()
1704 .filter_map(|entry| entry.ok())
1705 .filter(|entry| entry.file_type().is_file())
1706 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
1707 .collect::<Vec<_>>();
1708
1709 let total: usize = files
1710 .iter()
1711 .filter_map(|file| path.join(file).metadata().ok())
1712 .map(|metadata| metadata.len() as usize)
1713 .sum();
1714 let current = Arc::new(AtomicUsize::new(0));
1715
1716 if let Some(progress) = &progress {
1717 let _ = progress.send(Progress { current: 0, total }).await;
1718 }
1719
1720 let keys = files
1721 .iter()
1722 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
1723 .collect::<Vec<_>>();
1724 let file_sizes = files
1725 .iter()
1726 .filter_map(|key| path.join(key).metadata().ok())
1727 .map(|metadata| metadata.len() as usize)
1728 .collect::<Vec<_>>();
1729
1730 let params = SnapshotCreateMultipartParams {
1731 snapshot_name: name.to_owned(),
1732 keys,
1733 file_sizes,
1734 };
1735
1736 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1737 .rpc(
1738 "snapshots.create_upload_url_multipart".to_owned(),
1739 Some(params),
1740 )
1741 .await?;
1742
1743 let snapshot_id = match multipart.get("snapshot_id") {
1744 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1745 _ => return Err(Error::InvalidResponse),
1746 };
1747
1748 let snapshot = self.snapshot(snapshot_id).await?;
1749 let part_prefix = snapshot
1750 .path()
1751 .split("::/")
1752 .last()
1753 .ok_or(Error::InvalidResponse)?
1754 .to_owned();
1755
1756 for file in files {
1757 let file_str = file.to_str().ok_or_else(|| {
1758 Error::IoError(std::io::Error::new(
1759 std::io::ErrorKind::InvalidInput,
1760 "File path contains invalid UTF-8",
1761 ))
1762 })?;
1763 let part_key = format!("{}/{}", part_prefix, file_str);
1764 let mut part = match multipart.get(&part_key) {
1765 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1766 _ => return Err(Error::InvalidResponse),
1767 }
1768 .clone();
1769 part.key = Some(part_key);
1770
1771 let params = upload_multipart(
1772 self.http.clone(),
1773 part.clone(),
1774 path.join(file),
1775 total,
1776 current.clone(),
1777 progress.clone(),
1778 )
1779 .await?;
1780
1781 let complete: String = self
1782 .rpc(
1783 "snapshots.complete_multipart_upload".to_owned(),
1784 Some(params),
1785 )
1786 .await?;
1787 debug!("Snapshot Part Complete: {:?}", complete);
1788 }
1789
1790 let params = SnapshotStatusParams {
1791 snapshot_id,
1792 status: "available".to_owned(),
1793 };
1794 let _: SnapshotStatusResult = self
1795 .rpc("snapshots.update".to_owned(), Some(params))
1796 .await?;
1797
1798 if let Some(progress) = progress {
1799 drop(progress);
1800 }
1801
1802 self.snapshot(snapshot_id).await
1803 }
1804
1805 pub async fn download_snapshot(
1810 &self,
1811 snapshot_id: SnapshotID,
1812 output: PathBuf,
1813 progress: Option<Sender<Progress>>,
1814 ) -> Result<(), Error> {
1815 fs::create_dir_all(&output).await?;
1816
1817 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1818 let items: HashMap<String, String> = self
1819 .rpc("snapshots.create_download_url".to_owned(), Some(params))
1820 .await?;
1821
1822 let total = Arc::new(AtomicUsize::new(0));
1823 let current = Arc::new(AtomicUsize::new(0));
1824 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1825
1826 let tasks = items
1827 .iter()
1828 .map(|(key, url)| {
1829 let http = self.http.clone();
1830 let key = key.clone();
1831 let url = url.clone();
1832 let output = output.clone();
1833 let progress = progress.clone();
1834 let current = current.clone();
1835 let total = total.clone();
1836 let sem = sem.clone();
1837
1838 tokio::spawn(async move {
1839 let _permit = sem.acquire().await.map_err(|_| {
1840 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1841 })?;
1842 let res = http.get(url).send().await?;
1843 let content_length = res.content_length().unwrap_or(0) as usize;
1844
1845 if let Some(progress) = &progress {
1846 let total = total.fetch_add(content_length, Ordering::SeqCst);
1847 let _ = progress
1848 .send(Progress {
1849 current: current.load(Ordering::SeqCst),
1850 total: total + content_length,
1851 })
1852 .await;
1853 }
1854
1855 let mut file = File::create(output.join(key)).await?;
1856 let mut stream = res.bytes_stream();
1857
1858 while let Some(chunk) = stream.next().await {
1859 let chunk = chunk?;
1860 file.write_all(&chunk).await?;
1861 let len = chunk.len();
1862
1863 if let Some(progress) = &progress {
1864 let total = total.load(Ordering::SeqCst);
1865 let current = current.fetch_add(len, Ordering::SeqCst);
1866
1867 let _ = progress
1868 .send(Progress {
1869 current: current + len,
1870 total,
1871 })
1872 .await;
1873 }
1874 }
1875
1876 Ok::<(), Error>(())
1877 })
1878 })
1879 .collect::<Vec<_>>();
1880
1881 join_all(tasks)
1882 .await
1883 .into_iter()
1884 .collect::<Result<Vec<_>, _>>()?
1885 .into_iter()
1886 .collect::<Result<Vec<_>, _>>()?;
1887
1888 Ok(())
1889 }
1890
1891 #[allow(clippy::too_many_arguments)]
1906 pub async fn restore_snapshot(
1907 &self,
1908 project_id: ProjectID,
1909 snapshot_id: SnapshotID,
1910 topics: &[String],
1911 autolabel: &[String],
1912 autodepth: bool,
1913 dataset_name: Option<&str>,
1914 dataset_description: Option<&str>,
1915 ) -> Result<SnapshotRestoreResult, Error> {
1916 let params = SnapshotRestore {
1917 project_id,
1918 snapshot_id,
1919 fps: 1,
1920 autodepth,
1921 agtg_pipeline: !autolabel.is_empty(),
1922 autolabel: autolabel.to_vec(),
1923 topics: topics.to_vec(),
1924 dataset_name: dataset_name.map(|s| s.to_owned()),
1925 dataset_description: dataset_description.map(|s| s.to_owned()),
1926 };
1927 self.rpc("snapshots.restore".to_owned(), Some(params)).await
1928 }
1929
1930 pub async fn experiments(
1939 &self,
1940 project_id: ProjectID,
1941 name: Option<&str>,
1942 ) -> Result<Vec<Experiment>, Error> {
1943 let params = HashMap::from([("project_id", project_id)]);
1944 let experiments: Vec<Experiment> =
1945 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1946 if let Some(name) = name {
1947 Ok(experiments
1948 .into_iter()
1949 .filter(|e| e.name().contains(name))
1950 .collect())
1951 } else {
1952 Ok(experiments)
1953 }
1954 }
1955
1956 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1959 let params = HashMap::from([("trainer_id", experiment_id)]);
1960 self.rpc("trainer.get".to_owned(), Some(params)).await
1961 }
1962
1963 pub async fn training_sessions(
1972 &self,
1973 experiment_id: ExperimentID,
1974 name: Option<&str>,
1975 ) -> Result<Vec<TrainingSession>, Error> {
1976 let params = HashMap::from([("trainer_id", experiment_id)]);
1977 let sessions: Vec<TrainingSession> = self
1978 .rpc("trainer.session.list".to_owned(), Some(params))
1979 .await?;
1980 if let Some(name) = name {
1981 Ok(sessions
1982 .into_iter()
1983 .filter(|s| s.name().contains(name))
1984 .collect())
1985 } else {
1986 Ok(sessions)
1987 }
1988 }
1989
1990 pub async fn training_session(
1993 &self,
1994 session_id: TrainingSessionID,
1995 ) -> Result<TrainingSession, Error> {
1996 let params = HashMap::from([("trainer_session_id", session_id)]);
1997 self.rpc("trainer.session.get".to_owned(), Some(params))
1998 .await
1999 }
2000
2001 pub async fn validation_sessions(
2003 &self,
2004 project_id: ProjectID,
2005 ) -> Result<Vec<ValidationSession>, Error> {
2006 let params = HashMap::from([("project_id", project_id)]);
2007 self.rpc("validate.session.list".to_owned(), Some(params))
2008 .await
2009 }
2010
2011 pub async fn validation_session(
2013 &self,
2014 session_id: ValidationSessionID,
2015 ) -> Result<ValidationSession, Error> {
2016 let params = HashMap::from([("validate_session_id", session_id)]);
2017 self.rpc("validate.session.get".to_owned(), Some(params))
2018 .await
2019 }
2020
2021 pub async fn artifacts(
2024 &self,
2025 training_session_id: TrainingSessionID,
2026 ) -> Result<Vec<Artifact>, Error> {
2027 let params = HashMap::from([("training_session_id", training_session_id)]);
2028 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2029 .await
2030 }
2031
2032 pub async fn download_artifact(
2038 &self,
2039 training_session_id: TrainingSessionID,
2040 modelname: &str,
2041 filename: Option<PathBuf>,
2042 progress: Option<Sender<Progress>>,
2043 ) -> Result<(), Error> {
2044 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2045 let resp = self
2046 .http
2047 .get(format!(
2048 "{}/download_model?training_session_id={}&file={}",
2049 self.url,
2050 training_session_id.value(),
2051 modelname
2052 ))
2053 .header("Authorization", format!("Bearer {}", self.token().await))
2054 .send()
2055 .await?;
2056 if !resp.status().is_success() {
2057 let err = resp.error_for_status_ref().unwrap_err();
2058 return Err(Error::HttpError(err));
2059 }
2060
2061 if let Some(parent) = filename.parent() {
2062 fs::create_dir_all(parent).await?;
2063 }
2064
2065 if let Some(progress) = progress {
2066 let total = resp.content_length().unwrap_or(0) as usize;
2067 let _ = progress.send(Progress { current: 0, total }).await;
2068
2069 let mut file = File::create(filename).await?;
2070 let mut current = 0;
2071 let mut stream = resp.bytes_stream();
2072
2073 while let Some(item) = stream.next().await {
2074 let chunk = item?;
2075 file.write_all(&chunk).await?;
2076 current += chunk.len();
2077 let _ = progress.send(Progress { current, total }).await;
2078 }
2079 } else {
2080 let body = resp.bytes().await?;
2081 fs::write(filename, body).await?;
2082 }
2083
2084 Ok(())
2085 }
2086
2087 pub async fn download_checkpoint(
2097 &self,
2098 training_session_id: TrainingSessionID,
2099 checkpoint: &str,
2100 filename: Option<PathBuf>,
2101 progress: Option<Sender<Progress>>,
2102 ) -> Result<(), Error> {
2103 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2104 let resp = self
2105 .http
2106 .get(format!(
2107 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2108 self.url,
2109 training_session_id.value(),
2110 checkpoint
2111 ))
2112 .header("Authorization", format!("Bearer {}", self.token().await))
2113 .send()
2114 .await?;
2115 if !resp.status().is_success() {
2116 let err = resp.error_for_status_ref().unwrap_err();
2117 return Err(Error::HttpError(err));
2118 }
2119
2120 if let Some(parent) = filename.parent() {
2121 fs::create_dir_all(parent).await?;
2122 }
2123
2124 if let Some(progress) = progress {
2125 let total = resp.content_length().unwrap_or(0) as usize;
2126 let _ = progress.send(Progress { current: 0, total }).await;
2127
2128 let mut file = File::create(filename).await?;
2129 let mut current = 0;
2130 let mut stream = resp.bytes_stream();
2131
2132 while let Some(item) = stream.next().await {
2133 let chunk = item?;
2134 file.write_all(&chunk).await?;
2135 current += chunk.len();
2136 let _ = progress.send(Progress { current, total }).await;
2137 }
2138 } else {
2139 let body = resp.bytes().await?;
2140 fs::write(filename, body).await?;
2141 }
2142
2143 Ok(())
2144 }
2145
2146 pub async fn tasks(
2148 &self,
2149 name: Option<&str>,
2150 workflow: Option<&str>,
2151 status: Option<&str>,
2152 manager: Option<&str>,
2153 ) -> Result<Vec<Task>, Error> {
2154 let mut params = TasksListParams {
2155 continue_token: None,
2156 status: status.map(|s| vec![s.to_owned()]),
2157 manager: manager.map(|m| vec![m.to_owned()]),
2158 };
2159 let mut tasks = Vec::new();
2160
2161 loop {
2162 let result = self
2163 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
2164 .await?;
2165 tasks.extend(result.tasks);
2166
2167 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
2168 params.continue_token = None;
2169 } else {
2170 params.continue_token = result.continue_token;
2171 }
2172
2173 if params.continue_token.is_none() {
2174 break;
2175 }
2176 }
2177
2178 if let Some(name) = name {
2179 tasks.retain(|t| t.name().contains(name));
2180 }
2181
2182 if let Some(workflow) = workflow {
2183 tasks.retain(|t| t.workflow().contains(workflow));
2184 }
2185
2186 Ok(tasks)
2187 }
2188
2189 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
2191 self.rpc(
2192 "task.get".to_owned(),
2193 Some(HashMap::from([("id", task_id)])),
2194 )
2195 .await
2196 }
2197
2198 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
2200 let status = TaskStatus {
2201 task_id,
2202 status: status.to_owned(),
2203 };
2204 self.rpc("docker.update.status".to_owned(), Some(status))
2205 .await
2206 }
2207
2208 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
2212 let stages: Vec<HashMap<String, String>> = stages
2213 .iter()
2214 .map(|(key, value)| {
2215 let mut stage_map = HashMap::new();
2216 stage_map.insert(key.to_string(), value.to_string());
2217 stage_map
2218 })
2219 .collect();
2220 let params = TaskStages { task_id, stages };
2221 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
2222 Ok(())
2223 }
2224
2225 pub async fn update_stage(
2228 &self,
2229 task_id: TaskID,
2230 stage: &str,
2231 status: &str,
2232 message: &str,
2233 percentage: u8,
2234 ) -> Result<(), Error> {
2235 let stage = Stage::new(
2236 Some(task_id),
2237 stage.to_owned(),
2238 Some(status.to_owned()),
2239 Some(message.to_owned()),
2240 percentage,
2241 );
2242 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
2243 Ok(())
2244 }
2245
2246 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
2248 let req = self
2249 .http
2250 .get(format!("{}/{}", self.url, query))
2251 .header("User-Agent", "EdgeFirst Client")
2252 .header("Authorization", format!("Bearer {}", self.token().await));
2253 let resp = req.send().await?;
2254
2255 if resp.status().is_success() {
2256 let body = resp.bytes().await?;
2257
2258 if log_enabled!(Level::Trace) {
2259 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
2260 }
2261
2262 Ok(body.to_vec())
2263 } else {
2264 let err = resp.error_for_status_ref().unwrap_err();
2265 Err(Error::HttpError(err))
2266 }
2267 }
2268
2269 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2273 let req = self
2274 .http
2275 .post(format!("{}/api?method={}", self.url, method))
2276 .header("Accept", "application/json")
2277 .header("User-Agent", "EdgeFirst Client")
2278 .header("Authorization", format!("Bearer {}", self.token().await))
2279 .multipart(form);
2280 let resp = req.send().await?;
2281
2282 if resp.status().is_success() {
2283 let body = resp.bytes().await?;
2284
2285 if log_enabled!(Level::Trace) {
2286 trace!(
2287 "POST Multipart Response: {}",
2288 String::from_utf8_lossy(&body)
2289 );
2290 }
2291
2292 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2293 Ok(response) => response,
2294 Err(err) => {
2295 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2296 return Err(err.into());
2297 }
2298 };
2299
2300 if let Some(error) = response.error {
2301 Err(Error::RpcError(error.code, error.message))
2302 } else if let Some(result) = response.result {
2303 Ok(result)
2304 } else {
2305 Err(Error::InvalidResponse)
2306 }
2307 } else {
2308 let err = resp.error_for_status_ref().unwrap_err();
2309 Err(Error::HttpError(err))
2310 }
2311 }
2312
2313 pub async fn rpc<Params, RpcResult>(
2322 &self,
2323 method: String,
2324 params: Option<Params>,
2325 ) -> Result<RpcResult, Error>
2326 where
2327 Params: Serialize,
2328 RpcResult: DeserializeOwned,
2329 {
2330 let auth_expires = self.token_expiration().await?;
2331 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2332 self.renew_token().await?;
2333 }
2334
2335 self.rpc_without_auth(method, params).await
2336 }
2337
2338 async fn rpc_without_auth<Params, RpcResult>(
2339 &self,
2340 method: String,
2341 params: Option<Params>,
2342 ) -> Result<RpcResult, Error>
2343 where
2344 Params: Serialize,
2345 RpcResult: DeserializeOwned,
2346 {
2347 let request = RpcRequest {
2348 method,
2349 params,
2350 ..Default::default()
2351 };
2352
2353 if log_enabled!(Level::Trace) {
2354 trace!(
2355 "RPC Request: {}",
2356 serde_json::ser::to_string_pretty(&request)?
2357 );
2358 }
2359
2360 for attempt in 0..MAX_RETRIES {
2361 match self.try_rpc_request(&request, attempt).await {
2362 Ok(result) => return Ok(result),
2363 Err(Error::MaxRetriesExceeded(_)) => continue,
2364 Err(err) => return Err(err),
2365 }
2366 }
2367
2368 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2369 }
2370
2371 async fn try_rpc_request<Params, RpcResult>(
2372 &self,
2373 request: &RpcRequest<Params>,
2374 attempt: u32,
2375 ) -> Result<RpcResult, Error>
2376 where
2377 Params: Serialize,
2378 RpcResult: DeserializeOwned,
2379 {
2380 trace!(
2381 "RPC attempt {}/{} for method: {}",
2382 attempt + 1,
2383 MAX_RETRIES,
2384 request.method
2385 );
2386
2387 let res = match self
2388 .http
2389 .post(format!("{}/api", self.url))
2390 .header("Accept", "application/json")
2391 .header("User-Agent", "EdgeFirst Client")
2392 .header("Authorization", format!("Bearer {}", self.token().await))
2393 .json(&request)
2394 .send()
2395 .await
2396 {
2397 Ok(res) => res,
2398 Err(err) => {
2399 warn!("Socket Error: {:?}", err);
2400 trace!(
2401 "Socket error details for method '{}' - Status: {:?}, Is timeout: {}, Is connect: {}, Is request: {}, URL: {}/api",
2402 request.method,
2403 err.status(),
2404 err.is_timeout(),
2405 err.is_connect(),
2406 err.is_request(),
2407 self.url
2408 );
2409 return Err(Error::MaxRetriesExceeded(attempt));
2410 }
2411 };
2412
2413 trace!(
2414 "RPC response for method '{}': status={}, content-length={:?}",
2415 request.method,
2416 res.status(),
2417 res.content_length()
2418 );
2419
2420 if res.status().is_success() {
2421 self.process_rpc_response(res).await
2422 } else {
2423 let status = res.status();
2424 let err = res.error_for_status_ref().unwrap_err();
2425 let body = res.text().await?;
2426
2427 warn!("HTTP Error {}: {}", err, body);
2428 trace!(
2429 "HTTP error details for method '{}' - Status: {}, Body length: {}, Response body: {}",
2430 request.method,
2431 status,
2432 body.len(),
2433 if body.len() < 1000 { &body } else { &format!("{}...(truncated)", &body[..1000]) }
2434 );
2435 warn!(
2436 "Retrying RPC request (attempt {}/{})...",
2437 attempt + 1,
2438 MAX_RETRIES
2439 );
2440 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2441 Err(Error::MaxRetriesExceeded(attempt))
2442 }
2443 }
2444
2445 async fn process_rpc_response<RpcResult>(
2446 &self,
2447 res: reqwest::Response,
2448 ) -> Result<RpcResult, Error>
2449 where
2450 RpcResult: DeserializeOwned,
2451 {
2452 let body = res.bytes().await?;
2453
2454 if log_enabled!(Level::Trace) {
2455 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2456 }
2457
2458 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2459 Ok(response) => response,
2460 Err(err) => {
2461 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2462 return Err(err.into());
2463 }
2464 };
2465
2466 if let Some(error) = response.error {
2472 Err(Error::RpcError(error.code, error.message))
2473 } else if let Some(result) = response.result {
2474 Ok(result)
2475 } else {
2476 Err(Error::InvalidResponse)
2477 }
2478 }
2479}
2480
2481async fn parallel_foreach_items<T, F, Fut>(
2510 items: Vec<T>,
2511 progress: Option<Sender<Progress>>,
2512 work_fn: F,
2513) -> Result<(), Error>
2514where
2515 T: Send + 'static,
2516 F: Fn(T) -> Fut + Send + Sync + 'static,
2517 Fut: Future<Output = Result<(), Error>> + Send + 'static,
2518{
2519 let total = items.len();
2520 let current = Arc::new(AtomicUsize::new(0));
2521 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2522 let work_fn = Arc::new(work_fn);
2523
2524 let tasks = items
2525 .into_iter()
2526 .map(|item| {
2527 let sem = sem.clone();
2528 let current = current.clone();
2529 let progress = progress.clone();
2530 let work_fn = work_fn.clone();
2531
2532 tokio::spawn(async move {
2533 let _permit = sem.acquire().await.map_err(|_| {
2534 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2535 })?;
2536
2537 work_fn(item).await?;
2539
2540 if let Some(progress) = &progress {
2542 let current = current.fetch_add(1, Ordering::SeqCst);
2543 let _ = progress
2544 .send(Progress {
2545 current: current + 1,
2546 total,
2547 })
2548 .await;
2549 }
2550
2551 Ok::<(), Error>(())
2552 })
2553 })
2554 .collect::<Vec<_>>();
2555
2556 join_all(tasks)
2557 .await
2558 .into_iter()
2559 .collect::<Result<Vec<_>, _>>()?
2560 .into_iter()
2561 .collect::<Result<Vec<_>, _>>()?;
2562
2563 if let Some(progress) = progress {
2564 drop(progress);
2565 }
2566
2567 Ok(())
2568}
2569
2570async fn upload_multipart(
2595 http: reqwest::Client,
2596 part: SnapshotPart,
2597 path: PathBuf,
2598 total: usize,
2599 current: Arc<AtomicUsize>,
2600 progress: Option<Sender<Progress>>,
2601) -> Result<SnapshotCompleteMultipartParams, Error> {
2602 let filesize = path.metadata()?.len() as usize;
2603 let n_parts = filesize.div_ceil(PART_SIZE);
2604 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2605
2606 let key = part.key.ok_or(Error::InvalidResponse)?;
2607 let upload_id = part.upload_id;
2608
2609 let urls = part.urls.clone();
2610 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2612 EtagPart {
2613 etag: "".to_owned(),
2614 part_number: 0,
2615 };
2616 n_parts
2617 ]));
2618
2619 let tasks = (0..n_parts)
2621 .map(|part| {
2622 let http = http.clone();
2623 let url = urls[part].clone();
2624 let etags = etags.clone();
2625 let path = path.to_owned();
2626 let sem = sem.clone();
2627 let progress = progress.clone();
2628 let current = current.clone();
2629
2630 tokio::spawn(async move {
2631 let _permit = sem.acquire().await?;
2633 let mut etag = None;
2634
2635 for attempt in 0..MAX_RETRIES {
2637 match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
2638 {
2639 Ok(v) => {
2640 etag = Some(v);
2641 break;
2642 }
2643 Err(err) => {
2644 warn!("Upload Part Error: {:?}", err);
2645 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2646 }
2647 }
2648 }
2649
2650 if let Some(etag) = etag {
2651 let mut etags = etags.lock().await;
2653 etags[part] = EtagPart {
2654 etag,
2655 part_number: part + 1,
2656 };
2657
2658 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2660 if let Some(progress) = &progress {
2661 let _ = progress
2662 .send(Progress {
2663 current: current + PART_SIZE,
2664 total,
2665 })
2666 .await;
2667 }
2668
2669 Ok(())
2670 } else {
2671 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2672 }
2673 })
2674 })
2675 .collect::<Vec<_>>();
2676
2677 join_all(tasks)
2679 .await
2680 .into_iter()
2681 .collect::<Result<Vec<_>, _>>()?;
2682
2683 Ok(SnapshotCompleteMultipartParams {
2684 key,
2685 upload_id,
2686 etag_list: etags.lock().await.clone(),
2687 })
2688}
2689
2690async fn upload_part(
2691 http: reqwest::Client,
2692 url: String,
2693 path: PathBuf,
2694 part: usize,
2695 n_parts: usize,
2696) -> Result<String, Error> {
2697 let filesize = path.metadata()?.len() as usize;
2698 let mut file = File::open(path).await?;
2699 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2700 .await?;
2701 let file = file.take(PART_SIZE as u64);
2702
2703 let body_length = if part + 1 == n_parts {
2704 filesize % PART_SIZE
2705 } else {
2706 PART_SIZE
2707 };
2708
2709 let stream = FramedRead::new(file, BytesCodec::new());
2710 let body = Body::wrap_stream(stream);
2711
2712 let resp = http
2713 .put(url.clone())
2714 .header(CONTENT_LENGTH, body_length)
2715 .body(body)
2716 .send()
2717 .await?
2718 .error_for_status()?;
2719
2720 let etag = resp
2721 .headers()
2722 .get("etag")
2723 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
2724 .to_str()
2725 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
2726 .to_owned();
2727
2728 let etag = etag
2730 .strip_prefix("\"")
2731 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
2732 let etag = etag
2733 .strip_suffix("\"")
2734 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
2735
2736 Ok(etag.to_owned())
2737}
2738
2739async fn upload_file_to_presigned_url(
2744 http: reqwest::Client,
2745 url: &str,
2746 path: PathBuf,
2747) -> Result<(), Error> {
2748 let file_data = fs::read(&path).await?;
2750 let file_size = file_data.len();
2751
2752 for attempt in 1..=MAX_RETRIES {
2754 match http
2755 .put(url)
2756 .header(CONTENT_LENGTH, file_size)
2757 .body(file_data.clone())
2758 .send()
2759 .await
2760 {
2761 Ok(resp) => {
2762 if resp.status().is_success() {
2763 debug!(
2764 "Successfully uploaded file: {:?} ({} bytes)",
2765 path, file_size
2766 );
2767 return Ok(());
2768 } else {
2769 let status = resp.status();
2770 let error_text = resp.text().await.unwrap_or_default();
2771 warn!(
2772 "Upload failed [attempt {}/{}]: HTTP {} - {}",
2773 attempt, MAX_RETRIES, status, error_text
2774 );
2775 }
2776 }
2777 Err(err) => {
2778 warn!(
2779 "Upload error [attempt {}/{}]: {:?}",
2780 attempt, MAX_RETRIES, err
2781 );
2782 }
2783 }
2784
2785 if attempt < MAX_RETRIES {
2786 tokio::time::sleep(Duration::from_secs(attempt as u64)).await;
2787 }
2788 }
2789
2790 Err(Error::InvalidParameters(format!(
2791 "Failed to upload file {:?} after {} attempts",
2792 path, MAX_RETRIES
2793 )))
2794}