1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14 retry::{create_retry_policy, log_retry_configuration},
15};
16use base64::Engine as _;
17use chrono::{DateTime, Utc};
18use directories::ProjectDirs;
19use futures::{StreamExt as _, future::join_all};
20use log::{Level, debug, error, log_enabled, trace, warn};
21use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
22use serde::{Deserialize, Serialize, de::DeserializeOwned};
23use std::{
24 collections::HashMap,
25 ffi::OsStr,
26 fs::create_dir_all,
27 io::{SeekFrom, Write as _},
28 path::{Path, PathBuf},
29 sync::{
30 Arc,
31 atomic::{AtomicUsize, Ordering},
32 },
33 time::Duration,
34 vec,
35};
36use tokio::{
37 fs::{self, File},
38 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
39 sync::{RwLock, Semaphore, mpsc::Sender},
40};
41use tokio_util::codec::{BytesCodec, FramedRead};
42use walkdir::WalkDir;
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47static MAX_TASKS: usize = 32;
48static PART_SIZE: usize = 100 * 1024 * 1024;
49
50fn sanitize_path_component(name: &str) -> String {
51 let trimmed = name.trim();
52 if trimmed.is_empty() {
53 return "unnamed".to_string();
54 }
55
56 let component = Path::new(trimmed)
57 .file_name()
58 .unwrap_or_else(|| OsStr::new(trimmed));
59
60 let sanitized: String = component
61 .to_string_lossy()
62 .chars()
63 .map(|c| match c {
64 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
65 _ => c,
66 })
67 .collect();
68
69 if sanitized.is_empty() {
70 "unnamed".to_string()
71 } else {
72 sanitized
73 }
74}
75
76#[derive(Debug, Clone)]
98pub struct Progress {
99 pub current: usize,
101 pub total: usize,
103}
104
105#[derive(Serialize)]
106struct RpcRequest<Params> {
107 id: u64,
108 jsonrpc: String,
109 method: String,
110 params: Option<Params>,
111}
112
113impl<T> Default for RpcRequest<T> {
114 fn default() -> Self {
115 RpcRequest {
116 id: 0,
117 jsonrpc: "2.0".to_string(),
118 method: "".to_string(),
119 params: None,
120 }
121 }
122}
123
124#[derive(Deserialize)]
125struct RpcError {
126 code: i32,
127 message: String,
128}
129
130#[derive(Deserialize)]
131struct RpcResponse<RpcResult> {
132 #[allow(dead_code)]
133 id: String,
134 #[allow(dead_code)]
135 jsonrpc: String,
136 error: Option<RpcError>,
137 result: Option<RpcResult>,
138}
139
140#[derive(Deserialize)]
141#[allow(dead_code)]
142struct EmptyResult {}
143
144#[derive(Debug, Serialize)]
145#[allow(dead_code)]
146struct SnapshotCreateParams {
147 snapshot_name: String,
148 keys: Vec<String>,
149}
150
151#[derive(Debug, Deserialize)]
152#[allow(dead_code)]
153struct SnapshotCreateResult {
154 snapshot_id: SnapshotID,
155 urls: Vec<String>,
156}
157
158#[derive(Debug, Serialize)]
159struct SnapshotCreateMultipartParams {
160 snapshot_name: String,
161 keys: Vec<String>,
162 file_sizes: Vec<usize>,
163}
164
165#[derive(Debug, Deserialize)]
166#[serde(untagged)]
167enum SnapshotCreateMultipartResultField {
168 Id(u64),
169 Part(SnapshotPart),
170}
171
172#[derive(Debug, Serialize)]
173struct SnapshotCompleteMultipartParams {
174 key: String,
175 upload_id: String,
176 etag_list: Vec<EtagPart>,
177}
178
179#[derive(Debug, Clone, Serialize)]
180struct EtagPart {
181 #[serde(rename = "ETag")]
182 etag: String,
183 #[serde(rename = "PartNumber")]
184 part_number: usize,
185}
186
187#[derive(Debug, Clone, Deserialize)]
188struct SnapshotPart {
189 key: Option<String>,
190 upload_id: String,
191 urls: Vec<String>,
192}
193
194#[derive(Debug, Serialize)]
195struct SnapshotStatusParams {
196 snapshot_id: SnapshotID,
197 status: String,
198}
199
200#[derive(Deserialize, Debug)]
201struct SnapshotStatusResult {
202 #[allow(dead_code)]
203 pub id: SnapshotID,
204 #[allow(dead_code)]
205 pub uid: String,
206 #[allow(dead_code)]
207 pub description: String,
208 #[allow(dead_code)]
209 pub date: String,
210 #[allow(dead_code)]
211 pub status: String,
212}
213
214#[derive(Serialize)]
215#[allow(dead_code)]
216struct ImageListParams {
217 images_filter: ImagesFilter,
218 image_files_filter: HashMap<String, String>,
219 only_ids: bool,
220}
221
222#[derive(Serialize)]
223#[allow(dead_code)]
224struct ImagesFilter {
225 dataset_id: DatasetID,
226}
227
228#[derive(Clone, Debug)]
277pub struct Client {
278 http: reqwest::Client,
279 url: String,
280 token: Arc<RwLock<String>>,
281 token_path: Option<PathBuf>,
282}
283
284struct FetchContext<'a> {
286 dataset_id: DatasetID,
287 annotation_set_id: Option<AnnotationSetID>,
288 groups: &'a [String],
289 types: Vec<String>,
290 labels: &'a HashMap<String, u64>,
291}
292
293impl Client {
294 pub fn new() -> Result<Self, Error> {
303 log_retry_configuration();
304
305 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
307 .ok()
308 .and_then(|s| s.parse().ok())
309 .unwrap_or(30); let http = reqwest::Client::builder()
321 .connect_timeout(Duration::from_secs(10))
322 .timeout(Duration::from_secs(timeout_secs))
323 .pool_idle_timeout(Duration::from_secs(90))
324 .pool_max_idle_per_host(10)
325 .retry(create_retry_policy())
326 .build()?;
327
328 Ok(Client {
329 http,
330 url: "https://edgefirst.studio".to_string(),
331 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
332 token_path: None,
333 })
334 }
335
336 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
340 Ok(Client {
341 url: format!("https://{}.edgefirst.studio", server),
342 ..self.clone()
343 })
344 }
345
346 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
349 let params = HashMap::from([("username", username), ("password", password)]);
350 let login: LoginResult = self
351 .rpc_without_auth("auth.login".to_owned(), Some(params))
352 .await?;
353 Ok(Client {
354 token: Arc::new(tokio::sync::RwLock::new(login.token)),
355 ..self.clone()
356 })
357 }
358
359 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
362 let token_path = match token_path {
363 Some(path) => path.to_path_buf(),
364 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
365 .ok_or_else(|| {
366 Error::IoError(std::io::Error::new(
367 std::io::ErrorKind::NotFound,
368 "Could not determine user config directory",
369 ))
370 })?
371 .config_dir()
372 .join("token"),
373 };
374
375 debug!("Using token path: {:?}", token_path);
376
377 let token = match token_path.exists() {
378 true => std::fs::read_to_string(&token_path)?,
379 false => "".to_string(),
380 };
381
382 if !token.is_empty() {
383 let client = self.with_token(&token)?;
384 Ok(Client {
385 token_path: Some(token_path),
386 ..client
387 })
388 } else {
389 Ok(Client {
390 token_path: Some(token_path),
391 ..self.clone()
392 })
393 }
394 }
395
396 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
398 if token.is_empty() {
399 return Ok(self.clone());
400 }
401
402 let token_parts: Vec<&str> = token.split('.').collect();
403 if token_parts.len() != 3 {
404 return Err(Error::InvalidToken);
405 }
406
407 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
408 .decode(token_parts[1])
409 .map_err(|_| Error::InvalidToken)?;
410 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
411 let server = match payload.get("server") {
412 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
413 None => return Err(Error::InvalidToken),
414 };
415
416 Ok(Client {
417 url: format!("https://{}.edgefirst.studio", server),
418 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
419 ..self.clone()
420 })
421 }
422
423 pub async fn save_token(&self) -> Result<(), Error> {
424 let path = self.token_path.clone().unwrap_or_else(|| {
425 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
426 .map(|dirs| dirs.config_dir().join("token"))
427 .unwrap_or_else(|| PathBuf::from(".token"))
428 });
429
430 create_dir_all(path.parent().ok_or_else(|| {
431 Error::IoError(std::io::Error::new(
432 std::io::ErrorKind::InvalidInput,
433 "Token path has no parent directory",
434 ))
435 })?)?;
436 let mut file = std::fs::File::create(&path)?;
437 file.write_all(self.token.read().await.as_bytes())?;
438
439 debug!("Saved token to {:?}", path);
440
441 Ok(())
442 }
443
444 pub async fn version(&self) -> Result<String, Error> {
447 let version: HashMap<String, String> = self
448 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
449 .await?;
450 let version = version.get("version").ok_or(Error::InvalidResponse)?;
451 Ok(version.to_owned())
452 }
453
454 pub async fn logout(&self) -> Result<(), Error> {
458 {
459 let mut token = self.token.write().await;
460 *token = "".to_string();
461 }
462
463 if let Some(path) = &self.token_path
464 && path.exists()
465 {
466 fs::remove_file(path).await?;
467 }
468
469 Ok(())
470 }
471
472 pub async fn token(&self) -> String {
476 self.token.read().await.clone()
477 }
478
479 pub async fn verify_token(&self) -> Result<(), Error> {
484 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
485 .await?;
486 Ok::<(), Error>(())
487 }
488
489 pub async fn renew_token(&self) -> Result<(), Error> {
494 let params = HashMap::from([("username".to_string(), self.username().await?)]);
495 let result: LoginResult = self
496 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
497 .await?;
498
499 {
500 let mut token = self.token.write().await;
501 *token = result.token;
502 }
503
504 if self.token_path.is_some() {
505 self.save_token().await?;
506 }
507
508 Ok(())
509 }
510
511 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
512 let token = self.token.read().await;
513 if token.is_empty() {
514 return Err(Error::EmptyToken);
515 }
516
517 let token_parts: Vec<&str> = token.split('.').collect();
518 if token_parts.len() != 3 {
519 return Err(Error::InvalidToken);
520 }
521
522 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
523 .decode(token_parts[1])
524 .map_err(|_| Error::InvalidToken)?;
525 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
526 match payload.get(field) {
527 Some(value) => Ok(value.to_owned()),
528 None => Err(Error::InvalidToken),
529 }
530 }
531
532 pub fn url(&self) -> &str {
534 &self.url
535 }
536
537 pub async fn username(&self) -> Result<String, Error> {
539 match self.token_field("username").await? {
540 serde_json::Value::String(username) => Ok(username),
541 _ => Err(Error::InvalidToken),
542 }
543 }
544
545 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
547 let ts = match self.token_field("exp").await? {
548 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
549 _ => return Err(Error::InvalidToken),
550 };
551
552 match DateTime::<Utc>::from_timestamp_secs(ts) {
553 Some(dt) => Ok(dt),
554 None => Err(Error::InvalidToken),
555 }
556 }
557
558 pub async fn organization(&self) -> Result<Organization, Error> {
560 self.rpc::<(), Organization>("org.get".to_owned(), None)
561 .await
562 }
563
564 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
572 let projects = self
573 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
574 .await?;
575 if let Some(name) = name {
576 Ok(projects
577 .into_iter()
578 .filter(|p| p.name().contains(name))
579 .collect())
580 } else {
581 Ok(projects)
582 }
583 }
584
585 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
588 let params = HashMap::from([("project_id", project_id)]);
589 self.rpc("project.get".to_owned(), Some(params)).await
590 }
591
592 pub async fn datasets(
596 &self,
597 project_id: ProjectID,
598 name: Option<&str>,
599 ) -> Result<Vec<Dataset>, Error> {
600 let params = HashMap::from([("project_id", project_id)]);
601 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
602 if let Some(name) = name {
603 Ok(datasets
604 .into_iter()
605 .filter(|d| d.name().contains(name))
606 .collect())
607 } else {
608 Ok(datasets)
609 }
610 }
611
612 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
615 let params = HashMap::from([("dataset_id", dataset_id)]);
616 self.rpc("dataset.get".to_owned(), Some(params)).await
617 }
618
619 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
621 let params = HashMap::from([("dataset_id", dataset_id)]);
622 self.rpc("label.list".to_owned(), Some(params)).await
623 }
624
625 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
627 let new_label = NewLabel {
628 dataset_id,
629 labels: vec![NewLabelObject {
630 name: name.to_owned(),
631 }],
632 };
633 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
634 Ok(())
635 }
636
637 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
640 let params = HashMap::from([("label_id", label_id)]);
641 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
642 Ok(())
643 }
644
645 pub async fn create_dataset(
657 &self,
658 project_id: &str,
659 name: &str,
660 description: Option<&str>,
661 ) -> Result<DatasetID, Error> {
662 let mut params = HashMap::new();
663 params.insert("project_id", project_id);
664 params.insert("name", name);
665 if let Some(desc) = description {
666 params.insert("description", desc);
667 }
668
669 #[derive(Deserialize)]
670 struct CreateDatasetResult {
671 id: DatasetID,
672 }
673
674 let result: CreateDatasetResult =
675 self.rpc("dataset.create".to_owned(), Some(params)).await?;
676 Ok(result.id)
677 }
678
679 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
689 let params = HashMap::from([("id", dataset_id)]);
690 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
691 Ok(())
692 }
693
694 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
698 #[derive(Serialize)]
699 struct Params {
700 dataset_id: DatasetID,
701 label_id: u64,
702 label_name: String,
703 label_index: u64,
704 }
705
706 let _: String = self
707 .rpc(
708 "label.update".to_owned(),
709 Some(Params {
710 dataset_id: label.dataset_id(),
711 label_id: label.id(),
712 label_name: label.name().to_owned(),
713 label_index: label.index(),
714 }),
715 )
716 .await?;
717 Ok(())
718 }
719
720 pub async fn download_dataset(
721 &self,
722 dataset_id: DatasetID,
723 groups: &[String],
724 file_types: &[FileType],
725 output: PathBuf,
726 progress: Option<Sender<Progress>>,
727 ) -> Result<(), Error> {
728 let samples = self
729 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
730 .await?;
731 fs::create_dir_all(&output).await?;
732
733 let client = self.clone();
734 let file_types = file_types.to_vec();
735 let output = output.clone();
736
737 parallel_foreach_items(samples, progress, move |sample| {
738 let client = client.clone();
739 let file_types = file_types.clone();
740 let output = output.clone();
741
742 async move {
743 for file_type in file_types {
744 if let Some(data) = sample.download(&client, file_type.clone()).await? {
745 let (file_ext, is_image) = match file_type.clone() {
746 FileType::Image => (
747 infer::get(&data)
748 .expect("Failed to identify image file format for sample")
749 .extension()
750 .to_string(),
751 true,
752 ),
753 other => (other.to_string(), false),
754 };
755
756 let sequence_dir = sample
761 .sequence_name()
762 .map(|name| sanitize_path_component(name));
763
764 let target_dir = sequence_dir
765 .map(|seq| output.join(seq))
766 .unwrap_or_else(|| output.clone());
767 fs::create_dir_all(&target_dir).await?;
768
769 let sanitized_sample_name = sample
770 .name()
771 .map(|name| sanitize_path_component(&name))
772 .unwrap_or_else(|| "unknown".to_string());
773
774 let image_name = sample.image_name().map(sanitize_path_component);
775
776 let file_name = if is_image {
777 image_name.unwrap_or_else(|| {
778 format!("{}.{}", sanitized_sample_name, file_ext)
779 })
780 } else {
781 format!("{}.{}", sanitized_sample_name, file_ext)
782 };
783
784 let file_path = target_dir.join(&file_name);
785
786 let mut file = File::create(&file_path).await?;
787 file.write_all(&data).await?;
788 } else {
789 warn!(
790 "No data for sample: {}",
791 sample
792 .id()
793 .map(|id| id.to_string())
794 .unwrap_or_else(|| "unknown".to_string())
795 );
796 }
797 }
798
799 Ok(())
800 }
801 })
802 .await
803 }
804
805 pub async fn annotation_sets(
807 &self,
808 dataset_id: DatasetID,
809 ) -> Result<Vec<AnnotationSet>, Error> {
810 let params = HashMap::from([("dataset_id", dataset_id)]);
811 self.rpc("annset.list".to_owned(), Some(params)).await
812 }
813
814 pub async fn create_annotation_set(
826 &self,
827 dataset_id: DatasetID,
828 name: &str,
829 description: Option<&str>,
830 ) -> Result<AnnotationSetID, Error> {
831 #[derive(Serialize)]
832 struct Params<'a> {
833 dataset_id: DatasetID,
834 name: &'a str,
835 operator: &'a str,
836 #[serde(skip_serializing_if = "Option::is_none")]
837 description: Option<&'a str>,
838 }
839
840 #[derive(Deserialize)]
841 struct CreateAnnotationSetResult {
842 id: AnnotationSetID,
843 }
844
845 let username = self.username().await?;
846 let result: CreateAnnotationSetResult = self
847 .rpc(
848 "annset.add".to_owned(),
849 Some(Params {
850 dataset_id,
851 name,
852 operator: &username,
853 description,
854 }),
855 )
856 .await?;
857 Ok(result.id)
858 }
859
860 pub async fn delete_annotation_set(
871 &self,
872 annotation_set_id: AnnotationSetID,
873 ) -> Result<(), Error> {
874 let params = HashMap::from([("id", annotation_set_id)]);
875 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
876 Ok(())
877 }
878
879 pub async fn annotation_set(
881 &self,
882 annotation_set_id: AnnotationSetID,
883 ) -> Result<AnnotationSet, Error> {
884 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
885 self.rpc("annset.get".to_owned(), Some(params)).await
886 }
887
888 pub async fn annotations(
901 &self,
902 annotation_set_id: AnnotationSetID,
903 groups: &[String],
904 annotation_types: &[AnnotationType],
905 progress: Option<Sender<Progress>>,
906 ) -> Result<Vec<Annotation>, Error> {
907 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
908 let labels = self
909 .labels(dataset_id)
910 .await?
911 .into_iter()
912 .map(|label| (label.name().to_string(), label.index()))
913 .collect::<HashMap<_, _>>();
914 let total = self
915 .samples_count(
916 dataset_id,
917 Some(annotation_set_id),
918 annotation_types,
919 groups,
920 &[],
921 )
922 .await?
923 .total as usize;
924
925 if total == 0 {
926 return Ok(vec![]);
927 }
928
929 let context = FetchContext {
930 dataset_id,
931 annotation_set_id: Some(annotation_set_id),
932 groups,
933 types: annotation_types.iter().map(|t| t.to_string()).collect(),
934 labels: &labels,
935 };
936
937 self.fetch_annotations_paginated(context, total, progress)
938 .await
939 }
940
941 async fn fetch_annotations_paginated(
942 &self,
943 context: FetchContext<'_>,
944 total: usize,
945 progress: Option<Sender<Progress>>,
946 ) -> Result<Vec<Annotation>, Error> {
947 let mut annotations = vec![];
948 let mut continue_token: Option<String> = None;
949 let mut current = 0;
950
951 loop {
952 let params = SamplesListParams {
953 dataset_id: context.dataset_id,
954 annotation_set_id: context.annotation_set_id,
955 types: context.types.clone(),
956 group_names: context.groups.to_vec(),
957 continue_token,
958 };
959
960 let result: SamplesListResult =
961 self.rpc("samples.list".to_owned(), Some(params)).await?;
962 current += result.samples.len();
963 continue_token = result.continue_token;
964
965 if result.samples.is_empty() {
966 break;
967 }
968
969 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
970
971 if let Some(progress) = &progress {
972 let _ = progress.send(Progress { current, total }).await;
973 }
974
975 match &continue_token {
976 Some(token) if !token.is_empty() => continue,
977 _ => break,
978 }
979 }
980
981 drop(progress);
982 Ok(annotations)
983 }
984
985 fn process_sample_annotations(
986 &self,
987 samples: &[Sample],
988 labels: &HashMap<String, u64>,
989 annotations: &mut Vec<Annotation>,
990 ) {
991 for sample in samples {
992 if sample.annotations().is_empty() {
993 let mut annotation = Annotation::new();
994 annotation.set_sample_id(sample.id());
995 annotation.set_name(sample.name());
996 annotation.set_sequence_name(sample.sequence_name().cloned());
997 annotation.set_frame_number(sample.frame_number());
998 annotation.set_group(sample.group().cloned());
999 annotations.push(annotation);
1000 continue;
1001 }
1002
1003 for annotation in sample.annotations() {
1004 let mut annotation = annotation.clone();
1005 annotation.set_sample_id(sample.id());
1006 annotation.set_name(sample.name());
1007 annotation.set_sequence_name(sample.sequence_name().cloned());
1008 annotation.set_frame_number(sample.frame_number());
1009 annotation.set_group(sample.group().cloned());
1010 Self::set_label_index_from_map(&mut annotation, labels);
1011 annotations.push(annotation);
1012 }
1013 }
1014 }
1015
1016 fn parse_frame_from_image_name(
1024 image_name: Option<&String>,
1025 sequence_name: Option<&String>,
1026 ) -> Option<u32> {
1027 use std::path::Path;
1028
1029 let sequence = sequence_name?;
1030 let name = image_name?;
1031
1032 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1034
1035 stem.strip_prefix(sequence)
1037 .and_then(|suffix| suffix.strip_prefix('_'))
1038 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1039 }
1040
1041 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1043 if let Some(label) = annotation.label() {
1044 annotation.set_label_index(Some(labels[label.as_str()]));
1045 }
1046 }
1047
1048 pub async fn samples_count(
1049 &self,
1050 dataset_id: DatasetID,
1051 annotation_set_id: Option<AnnotationSetID>,
1052 annotation_types: &[AnnotationType],
1053 groups: &[String],
1054 types: &[FileType],
1055 ) -> Result<SamplesCountResult, Error> {
1056 let types = annotation_types
1057 .iter()
1058 .map(|t| t.to_string())
1059 .chain(types.iter().map(|t| t.to_string()))
1060 .collect::<Vec<_>>();
1061
1062 let params = SamplesListParams {
1063 dataset_id,
1064 annotation_set_id,
1065 group_names: groups.to_vec(),
1066 types,
1067 continue_token: None,
1068 };
1069
1070 self.rpc("samples.count".to_owned(), Some(params)).await
1071 }
1072
1073 pub async fn samples(
1074 &self,
1075 dataset_id: DatasetID,
1076 annotation_set_id: Option<AnnotationSetID>,
1077 annotation_types: &[AnnotationType],
1078 groups: &[String],
1079 types: &[FileType],
1080 progress: Option<Sender<Progress>>,
1081 ) -> Result<Vec<Sample>, Error> {
1082 let types_vec = annotation_types
1083 .iter()
1084 .map(|t| t.to_string())
1085 .chain(types.iter().map(|t| t.to_string()))
1086 .collect::<Vec<_>>();
1087 let labels = self
1088 .labels(dataset_id)
1089 .await?
1090 .into_iter()
1091 .map(|label| (label.name().to_string(), label.index()))
1092 .collect::<HashMap<_, _>>();
1093 let total = self
1094 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1095 .await?
1096 .total as usize;
1097
1098 if total == 0 {
1099 return Ok(vec![]);
1100 }
1101
1102 let context = FetchContext {
1103 dataset_id,
1104 annotation_set_id,
1105 groups,
1106 types: types_vec,
1107 labels: &labels,
1108 };
1109
1110 self.fetch_samples_paginated(context, total, progress).await
1111 }
1112
1113 async fn fetch_samples_paginated(
1114 &self,
1115 context: FetchContext<'_>,
1116 total: usize,
1117 progress: Option<Sender<Progress>>,
1118 ) -> Result<Vec<Sample>, Error> {
1119 let mut samples = vec![];
1120 let mut continue_token: Option<String> = None;
1121 let mut current = 0;
1122
1123 loop {
1124 let params = SamplesListParams {
1125 dataset_id: context.dataset_id,
1126 annotation_set_id: context.annotation_set_id,
1127 types: context.types.clone(),
1128 group_names: context.groups.to_vec(),
1129 continue_token: continue_token.clone(),
1130 };
1131
1132 let result: SamplesListResult =
1133 self.rpc("samples.list".to_owned(), Some(params)).await?;
1134 current += result.samples.len();
1135 continue_token = result.continue_token;
1136
1137 if result.samples.is_empty() {
1138 break;
1139 }
1140
1141 samples.append(
1142 &mut result
1143 .samples
1144 .into_iter()
1145 .map(|s| {
1146 let frame_number = s.frame_number.or_else(|| {
1151 Self::parse_frame_from_image_name(
1152 s.image_name.as_ref(),
1153 s.sequence_name.as_ref(),
1154 )
1155 });
1156
1157 let mut anns = s.annotations().to_vec();
1158 for ann in &mut anns {
1159 ann.set_name(s.name());
1161 ann.set_group(s.group().cloned());
1162 ann.set_sequence_name(s.sequence_name().cloned());
1163 ann.set_frame_number(frame_number);
1164 Self::set_label_index_from_map(ann, context.labels);
1165 }
1166 s.with_annotations(anns).with_frame_number(frame_number)
1167 })
1168 .collect::<Vec<_>>(),
1169 );
1170
1171 if let Some(progress) = &progress {
1172 let _ = progress.send(Progress { current, total }).await;
1173 }
1174
1175 match &continue_token {
1176 Some(token) if !token.is_empty() => continue,
1177 _ => break,
1178 }
1179 }
1180
1181 drop(progress);
1182 Ok(samples)
1183 }
1184
1185 pub async fn populate_samples(
1277 &self,
1278 dataset_id: DatasetID,
1279 annotation_set_id: Option<AnnotationSetID>,
1280 samples: Vec<Sample>,
1281 progress: Option<Sender<Progress>>,
1282 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1283 use crate::api::SamplesPopulateParams;
1284
1285 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1287
1288 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1290
1291 let has_files_to_upload = !files_to_upload.is_empty();
1292
1293 let params = SamplesPopulateParams {
1295 dataset_id,
1296 annotation_set_id,
1297 presigned_urls: Some(has_files_to_upload),
1298 samples,
1299 };
1300
1301 let results: Vec<crate::SamplesPopulateResult> = self
1302 .rpc("samples.populate2".to_owned(), Some(params))
1303 .await?;
1304
1305 if has_files_to_upload {
1307 self.upload_sample_files(&results, files_to_upload, progress)
1308 .await?;
1309 }
1310
1311 Ok(results)
1312 }
1313
1314 fn prepare_samples_for_upload(
1315 &self,
1316 samples: Vec<Sample>,
1317 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1318 ) -> Result<Vec<Sample>, Error> {
1319 Ok(samples
1320 .into_iter()
1321 .map(|mut sample| {
1322 if sample.uuid.is_none() {
1324 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1325 }
1326
1327 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1328
1329 let files_copy = sample.files.clone();
1331 let updated_files: Vec<crate::SampleFile> = files_copy
1332 .iter()
1333 .map(|file| {
1334 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1335 })
1336 .collect();
1337
1338 sample.files = updated_files;
1339 sample
1340 })
1341 .collect())
1342 }
1343
1344 fn process_sample_file(
1345 &self,
1346 file: &crate::SampleFile,
1347 sample_uuid: &str,
1348 sample: &mut Sample,
1349 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1350 ) -> crate::SampleFile {
1351 use std::path::Path;
1352
1353 if let Some(filename) = file.filename() {
1354 let path = Path::new(filename);
1355
1356 if path.exists()
1358 && path.is_file()
1359 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1360 {
1361 if file.file_type() == "image"
1363 && (sample.width.is_none() || sample.height.is_none())
1364 && let Ok(size) = imagesize::size(path)
1365 {
1366 sample.width = Some(size.width as u32);
1367 sample.height = Some(size.height as u32);
1368 }
1369
1370 files_to_upload.push((
1372 sample_uuid.to_string(),
1373 file.file_type().to_string(),
1374 path.to_path_buf(),
1375 basename.to_string(),
1376 ));
1377
1378 return crate::SampleFile::with_filename(
1380 file.file_type().to_string(),
1381 basename.to_string(),
1382 );
1383 }
1384 }
1385 file.clone()
1387 }
1388
1389 async fn upload_sample_files(
1390 &self,
1391 results: &[crate::SamplesPopulateResult],
1392 files_to_upload: Vec<(String, String, PathBuf, String)>,
1393 progress: Option<Sender<Progress>>,
1394 ) -> Result<(), Error> {
1395 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1397 for (uuid, _file_type, path, basename) in files_to_upload {
1398 upload_map.insert((uuid, basename), path);
1399 }
1400
1401 let http = self.http.clone();
1402
1403 let upload_tasks: Vec<_> = results
1405 .iter()
1406 .map(|result| (result.uuid.clone(), result.urls.clone()))
1407 .collect();
1408
1409 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1410 let http = http.clone();
1411 let upload_map = upload_map.clone();
1412
1413 async move {
1414 for url_info in &urls {
1416 if let Some(local_path) =
1417 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1418 {
1419 upload_file_to_presigned_url(
1421 http.clone(),
1422 &url_info.url,
1423 local_path.clone(),
1424 )
1425 .await?;
1426 }
1427 }
1428
1429 Ok(())
1430 }
1431 })
1432 .await
1433 }
1434
1435 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1436 let resp = self.http.get(url).send().await?;
1438
1439 if !resp.status().is_success() {
1440 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
1441 }
1442
1443 let bytes = resp.bytes().await?;
1444 Ok(bytes.to_vec())
1445 }
1446
1447 #[deprecated(
1487 since = "0.8.0",
1488 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
1489 )]
1490 #[cfg(feature = "polars")]
1491 pub async fn annotations_dataframe(
1492 &self,
1493 annotation_set_id: AnnotationSetID,
1494 groups: &[String],
1495 types: &[AnnotationType],
1496 progress: Option<Sender<Progress>>,
1497 ) -> Result<DataFrame, Error> {
1498 use crate::dataset::annotations_dataframe;
1499
1500 let annotations = self
1501 .annotations(annotation_set_id, groups, types, progress)
1502 .await?;
1503 #[allow(deprecated)]
1504 annotations_dataframe(&annotations)
1505 }
1506
1507 #[cfg(feature = "polars")]
1544 pub async fn samples_dataframe(
1545 &self,
1546 dataset_id: DatasetID,
1547 annotation_set_id: Option<AnnotationSetID>,
1548 groups: &[String],
1549 types: &[AnnotationType],
1550 progress: Option<Sender<Progress>>,
1551 ) -> Result<DataFrame, Error> {
1552 use crate::dataset::samples_dataframe;
1553
1554 let samples = self
1555 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
1556 .await?;
1557 samples_dataframe(&samples)
1558 }
1559
1560 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1563 let snapshots: Vec<Snapshot> = self
1564 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1565 .await?;
1566 if let Some(name) = name {
1567 Ok(snapshots
1568 .into_iter()
1569 .filter(|s| s.description().contains(name))
1570 .collect())
1571 } else {
1572 Ok(snapshots)
1573 }
1574 }
1575
1576 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1578 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1579 self.rpc("snapshots.get".to_owned(), Some(params)).await
1580 }
1581
1582 pub async fn create_snapshot(
1589 &self,
1590 path: &str,
1591 progress: Option<Sender<Progress>>,
1592 ) -> Result<Snapshot, Error> {
1593 let path = Path::new(path);
1594
1595 if path.is_dir() {
1596 let path_str = path.to_str().ok_or_else(|| {
1597 Error::IoError(std::io::Error::new(
1598 std::io::ErrorKind::InvalidInput,
1599 "Path contains invalid UTF-8",
1600 ))
1601 })?;
1602 return self.create_snapshot_folder(path_str, progress).await;
1603 }
1604
1605 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1606 Error::IoError(std::io::Error::new(
1607 std::io::ErrorKind::InvalidInput,
1608 "Invalid filename",
1609 ))
1610 })?;
1611 let total = path.metadata()?.len() as usize;
1612 let current = Arc::new(AtomicUsize::new(0));
1613
1614 if let Some(progress) = &progress {
1615 let _ = progress.send(Progress { current: 0, total }).await;
1616 }
1617
1618 let params = SnapshotCreateMultipartParams {
1619 snapshot_name: name.to_owned(),
1620 keys: vec![name.to_owned()],
1621 file_sizes: vec![total],
1622 };
1623 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1624 .rpc(
1625 "snapshots.create_upload_url_multipart".to_owned(),
1626 Some(params),
1627 )
1628 .await?;
1629
1630 let snapshot_id = match multipart.get("snapshot_id") {
1631 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1632 _ => return Err(Error::InvalidResponse),
1633 };
1634
1635 let snapshot = self.snapshot(snapshot_id).await?;
1636 let part_prefix = snapshot
1637 .path()
1638 .split("::/")
1639 .last()
1640 .ok_or(Error::InvalidResponse)?
1641 .to_owned();
1642 let part_key = format!("{}/{}", part_prefix, name);
1643 let mut part = match multipart.get(&part_key) {
1644 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1645 _ => return Err(Error::InvalidResponse),
1646 }
1647 .clone();
1648 part.key = Some(part_key);
1649
1650 let params = upload_multipart(
1651 self.http.clone(),
1652 part.clone(),
1653 path.to_path_buf(),
1654 total,
1655 current,
1656 progress.clone(),
1657 )
1658 .await?;
1659
1660 let complete: String = self
1661 .rpc(
1662 "snapshots.complete_multipart_upload".to_owned(),
1663 Some(params),
1664 )
1665 .await?;
1666 debug!("Snapshot Multipart Complete: {:?}", complete);
1667
1668 let params: SnapshotStatusParams = SnapshotStatusParams {
1669 snapshot_id,
1670 status: "available".to_owned(),
1671 };
1672 let _: SnapshotStatusResult = self
1673 .rpc("snapshots.update".to_owned(), Some(params))
1674 .await?;
1675
1676 if let Some(progress) = progress {
1677 drop(progress);
1678 }
1679
1680 self.snapshot(snapshot_id).await
1681 }
1682
1683 async fn create_snapshot_folder(
1684 &self,
1685 path: &str,
1686 progress: Option<Sender<Progress>>,
1687 ) -> Result<Snapshot, Error> {
1688 let path = Path::new(path);
1689 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1690 Error::IoError(std::io::Error::new(
1691 std::io::ErrorKind::InvalidInput,
1692 "Invalid directory name",
1693 ))
1694 })?;
1695
1696 let files = WalkDir::new(path)
1697 .into_iter()
1698 .filter_map(|entry| entry.ok())
1699 .filter(|entry| entry.file_type().is_file())
1700 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
1701 .collect::<Vec<_>>();
1702
1703 let total: usize = files
1704 .iter()
1705 .filter_map(|file| path.join(file).metadata().ok())
1706 .map(|metadata| metadata.len() as usize)
1707 .sum();
1708 let current = Arc::new(AtomicUsize::new(0));
1709
1710 if let Some(progress) = &progress {
1711 let _ = progress.send(Progress { current: 0, total }).await;
1712 }
1713
1714 let keys = files
1715 .iter()
1716 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
1717 .collect::<Vec<_>>();
1718 let file_sizes = files
1719 .iter()
1720 .filter_map(|key| path.join(key).metadata().ok())
1721 .map(|metadata| metadata.len() as usize)
1722 .collect::<Vec<_>>();
1723
1724 let params = SnapshotCreateMultipartParams {
1725 snapshot_name: name.to_owned(),
1726 keys,
1727 file_sizes,
1728 };
1729
1730 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1731 .rpc(
1732 "snapshots.create_upload_url_multipart".to_owned(),
1733 Some(params),
1734 )
1735 .await?;
1736
1737 let snapshot_id = match multipart.get("snapshot_id") {
1738 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1739 _ => return Err(Error::InvalidResponse),
1740 };
1741
1742 let snapshot = self.snapshot(snapshot_id).await?;
1743 let part_prefix = snapshot
1744 .path()
1745 .split("::/")
1746 .last()
1747 .ok_or(Error::InvalidResponse)?
1748 .to_owned();
1749
1750 for file in files {
1751 let file_str = file.to_str().ok_or_else(|| {
1752 Error::IoError(std::io::Error::new(
1753 std::io::ErrorKind::InvalidInput,
1754 "File path contains invalid UTF-8",
1755 ))
1756 })?;
1757 let part_key = format!("{}/{}", part_prefix, file_str);
1758 let mut part = match multipart.get(&part_key) {
1759 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1760 _ => return Err(Error::InvalidResponse),
1761 }
1762 .clone();
1763 part.key = Some(part_key);
1764
1765 let params = upload_multipart(
1766 self.http.clone(),
1767 part.clone(),
1768 path.join(file),
1769 total,
1770 current.clone(),
1771 progress.clone(),
1772 )
1773 .await?;
1774
1775 let complete: String = self
1776 .rpc(
1777 "snapshots.complete_multipart_upload".to_owned(),
1778 Some(params),
1779 )
1780 .await?;
1781 debug!("Snapshot Part Complete: {:?}", complete);
1782 }
1783
1784 let params = SnapshotStatusParams {
1785 snapshot_id,
1786 status: "available".to_owned(),
1787 };
1788 let _: SnapshotStatusResult = self
1789 .rpc("snapshots.update".to_owned(), Some(params))
1790 .await?;
1791
1792 if let Some(progress) = progress {
1793 drop(progress);
1794 }
1795
1796 self.snapshot(snapshot_id).await
1797 }
1798
1799 pub async fn download_snapshot(
1804 &self,
1805 snapshot_id: SnapshotID,
1806 output: PathBuf,
1807 progress: Option<Sender<Progress>>,
1808 ) -> Result<(), Error> {
1809 fs::create_dir_all(&output).await?;
1810
1811 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1812 let items: HashMap<String, String> = self
1813 .rpc("snapshots.create_download_url".to_owned(), Some(params))
1814 .await?;
1815
1816 let total = Arc::new(AtomicUsize::new(0));
1817 let current = Arc::new(AtomicUsize::new(0));
1818 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1819
1820 let tasks = items
1821 .iter()
1822 .map(|(key, url)| {
1823 let http = self.http.clone();
1824 let key = key.clone();
1825 let url = url.clone();
1826 let output = output.clone();
1827 let progress = progress.clone();
1828 let current = current.clone();
1829 let total = total.clone();
1830 let sem = sem.clone();
1831
1832 tokio::spawn(async move {
1833 let _permit = sem.acquire().await.map_err(|_| {
1834 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1835 })?;
1836 let res = http.get(url).send().await?;
1837 let content_length = res.content_length().unwrap_or(0) as usize;
1838
1839 if let Some(progress) = &progress {
1840 let total = total.fetch_add(content_length, Ordering::SeqCst);
1841 let _ = progress
1842 .send(Progress {
1843 current: current.load(Ordering::SeqCst),
1844 total: total + content_length,
1845 })
1846 .await;
1847 }
1848
1849 let mut file = File::create(output.join(key)).await?;
1850 let mut stream = res.bytes_stream();
1851
1852 while let Some(chunk) = stream.next().await {
1853 let chunk = chunk?;
1854 file.write_all(&chunk).await?;
1855 let len = chunk.len();
1856
1857 if let Some(progress) = &progress {
1858 let total = total.load(Ordering::SeqCst);
1859 let current = current.fetch_add(len, Ordering::SeqCst);
1860
1861 let _ = progress
1862 .send(Progress {
1863 current: current + len,
1864 total,
1865 })
1866 .await;
1867 }
1868 }
1869
1870 Ok::<(), Error>(())
1871 })
1872 })
1873 .collect::<Vec<_>>();
1874
1875 join_all(tasks)
1876 .await
1877 .into_iter()
1878 .collect::<Result<Vec<_>, _>>()?
1879 .into_iter()
1880 .collect::<Result<Vec<_>, _>>()?;
1881
1882 Ok(())
1883 }
1884
1885 #[allow(clippy::too_many_arguments)]
1900 pub async fn restore_snapshot(
1901 &self,
1902 project_id: ProjectID,
1903 snapshot_id: SnapshotID,
1904 topics: &[String],
1905 autolabel: &[String],
1906 autodepth: bool,
1907 dataset_name: Option<&str>,
1908 dataset_description: Option<&str>,
1909 ) -> Result<SnapshotRestoreResult, Error> {
1910 let params = SnapshotRestore {
1911 project_id,
1912 snapshot_id,
1913 fps: 1,
1914 autodepth,
1915 agtg_pipeline: !autolabel.is_empty(),
1916 autolabel: autolabel.to_vec(),
1917 topics: topics.to_vec(),
1918 dataset_name: dataset_name.map(|s| s.to_owned()),
1919 dataset_description: dataset_description.map(|s| s.to_owned()),
1920 };
1921 self.rpc("snapshots.restore".to_owned(), Some(params)).await
1922 }
1923
1924 pub async fn experiments(
1933 &self,
1934 project_id: ProjectID,
1935 name: Option<&str>,
1936 ) -> Result<Vec<Experiment>, Error> {
1937 let params = HashMap::from([("project_id", project_id)]);
1938 let experiments: Vec<Experiment> =
1939 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1940 if let Some(name) = name {
1941 Ok(experiments
1942 .into_iter()
1943 .filter(|e| e.name().contains(name))
1944 .collect())
1945 } else {
1946 Ok(experiments)
1947 }
1948 }
1949
1950 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1953 let params = HashMap::from([("trainer_id", experiment_id)]);
1954 self.rpc("trainer.get".to_owned(), Some(params)).await
1955 }
1956
1957 pub async fn training_sessions(
1966 &self,
1967 experiment_id: ExperimentID,
1968 name: Option<&str>,
1969 ) -> Result<Vec<TrainingSession>, Error> {
1970 let params = HashMap::from([("trainer_id", experiment_id)]);
1971 let sessions: Vec<TrainingSession> = self
1972 .rpc("trainer.session.list".to_owned(), Some(params))
1973 .await?;
1974 if let Some(name) = name {
1975 Ok(sessions
1976 .into_iter()
1977 .filter(|s| s.name().contains(name))
1978 .collect())
1979 } else {
1980 Ok(sessions)
1981 }
1982 }
1983
1984 pub async fn training_session(
1987 &self,
1988 session_id: TrainingSessionID,
1989 ) -> Result<TrainingSession, Error> {
1990 let params = HashMap::from([("trainer_session_id", session_id)]);
1991 self.rpc("trainer.session.get".to_owned(), Some(params))
1992 .await
1993 }
1994
1995 pub async fn validation_sessions(
1997 &self,
1998 project_id: ProjectID,
1999 ) -> Result<Vec<ValidationSession>, Error> {
2000 let params = HashMap::from([("project_id", project_id)]);
2001 self.rpc("validate.session.list".to_owned(), Some(params))
2002 .await
2003 }
2004
2005 pub async fn validation_session(
2007 &self,
2008 session_id: ValidationSessionID,
2009 ) -> Result<ValidationSession, Error> {
2010 let params = HashMap::from([("validate_session_id", session_id)]);
2011 self.rpc("validate.session.get".to_owned(), Some(params))
2012 .await
2013 }
2014
2015 pub async fn artifacts(
2018 &self,
2019 training_session_id: TrainingSessionID,
2020 ) -> Result<Vec<Artifact>, Error> {
2021 let params = HashMap::from([("training_session_id", training_session_id)]);
2022 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2023 .await
2024 }
2025
2026 pub async fn download_artifact(
2032 &self,
2033 training_session_id: TrainingSessionID,
2034 modelname: &str,
2035 filename: Option<PathBuf>,
2036 progress: Option<Sender<Progress>>,
2037 ) -> Result<(), Error> {
2038 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2039 let resp = self
2040 .http
2041 .get(format!(
2042 "{}/download_model?training_session_id={}&file={}",
2043 self.url,
2044 training_session_id.value(),
2045 modelname
2046 ))
2047 .header("Authorization", format!("Bearer {}", self.token().await))
2048 .send()
2049 .await?;
2050 if !resp.status().is_success() {
2051 let err = resp.error_for_status_ref().unwrap_err();
2052 return Err(Error::HttpError(err));
2053 }
2054
2055 if let Some(parent) = filename.parent() {
2056 fs::create_dir_all(parent).await?;
2057 }
2058
2059 if let Some(progress) = progress {
2060 let total = resp.content_length().unwrap_or(0) as usize;
2061 let _ = progress.send(Progress { current: 0, total }).await;
2062
2063 let mut file = File::create(filename).await?;
2064 let mut current = 0;
2065 let mut stream = resp.bytes_stream();
2066
2067 while let Some(item) = stream.next().await {
2068 let chunk = item?;
2069 file.write_all(&chunk).await?;
2070 current += chunk.len();
2071 let _ = progress.send(Progress { current, total }).await;
2072 }
2073 } else {
2074 let body = resp.bytes().await?;
2075 fs::write(filename, body).await?;
2076 }
2077
2078 Ok(())
2079 }
2080
2081 pub async fn download_checkpoint(
2091 &self,
2092 training_session_id: TrainingSessionID,
2093 checkpoint: &str,
2094 filename: Option<PathBuf>,
2095 progress: Option<Sender<Progress>>,
2096 ) -> Result<(), Error> {
2097 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2098 let resp = self
2099 .http
2100 .get(format!(
2101 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2102 self.url,
2103 training_session_id.value(),
2104 checkpoint
2105 ))
2106 .header("Authorization", format!("Bearer {}", self.token().await))
2107 .send()
2108 .await?;
2109 if !resp.status().is_success() {
2110 let err = resp.error_for_status_ref().unwrap_err();
2111 return Err(Error::HttpError(err));
2112 }
2113
2114 if let Some(parent) = filename.parent() {
2115 fs::create_dir_all(parent).await?;
2116 }
2117
2118 if let Some(progress) = progress {
2119 let total = resp.content_length().unwrap_or(0) as usize;
2120 let _ = progress.send(Progress { current: 0, total }).await;
2121
2122 let mut file = File::create(filename).await?;
2123 let mut current = 0;
2124 let mut stream = resp.bytes_stream();
2125
2126 while let Some(item) = stream.next().await {
2127 let chunk = item?;
2128 file.write_all(&chunk).await?;
2129 current += chunk.len();
2130 let _ = progress.send(Progress { current, total }).await;
2131 }
2132 } else {
2133 let body = resp.bytes().await?;
2134 fs::write(filename, body).await?;
2135 }
2136
2137 Ok(())
2138 }
2139
2140 pub async fn tasks(
2142 &self,
2143 name: Option<&str>,
2144 workflow: Option<&str>,
2145 status: Option<&str>,
2146 manager: Option<&str>,
2147 ) -> Result<Vec<Task>, Error> {
2148 let mut params = TasksListParams {
2149 continue_token: None,
2150 status: status.map(|s| vec![s.to_owned()]),
2151 manager: manager.map(|m| vec![m.to_owned()]),
2152 };
2153 let mut tasks = Vec::new();
2154
2155 loop {
2156 let result = self
2157 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
2158 .await?;
2159 tasks.extend(result.tasks);
2160
2161 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
2162 params.continue_token = None;
2163 } else {
2164 params.continue_token = result.continue_token;
2165 }
2166
2167 if params.continue_token.is_none() {
2168 break;
2169 }
2170 }
2171
2172 if let Some(name) = name {
2173 tasks.retain(|t| t.name().contains(name));
2174 }
2175
2176 if let Some(workflow) = workflow {
2177 tasks.retain(|t| t.workflow().contains(workflow));
2178 }
2179
2180 Ok(tasks)
2181 }
2182
2183 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
2185 self.rpc(
2186 "task.get".to_owned(),
2187 Some(HashMap::from([("id", task_id)])),
2188 )
2189 .await
2190 }
2191
2192 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
2194 let status = TaskStatus {
2195 task_id,
2196 status: status.to_owned(),
2197 };
2198 self.rpc("docker.update.status".to_owned(), Some(status))
2199 .await
2200 }
2201
2202 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
2206 let stages: Vec<HashMap<String, String>> = stages
2207 .iter()
2208 .map(|(key, value)| {
2209 let mut stage_map = HashMap::new();
2210 stage_map.insert(key.to_string(), value.to_string());
2211 stage_map
2212 })
2213 .collect();
2214 let params = TaskStages { task_id, stages };
2215 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
2216 Ok(())
2217 }
2218
2219 pub async fn update_stage(
2222 &self,
2223 task_id: TaskID,
2224 stage: &str,
2225 status: &str,
2226 message: &str,
2227 percentage: u8,
2228 ) -> Result<(), Error> {
2229 let stage = Stage::new(
2230 Some(task_id),
2231 stage.to_owned(),
2232 Some(status.to_owned()),
2233 Some(message.to_owned()),
2234 percentage,
2235 );
2236 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
2237 Ok(())
2238 }
2239
2240 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
2242 let req = self
2243 .http
2244 .get(format!("{}/{}", self.url, query))
2245 .header("User-Agent", "EdgeFirst Client")
2246 .header("Authorization", format!("Bearer {}", self.token().await));
2247 let resp = req.send().await?;
2248
2249 if resp.status().is_success() {
2250 let body = resp.bytes().await?;
2251
2252 if log_enabled!(Level::Trace) {
2253 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
2254 }
2255
2256 Ok(body.to_vec())
2257 } else {
2258 let err = resp.error_for_status_ref().unwrap_err();
2259 Err(Error::HttpError(err))
2260 }
2261 }
2262
2263 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2267 let req = self
2268 .http
2269 .post(format!("{}/api?method={}", self.url, method))
2270 .header("Accept", "application/json")
2271 .header("User-Agent", "EdgeFirst Client")
2272 .header("Authorization", format!("Bearer {}", self.token().await))
2273 .multipart(form);
2274 let resp = req.send().await?;
2275
2276 if resp.status().is_success() {
2277 let body = resp.bytes().await?;
2278
2279 if log_enabled!(Level::Trace) {
2280 trace!(
2281 "POST Multipart Response: {}",
2282 String::from_utf8_lossy(&body)
2283 );
2284 }
2285
2286 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2287 Ok(response) => response,
2288 Err(err) => {
2289 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2290 return Err(err.into());
2291 }
2292 };
2293
2294 if let Some(error) = response.error {
2295 Err(Error::RpcError(error.code, error.message))
2296 } else if let Some(result) = response.result {
2297 Ok(result)
2298 } else {
2299 Err(Error::InvalidResponse)
2300 }
2301 } else {
2302 let err = resp.error_for_status_ref().unwrap_err();
2303 Err(Error::HttpError(err))
2304 }
2305 }
2306
2307 pub async fn rpc<Params, RpcResult>(
2316 &self,
2317 method: String,
2318 params: Option<Params>,
2319 ) -> Result<RpcResult, Error>
2320 where
2321 Params: Serialize,
2322 RpcResult: DeserializeOwned,
2323 {
2324 let auth_expires = self.token_expiration().await?;
2325 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2326 self.renew_token().await?;
2327 }
2328
2329 self.rpc_without_auth(method, params).await
2330 }
2331
2332 async fn rpc_without_auth<Params, RpcResult>(
2333 &self,
2334 method: String,
2335 params: Option<Params>,
2336 ) -> Result<RpcResult, Error>
2337 where
2338 Params: Serialize,
2339 RpcResult: DeserializeOwned,
2340 {
2341 let request = RpcRequest {
2342 method,
2343 params,
2344 ..Default::default()
2345 };
2346
2347 if log_enabled!(Level::Trace) {
2348 trace!(
2349 "RPC Request: {}",
2350 serde_json::ser::to_string_pretty(&request)?
2351 );
2352 }
2353
2354 let url = format!("{}/api", self.url);
2355
2356 let res = self
2359 .http
2360 .post(&url)
2361 .header("Accept", "application/json")
2362 .header("User-Agent", "EdgeFirst Client")
2363 .header("Authorization", format!("Bearer {}", self.token().await))
2364 .json(&request)
2365 .send()
2366 .await?;
2367
2368 self.process_rpc_response(res).await
2369 }
2370
2371 async fn process_rpc_response<RpcResult>(
2372 &self,
2373 res: reqwest::Response,
2374 ) -> Result<RpcResult, Error>
2375 where
2376 RpcResult: DeserializeOwned,
2377 {
2378 let body = res.bytes().await?;
2379
2380 if log_enabled!(Level::Trace) {
2381 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2382 }
2383
2384 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2385 Ok(response) => response,
2386 Err(err) => {
2387 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2388 return Err(err.into());
2389 }
2390 };
2391
2392 if let Some(error) = response.error {
2398 Err(Error::RpcError(error.code, error.message))
2399 } else if let Some(result) = response.result {
2400 Ok(result)
2401 } else {
2402 Err(Error::InvalidResponse)
2403 }
2404 }
2405}
2406
2407async fn parallel_foreach_items<T, F, Fut>(
2436 items: Vec<T>,
2437 progress: Option<Sender<Progress>>,
2438 work_fn: F,
2439) -> Result<(), Error>
2440where
2441 T: Send + 'static,
2442 F: Fn(T) -> Fut + Send + Sync + 'static,
2443 Fut: Future<Output = Result<(), Error>> + Send + 'static,
2444{
2445 let total = items.len();
2446 let current = Arc::new(AtomicUsize::new(0));
2447 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2448 let work_fn = Arc::new(work_fn);
2449
2450 let tasks = items
2451 .into_iter()
2452 .map(|item| {
2453 let sem = sem.clone();
2454 let current = current.clone();
2455 let progress = progress.clone();
2456 let work_fn = work_fn.clone();
2457
2458 tokio::spawn(async move {
2459 let _permit = sem.acquire().await.map_err(|_| {
2460 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2461 })?;
2462
2463 work_fn(item).await?;
2465
2466 if let Some(progress) = &progress {
2468 let current = current.fetch_add(1, Ordering::SeqCst);
2469 let _ = progress
2470 .send(Progress {
2471 current: current + 1,
2472 total,
2473 })
2474 .await;
2475 }
2476
2477 Ok::<(), Error>(())
2478 })
2479 })
2480 .collect::<Vec<_>>();
2481
2482 join_all(tasks)
2483 .await
2484 .into_iter()
2485 .collect::<Result<Vec<_>, _>>()?
2486 .into_iter()
2487 .collect::<Result<Vec<_>, _>>()?;
2488
2489 if let Some(progress) = progress {
2490 drop(progress);
2491 }
2492
2493 Ok(())
2494}
2495
2496async fn upload_multipart(
2521 http: reqwest::Client,
2522 part: SnapshotPart,
2523 path: PathBuf,
2524 total: usize,
2525 current: Arc<AtomicUsize>,
2526 progress: Option<Sender<Progress>>,
2527) -> Result<SnapshotCompleteMultipartParams, Error> {
2528 let filesize = path.metadata()?.len() as usize;
2529 let n_parts = filesize.div_ceil(PART_SIZE);
2530 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2531
2532 let key = part.key.ok_or(Error::InvalidResponse)?;
2533 let upload_id = part.upload_id;
2534
2535 let urls = part.urls.clone();
2536 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2538 EtagPart {
2539 etag: "".to_owned(),
2540 part_number: 0,
2541 };
2542 n_parts
2543 ]));
2544
2545 let tasks = (0..n_parts)
2547 .map(|part| {
2548 let http = http.clone();
2549 let url = urls[part].clone();
2550 let etags = etags.clone();
2551 let path = path.to_owned();
2552 let sem = sem.clone();
2553 let progress = progress.clone();
2554 let current = current.clone();
2555
2556 tokio::spawn(async move {
2557 let _permit = sem.acquire().await?;
2559
2560 let etag =
2562 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
2563
2564 let mut etags = etags.lock().await;
2566 etags[part] = EtagPart {
2567 etag,
2568 part_number: part + 1,
2569 };
2570
2571 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2573 if let Some(progress) = &progress {
2574 let _ = progress
2575 .send(Progress {
2576 current: current + PART_SIZE,
2577 total,
2578 })
2579 .await;
2580 }
2581
2582 Ok::<(), Error>(())
2583 })
2584 })
2585 .collect::<Vec<_>>();
2586
2587 join_all(tasks)
2589 .await
2590 .into_iter()
2591 .collect::<Result<Vec<_>, _>>()?;
2592
2593 Ok(SnapshotCompleteMultipartParams {
2594 key,
2595 upload_id,
2596 etag_list: etags.lock().await.clone(),
2597 })
2598}
2599
2600async fn upload_part(
2601 http: reqwest::Client,
2602 url: String,
2603 path: PathBuf,
2604 part: usize,
2605 n_parts: usize,
2606) -> Result<String, Error> {
2607 let filesize = path.metadata()?.len() as usize;
2608 let mut file = File::open(path).await?;
2609 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2610 .await?;
2611 let file = file.take(PART_SIZE as u64);
2612
2613 let body_length = if part + 1 == n_parts {
2614 filesize % PART_SIZE
2615 } else {
2616 PART_SIZE
2617 };
2618
2619 let stream = FramedRead::new(file, BytesCodec::new());
2620 let body = Body::wrap_stream(stream);
2621
2622 let resp = http
2623 .put(url.clone())
2624 .header(CONTENT_LENGTH, body_length)
2625 .body(body)
2626 .send()
2627 .await?
2628 .error_for_status()?;
2629
2630 let etag = resp
2631 .headers()
2632 .get("etag")
2633 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
2634 .to_str()
2635 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
2636 .to_owned();
2637
2638 let etag = etag
2640 .strip_prefix("\"")
2641 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
2642 let etag = etag
2643 .strip_suffix("\"")
2644 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
2645
2646 Ok(etag.to_owned())
2647}
2648
2649async fn upload_file_to_presigned_url(
2654 http: reqwest::Client,
2655 url: &str,
2656 path: PathBuf,
2657) -> Result<(), Error> {
2658 let file_data = fs::read(&path).await?;
2660 let file_size = file_data.len();
2661
2662 let resp = http
2664 .put(url)
2665 .header(CONTENT_LENGTH, file_size)
2666 .body(file_data)
2667 .send()
2668 .await?;
2669
2670 if resp.status().is_success() {
2671 debug!(
2672 "Successfully uploaded file: {:?} ({} bytes)",
2673 path, file_size
2674 );
2675 Ok(())
2676 } else {
2677 let status = resp.status();
2678 let error_text = resp.text().await.unwrap_or_default();
2679 Err(Error::InvalidParameters(format!(
2680 "Upload failed: HTTP {} - {}",
2681 status, error_text
2682 )))
2683 }
2684}