1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14 retry::{create_retry_policy, log_retry_configuration},
15};
16use base64::Engine as _;
17use chrono::{DateTime, Utc};
18use directories::ProjectDirs;
19use futures::{StreamExt as _, future::join_all};
20use log::{Level, debug, error, log_enabled, trace, warn};
21use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
22use serde::{Deserialize, Serialize, de::DeserializeOwned};
23use std::{
24 collections::HashMap,
25 ffi::OsStr,
26 fs::create_dir_all,
27 io::{SeekFrom, Write as _},
28 path::{Path, PathBuf},
29 sync::{
30 Arc,
31 atomic::{AtomicUsize, Ordering},
32 },
33 time::Duration,
34 vec,
35};
36use tokio::{
37 fs::{self, File},
38 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
39 sync::{RwLock, Semaphore, mpsc::Sender},
40};
41use tokio_util::codec::{BytesCodec, FramedRead};
42use walkdir::WalkDir;
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49fn max_tasks() -> usize {
50 std::env::var("MAX_TASKS")
51 .ok()
52 .and_then(|v| v.parse().ok())
53 .unwrap_or_else(|| {
54 let cpus = std::thread::available_parallelism()
57 .map(|n| n.get())
58 .unwrap_or(4);
59 (cpus / 2).clamp(2, 8)
60 })
61}
62
63fn sanitize_path_component(name: &str) -> String {
64 let trimmed = name.trim();
65 if trimmed.is_empty() {
66 return "unnamed".to_string();
67 }
68
69 let component = Path::new(trimmed)
70 .file_name()
71 .unwrap_or_else(|| OsStr::new(trimmed));
72
73 let sanitized: String = component
74 .to_string_lossy()
75 .chars()
76 .map(|c| match c {
77 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
78 _ => c,
79 })
80 .collect();
81
82 if sanitized.is_empty() {
83 "unnamed".to_string()
84 } else {
85 sanitized
86 }
87}
88
89#[derive(Debug, Clone)]
111pub struct Progress {
112 pub current: usize,
114 pub total: usize,
116}
117
118#[derive(Serialize)]
119struct RpcRequest<Params> {
120 id: u64,
121 jsonrpc: String,
122 method: String,
123 params: Option<Params>,
124}
125
126impl<T> Default for RpcRequest<T> {
127 fn default() -> Self {
128 RpcRequest {
129 id: 0,
130 jsonrpc: "2.0".to_string(),
131 method: "".to_string(),
132 params: None,
133 }
134 }
135}
136
137#[derive(Deserialize)]
138struct RpcError {
139 code: i32,
140 message: String,
141}
142
143#[derive(Deserialize)]
144struct RpcResponse<RpcResult> {
145 #[allow(dead_code)]
146 id: String,
147 #[allow(dead_code)]
148 jsonrpc: String,
149 error: Option<RpcError>,
150 result: Option<RpcResult>,
151}
152
153#[derive(Deserialize)]
154#[allow(dead_code)]
155struct EmptyResult {}
156
157#[derive(Debug, Serialize)]
158#[allow(dead_code)]
159struct SnapshotCreateParams {
160 snapshot_name: String,
161 keys: Vec<String>,
162}
163
164#[derive(Debug, Deserialize)]
165#[allow(dead_code)]
166struct SnapshotCreateResult {
167 snapshot_id: SnapshotID,
168 urls: Vec<String>,
169}
170
171#[derive(Debug, Serialize)]
172struct SnapshotCreateMultipartParams {
173 snapshot_name: String,
174 keys: Vec<String>,
175 file_sizes: Vec<usize>,
176}
177
178#[derive(Debug, Deserialize)]
179#[serde(untagged)]
180enum SnapshotCreateMultipartResultField {
181 Id(u64),
182 Part(SnapshotPart),
183}
184
185#[derive(Debug, Serialize)]
186struct SnapshotCompleteMultipartParams {
187 key: String,
188 upload_id: String,
189 etag_list: Vec<EtagPart>,
190}
191
192#[derive(Debug, Clone, Serialize)]
193struct EtagPart {
194 #[serde(rename = "ETag")]
195 etag: String,
196 #[serde(rename = "PartNumber")]
197 part_number: usize,
198}
199
200#[derive(Debug, Clone, Deserialize)]
201struct SnapshotPart {
202 key: Option<String>,
203 upload_id: String,
204 urls: Vec<String>,
205}
206
207#[derive(Debug, Serialize)]
208struct SnapshotStatusParams {
209 snapshot_id: SnapshotID,
210 status: String,
211}
212
213#[derive(Deserialize, Debug)]
214struct SnapshotStatusResult {
215 #[allow(dead_code)]
216 pub id: SnapshotID,
217 #[allow(dead_code)]
218 pub uid: String,
219 #[allow(dead_code)]
220 pub description: String,
221 #[allow(dead_code)]
222 pub date: String,
223 #[allow(dead_code)]
224 pub status: String,
225}
226
227#[derive(Serialize)]
228#[allow(dead_code)]
229struct ImageListParams {
230 images_filter: ImagesFilter,
231 image_files_filter: HashMap<String, String>,
232 only_ids: bool,
233}
234
235#[derive(Serialize)]
236#[allow(dead_code)]
237struct ImagesFilter {
238 dataset_id: DatasetID,
239}
240
241#[derive(Clone, Debug)]
290pub struct Client {
291 http: reqwest::Client,
292 url: String,
293 token: Arc<RwLock<String>>,
294 token_path: Option<PathBuf>,
295}
296
297struct FetchContext<'a> {
299 dataset_id: DatasetID,
300 annotation_set_id: Option<AnnotationSetID>,
301 groups: &'a [String],
302 types: Vec<String>,
303 labels: &'a HashMap<String, u64>,
304}
305
306impl Client {
307 pub fn new() -> Result<Self, Error> {
316 log_retry_configuration();
317
318 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
320 .ok()
321 .and_then(|s| s.parse().ok())
322 .unwrap_or(30); let http = reqwest::Client::builder()
334 .connect_timeout(Duration::from_secs(10))
335 .timeout(Duration::from_secs(timeout_secs))
336 .pool_idle_timeout(Duration::from_secs(90))
337 .pool_max_idle_per_host(10)
338 .retry(create_retry_policy())
339 .build()?;
340
341 Ok(Client {
342 http,
343 url: "https://edgefirst.studio".to_string(),
344 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
345 token_path: None,
346 })
347 }
348
349 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
353 Ok(Client {
354 url: format!("https://{}.edgefirst.studio", server),
355 ..self.clone()
356 })
357 }
358
359 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
362 let params = HashMap::from([("username", username), ("password", password)]);
363 let login: LoginResult = self
364 .rpc_without_auth("auth.login".to_owned(), Some(params))
365 .await?;
366
367 if login.token.is_empty() {
369 return Err(Error::EmptyToken);
370 }
371
372 Ok(Client {
373 token: Arc::new(tokio::sync::RwLock::new(login.token)),
374 ..self.clone()
375 })
376 }
377
378 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
381 let token_path = match token_path {
382 Some(path) => path.to_path_buf(),
383 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
384 .ok_or_else(|| {
385 Error::IoError(std::io::Error::new(
386 std::io::ErrorKind::NotFound,
387 "Could not determine user config directory",
388 ))
389 })?
390 .config_dir()
391 .join("token"),
392 };
393
394 debug!("Using token path: {:?}", token_path);
395
396 let token = match token_path.exists() {
397 true => std::fs::read_to_string(&token_path)?,
398 false => "".to_string(),
399 };
400
401 if !token.is_empty() {
402 match self.with_token(&token) {
403 Ok(client) => Ok(Client {
404 token_path: Some(token_path),
405 ..client
406 }),
407 Err(e) => {
408 warn!(
410 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
411 token_path, e
412 );
413 if let Err(remove_err) = std::fs::remove_file(&token_path) {
414 warn!("Failed to remove corrupted token file: {:?}", remove_err);
415 }
416 Ok(Client {
417 token_path: Some(token_path),
418 ..self.clone()
419 })
420 }
421 }
422 } else {
423 Ok(Client {
424 token_path: Some(token_path),
425 ..self.clone()
426 })
427 }
428 }
429
430 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
432 if token.is_empty() {
433 return Ok(self.clone());
434 }
435
436 let token_parts: Vec<&str> = token.split('.').collect();
437 if token_parts.len() != 3 {
438 return Err(Error::InvalidToken);
439 }
440
441 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
442 .decode(token_parts[1])
443 .map_err(|_| Error::InvalidToken)?;
444 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
445 let server = match payload.get("server") {
446 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
447 None => return Err(Error::InvalidToken),
448 };
449
450 Ok(Client {
451 url: format!("https://{}.edgefirst.studio", server),
452 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
453 ..self.clone()
454 })
455 }
456
457 pub async fn save_token(&self) -> Result<(), Error> {
458 let path = self.token_path.clone().unwrap_or_else(|| {
459 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
460 .map(|dirs| dirs.config_dir().join("token"))
461 .unwrap_or_else(|| PathBuf::from(".token"))
462 });
463
464 create_dir_all(path.parent().ok_or_else(|| {
465 Error::IoError(std::io::Error::new(
466 std::io::ErrorKind::InvalidInput,
467 "Token path has no parent directory",
468 ))
469 })?)?;
470 let mut file = std::fs::File::create(&path)?;
471 file.write_all(self.token.read().await.as_bytes())?;
472
473 debug!("Saved token to {:?}", path);
474
475 Ok(())
476 }
477
478 pub async fn version(&self) -> Result<String, Error> {
481 let version: HashMap<String, String> = self
482 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
483 .await?;
484 let version = version.get("version").ok_or(Error::InvalidResponse)?;
485 Ok(version.to_owned())
486 }
487
488 pub async fn logout(&self) -> Result<(), Error> {
492 {
493 let mut token = self.token.write().await;
494 *token = "".to_string();
495 }
496
497 if let Some(path) = &self.token_path
498 && path.exists()
499 {
500 fs::remove_file(path).await?;
501 }
502
503 Ok(())
504 }
505
506 pub async fn token(&self) -> String {
510 self.token.read().await.clone()
511 }
512
513 pub async fn verify_token(&self) -> Result<(), Error> {
518 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
519 .await?;
520 Ok::<(), Error>(())
521 }
522
523 pub async fn renew_token(&self) -> Result<(), Error> {
528 let params = HashMap::from([("username".to_string(), self.username().await?)]);
529 let result: LoginResult = self
530 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
531 .await?;
532
533 {
534 let mut token = self.token.write().await;
535 *token = result.token;
536 }
537
538 if self.token_path.is_some() {
539 self.save_token().await?;
540 }
541
542 Ok(())
543 }
544
545 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
546 let token = self.token.read().await;
547 if token.is_empty() {
548 return Err(Error::EmptyToken);
549 }
550
551 let token_parts: Vec<&str> = token.split('.').collect();
552 if token_parts.len() != 3 {
553 return Err(Error::InvalidToken);
554 }
555
556 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
557 .decode(token_parts[1])
558 .map_err(|_| Error::InvalidToken)?;
559 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
560 match payload.get(field) {
561 Some(value) => Ok(value.to_owned()),
562 None => Err(Error::InvalidToken),
563 }
564 }
565
566 pub fn url(&self) -> &str {
568 &self.url
569 }
570
571 pub async fn username(&self) -> Result<String, Error> {
573 match self.token_field("username").await? {
574 serde_json::Value::String(username) => Ok(username),
575 _ => Err(Error::InvalidToken),
576 }
577 }
578
579 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
581 let ts = match self.token_field("exp").await? {
582 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
583 _ => return Err(Error::InvalidToken),
584 };
585
586 match DateTime::<Utc>::from_timestamp_secs(ts) {
587 Some(dt) => Ok(dt),
588 None => Err(Error::InvalidToken),
589 }
590 }
591
592 pub async fn organization(&self) -> Result<Organization, Error> {
594 self.rpc::<(), Organization>("org.get".to_owned(), None)
595 .await
596 }
597
598 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
606 let projects = self
607 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
608 .await?;
609 if let Some(name) = name {
610 Ok(projects
611 .into_iter()
612 .filter(|p| p.name().contains(name))
613 .collect())
614 } else {
615 Ok(projects)
616 }
617 }
618
619 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
622 let params = HashMap::from([("project_id", project_id)]);
623 self.rpc("project.get".to_owned(), Some(params)).await
624 }
625
626 pub async fn datasets(
630 &self,
631 project_id: ProjectID,
632 name: Option<&str>,
633 ) -> Result<Vec<Dataset>, Error> {
634 let params = HashMap::from([("project_id", project_id)]);
635 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
636 if let Some(name) = name {
637 Ok(datasets
638 .into_iter()
639 .filter(|d| d.name().contains(name))
640 .collect())
641 } else {
642 Ok(datasets)
643 }
644 }
645
646 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
649 let params = HashMap::from([("dataset_id", dataset_id)]);
650 self.rpc("dataset.get".to_owned(), Some(params)).await
651 }
652
653 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
655 let params = HashMap::from([("dataset_id", dataset_id)]);
656 self.rpc("label.list".to_owned(), Some(params)).await
657 }
658
659 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
661 let new_label = NewLabel {
662 dataset_id,
663 labels: vec![NewLabelObject {
664 name: name.to_owned(),
665 }],
666 };
667 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
668 Ok(())
669 }
670
671 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
674 let params = HashMap::from([("label_id", label_id)]);
675 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
676 Ok(())
677 }
678
679 pub async fn create_dataset(
691 &self,
692 project_id: &str,
693 name: &str,
694 description: Option<&str>,
695 ) -> Result<DatasetID, Error> {
696 let mut params = HashMap::new();
697 params.insert("project_id", project_id);
698 params.insert("name", name);
699 if let Some(desc) = description {
700 params.insert("description", desc);
701 }
702
703 #[derive(Deserialize)]
704 struct CreateDatasetResult {
705 id: DatasetID,
706 }
707
708 let result: CreateDatasetResult =
709 self.rpc("dataset.create".to_owned(), Some(params)).await?;
710 Ok(result.id)
711 }
712
713 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
723 let params = HashMap::from([("id", dataset_id)]);
724 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
725 Ok(())
726 }
727
728 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
732 #[derive(Serialize)]
733 struct Params {
734 dataset_id: DatasetID,
735 label_id: u64,
736 label_name: String,
737 label_index: u64,
738 }
739
740 let _: String = self
741 .rpc(
742 "label.update".to_owned(),
743 Some(Params {
744 dataset_id: label.dataset_id(),
745 label_id: label.id(),
746 label_name: label.name().to_owned(),
747 label_index: label.index(),
748 }),
749 )
750 .await?;
751 Ok(())
752 }
753
754 pub async fn download_dataset(
807 &self,
808 dataset_id: DatasetID,
809 groups: &[String],
810 file_types: &[FileType],
811 output: PathBuf,
812 flatten: bool,
813 progress: Option<Sender<Progress>>,
814 ) -> Result<(), Error> {
815 let samples = self
816 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
817 .await?;
818 fs::create_dir_all(&output).await?;
819
820 let client = self.clone();
821 let file_types = file_types.to_vec();
822 let output = output.clone();
823
824 parallel_foreach_items(samples, progress, move |sample| {
825 let client = client.clone();
826 let file_types = file_types.clone();
827 let output = output.clone();
828
829 async move {
830 for file_type in file_types {
831 if let Some(data) = sample.download(&client, file_type.clone()).await? {
832 let (file_ext, is_image) = match file_type.clone() {
833 FileType::Image => (
834 infer::get(&data)
835 .expect("Failed to identify image file format for sample")
836 .extension()
837 .to_string(),
838 true,
839 ),
840 other => (other.to_string(), false),
841 };
842
843 let sequence_dir = sample
850 .sequence_name()
851 .map(|name| sanitize_path_component(name));
852
853 let target_dir = if flatten {
854 output.clone()
855 } else {
856 sequence_dir
857 .as_ref()
858 .map(|seq| output.join(seq))
859 .unwrap_or_else(|| output.clone())
860 };
861 fs::create_dir_all(&target_dir).await?;
862
863 let sanitized_sample_name = sample
864 .name()
865 .map(|name| sanitize_path_component(&name))
866 .unwrap_or_else(|| "unknown".to_string());
867
868 let image_name = sample.image_name().map(sanitize_path_component);
869
870 let file_name = if is_image {
876 if let Some(img_name) = image_name {
877 Self::build_filename(
878 &img_name,
879 flatten,
880 sequence_dir.as_ref(),
881 sample.frame_number(),
882 )
883 } else {
884 format!("{}.{}", sanitized_sample_name, file_ext)
885 }
886 } else {
887 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
888 Self::build_filename(
889 &base_name,
890 flatten,
891 sequence_dir.as_ref(),
892 sample.frame_number(),
893 )
894 };
895
896 let file_path = target_dir.join(&file_name);
897
898 let mut file = File::create(&file_path).await?;
899 file.write_all(&data).await?;
900 } else {
901 warn!(
902 "No data for sample: {}",
903 sample
904 .id()
905 .map(|id| id.to_string())
906 .unwrap_or_else(|| "unknown".to_string())
907 );
908 }
909 }
910
911 Ok(())
912 }
913 })
914 .await
915 }
916
917 fn build_filename(
933 base_name: &str,
934 flatten: bool,
935 sequence_name: Option<&String>,
936 frame_number: Option<u32>,
937 ) -> String {
938 if !flatten || sequence_name.is_none() {
939 return base_name.to_string();
940 }
941
942 let seq_name = sequence_name.unwrap();
943 let prefix = format!("{}_", seq_name);
944
945 if base_name.starts_with(&prefix) {
947 base_name.to_string()
948 } else {
949 match frame_number {
951 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
952 None => format!("{}{}", prefix, base_name),
953 }
954 }
955 }
956
957 pub async fn annotation_sets(
959 &self,
960 dataset_id: DatasetID,
961 ) -> Result<Vec<AnnotationSet>, Error> {
962 let params = HashMap::from([("dataset_id", dataset_id)]);
963 self.rpc("annset.list".to_owned(), Some(params)).await
964 }
965
966 pub async fn create_annotation_set(
978 &self,
979 dataset_id: DatasetID,
980 name: &str,
981 description: Option<&str>,
982 ) -> Result<AnnotationSetID, Error> {
983 #[derive(Serialize)]
984 struct Params<'a> {
985 dataset_id: DatasetID,
986 name: &'a str,
987 operator: &'a str,
988 #[serde(skip_serializing_if = "Option::is_none")]
989 description: Option<&'a str>,
990 }
991
992 #[derive(Deserialize)]
993 struct CreateAnnotationSetResult {
994 id: AnnotationSetID,
995 }
996
997 let username = self.username().await?;
998 let result: CreateAnnotationSetResult = self
999 .rpc(
1000 "annset.add".to_owned(),
1001 Some(Params {
1002 dataset_id,
1003 name,
1004 operator: &username,
1005 description,
1006 }),
1007 )
1008 .await?;
1009 Ok(result.id)
1010 }
1011
1012 pub async fn delete_annotation_set(
1023 &self,
1024 annotation_set_id: AnnotationSetID,
1025 ) -> Result<(), Error> {
1026 let params = HashMap::from([("id", annotation_set_id)]);
1027 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1028 Ok(())
1029 }
1030
1031 pub async fn annotation_set(
1033 &self,
1034 annotation_set_id: AnnotationSetID,
1035 ) -> Result<AnnotationSet, Error> {
1036 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1037 self.rpc("annset.get".to_owned(), Some(params)).await
1038 }
1039
1040 pub async fn annotations(
1053 &self,
1054 annotation_set_id: AnnotationSetID,
1055 groups: &[String],
1056 annotation_types: &[AnnotationType],
1057 progress: Option<Sender<Progress>>,
1058 ) -> Result<Vec<Annotation>, Error> {
1059 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1060 let labels = self
1061 .labels(dataset_id)
1062 .await?
1063 .into_iter()
1064 .map(|label| (label.name().to_string(), label.index()))
1065 .collect::<HashMap<_, _>>();
1066 let total = self
1067 .samples_count(
1068 dataset_id,
1069 Some(annotation_set_id),
1070 annotation_types,
1071 groups,
1072 &[],
1073 )
1074 .await?
1075 .total as usize;
1076
1077 if total == 0 {
1078 return Ok(vec![]);
1079 }
1080
1081 let context = FetchContext {
1082 dataset_id,
1083 annotation_set_id: Some(annotation_set_id),
1084 groups,
1085 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1086 labels: &labels,
1087 };
1088
1089 self.fetch_annotations_paginated(context, total, progress)
1090 .await
1091 }
1092
1093 async fn fetch_annotations_paginated(
1094 &self,
1095 context: FetchContext<'_>,
1096 total: usize,
1097 progress: Option<Sender<Progress>>,
1098 ) -> Result<Vec<Annotation>, Error> {
1099 let mut annotations = vec![];
1100 let mut continue_token: Option<String> = None;
1101 let mut current = 0;
1102
1103 loop {
1104 let params = SamplesListParams {
1105 dataset_id: context.dataset_id,
1106 annotation_set_id: context.annotation_set_id,
1107 types: context.types.clone(),
1108 group_names: context.groups.to_vec(),
1109 continue_token,
1110 };
1111
1112 let result: SamplesListResult =
1113 self.rpc("samples.list".to_owned(), Some(params)).await?;
1114 current += result.samples.len();
1115 continue_token = result.continue_token;
1116
1117 if result.samples.is_empty() {
1118 break;
1119 }
1120
1121 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1122
1123 if let Some(progress) = &progress {
1124 let _ = progress.send(Progress { current, total }).await;
1125 }
1126
1127 match &continue_token {
1128 Some(token) if !token.is_empty() => continue,
1129 _ => break,
1130 }
1131 }
1132
1133 drop(progress);
1134 Ok(annotations)
1135 }
1136
1137 fn process_sample_annotations(
1138 &self,
1139 samples: &[Sample],
1140 labels: &HashMap<String, u64>,
1141 annotations: &mut Vec<Annotation>,
1142 ) {
1143 for sample in samples {
1144 if sample.annotations().is_empty() {
1145 let mut annotation = Annotation::new();
1146 annotation.set_sample_id(sample.id());
1147 annotation.set_name(sample.name());
1148 annotation.set_sequence_name(sample.sequence_name().cloned());
1149 annotation.set_frame_number(sample.frame_number());
1150 annotation.set_group(sample.group().cloned());
1151 annotations.push(annotation);
1152 continue;
1153 }
1154
1155 for annotation in sample.annotations() {
1156 let mut annotation = annotation.clone();
1157 annotation.set_sample_id(sample.id());
1158 annotation.set_name(sample.name());
1159 annotation.set_sequence_name(sample.sequence_name().cloned());
1160 annotation.set_frame_number(sample.frame_number());
1161 annotation.set_group(sample.group().cloned());
1162 Self::set_label_index_from_map(&mut annotation, labels);
1163 annotations.push(annotation);
1164 }
1165 }
1166 }
1167
1168 fn parse_frame_from_image_name(
1176 image_name: Option<&String>,
1177 sequence_name: Option<&String>,
1178 ) -> Option<u32> {
1179 use std::path::Path;
1180
1181 let sequence = sequence_name?;
1182 let name = image_name?;
1183
1184 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1186
1187 stem.strip_prefix(sequence)
1189 .and_then(|suffix| suffix.strip_prefix('_'))
1190 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1191 }
1192
1193 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1195 if let Some(label) = annotation.label() {
1196 annotation.set_label_index(Some(labels[label.as_str()]));
1197 }
1198 }
1199
1200 pub async fn samples_count(
1201 &self,
1202 dataset_id: DatasetID,
1203 annotation_set_id: Option<AnnotationSetID>,
1204 annotation_types: &[AnnotationType],
1205 groups: &[String],
1206 types: &[FileType],
1207 ) -> Result<SamplesCountResult, Error> {
1208 let types = annotation_types
1209 .iter()
1210 .map(|t| t.to_string())
1211 .chain(types.iter().map(|t| t.to_string()))
1212 .collect::<Vec<_>>();
1213
1214 let params = SamplesListParams {
1215 dataset_id,
1216 annotation_set_id,
1217 group_names: groups.to_vec(),
1218 types,
1219 continue_token: None,
1220 };
1221
1222 self.rpc("samples.count".to_owned(), Some(params)).await
1223 }
1224
1225 pub async fn samples(
1226 &self,
1227 dataset_id: DatasetID,
1228 annotation_set_id: Option<AnnotationSetID>,
1229 annotation_types: &[AnnotationType],
1230 groups: &[String],
1231 types: &[FileType],
1232 progress: Option<Sender<Progress>>,
1233 ) -> Result<Vec<Sample>, Error> {
1234 let types_vec = annotation_types
1235 .iter()
1236 .map(|t| t.to_string())
1237 .chain(types.iter().map(|t| t.to_string()))
1238 .collect::<Vec<_>>();
1239 let labels = self
1240 .labels(dataset_id)
1241 .await?
1242 .into_iter()
1243 .map(|label| (label.name().to_string(), label.index()))
1244 .collect::<HashMap<_, _>>();
1245 let total = self
1246 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1247 .await?
1248 .total as usize;
1249
1250 if total == 0 {
1251 return Ok(vec![]);
1252 }
1253
1254 let context = FetchContext {
1255 dataset_id,
1256 annotation_set_id,
1257 groups,
1258 types: types_vec,
1259 labels: &labels,
1260 };
1261
1262 self.fetch_samples_paginated(context, total, progress).await
1263 }
1264
1265 async fn fetch_samples_paginated(
1266 &self,
1267 context: FetchContext<'_>,
1268 total: usize,
1269 progress: Option<Sender<Progress>>,
1270 ) -> Result<Vec<Sample>, Error> {
1271 let mut samples = vec![];
1272 let mut continue_token: Option<String> = None;
1273 let mut current = 0;
1274
1275 loop {
1276 let params = SamplesListParams {
1277 dataset_id: context.dataset_id,
1278 annotation_set_id: context.annotation_set_id,
1279 types: context.types.clone(),
1280 group_names: context.groups.to_vec(),
1281 continue_token: continue_token.clone(),
1282 };
1283
1284 let result: SamplesListResult =
1285 self.rpc("samples.list".to_owned(), Some(params)).await?;
1286 current += result.samples.len();
1287 continue_token = result.continue_token;
1288
1289 if result.samples.is_empty() {
1290 break;
1291 }
1292
1293 samples.append(
1294 &mut result
1295 .samples
1296 .into_iter()
1297 .map(|s| {
1298 let frame_number = s.frame_number.or_else(|| {
1303 Self::parse_frame_from_image_name(
1304 s.image_name.as_ref(),
1305 s.sequence_name.as_ref(),
1306 )
1307 });
1308
1309 let mut anns = s.annotations().to_vec();
1310 for ann in &mut anns {
1311 ann.set_name(s.name());
1313 ann.set_group(s.group().cloned());
1314 ann.set_sequence_name(s.sequence_name().cloned());
1315 ann.set_frame_number(frame_number);
1316 Self::set_label_index_from_map(ann, context.labels);
1317 }
1318 s.with_annotations(anns).with_frame_number(frame_number)
1319 })
1320 .collect::<Vec<_>>(),
1321 );
1322
1323 if let Some(progress) = &progress {
1324 let _ = progress.send(Progress { current, total }).await;
1325 }
1326
1327 match &continue_token {
1328 Some(token) if !token.is_empty() => continue,
1329 _ => break,
1330 }
1331 }
1332
1333 drop(progress);
1334 Ok(samples)
1335 }
1336
1337 pub async fn populate_samples(
1429 &self,
1430 dataset_id: DatasetID,
1431 annotation_set_id: Option<AnnotationSetID>,
1432 samples: Vec<Sample>,
1433 progress: Option<Sender<Progress>>,
1434 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1435 use crate::api::SamplesPopulateParams;
1436
1437 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1439
1440 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1442
1443 let has_files_to_upload = !files_to_upload.is_empty();
1444
1445 let params = SamplesPopulateParams {
1447 dataset_id,
1448 annotation_set_id,
1449 presigned_urls: Some(has_files_to_upload),
1450 samples,
1451 };
1452
1453 let results: Vec<crate::SamplesPopulateResult> = self
1454 .rpc("samples.populate2".to_owned(), Some(params))
1455 .await?;
1456
1457 if has_files_to_upload {
1459 self.upload_sample_files(&results, files_to_upload, progress)
1460 .await?;
1461 }
1462
1463 Ok(results)
1464 }
1465
1466 fn prepare_samples_for_upload(
1467 &self,
1468 samples: Vec<Sample>,
1469 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1470 ) -> Result<Vec<Sample>, Error> {
1471 Ok(samples
1472 .into_iter()
1473 .map(|mut sample| {
1474 if sample.uuid.is_none() {
1476 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1477 }
1478
1479 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1480
1481 let files_copy = sample.files.clone();
1483 let updated_files: Vec<crate::SampleFile> = files_copy
1484 .iter()
1485 .map(|file| {
1486 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1487 })
1488 .collect();
1489
1490 sample.files = updated_files;
1491 sample
1492 })
1493 .collect())
1494 }
1495
1496 fn process_sample_file(
1497 &self,
1498 file: &crate::SampleFile,
1499 sample_uuid: &str,
1500 sample: &mut Sample,
1501 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1502 ) -> crate::SampleFile {
1503 use std::path::Path;
1504
1505 if let Some(filename) = file.filename() {
1506 let path = Path::new(filename);
1507
1508 if path.exists()
1510 && path.is_file()
1511 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1512 {
1513 if file.file_type() == "image"
1515 && (sample.width.is_none() || sample.height.is_none())
1516 && let Ok(size) = imagesize::size(path)
1517 {
1518 sample.width = Some(size.width as u32);
1519 sample.height = Some(size.height as u32);
1520 }
1521
1522 files_to_upload.push((
1524 sample_uuid.to_string(),
1525 file.file_type().to_string(),
1526 path.to_path_buf(),
1527 basename.to_string(),
1528 ));
1529
1530 return crate::SampleFile::with_filename(
1532 file.file_type().to_string(),
1533 basename.to_string(),
1534 );
1535 }
1536 }
1537 file.clone()
1539 }
1540
1541 async fn upload_sample_files(
1542 &self,
1543 results: &[crate::SamplesPopulateResult],
1544 files_to_upload: Vec<(String, String, PathBuf, String)>,
1545 progress: Option<Sender<Progress>>,
1546 ) -> Result<(), Error> {
1547 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1549 for (uuid, _file_type, path, basename) in files_to_upload {
1550 upload_map.insert((uuid, basename), path);
1551 }
1552
1553 let http = self.http.clone();
1554
1555 let upload_tasks: Vec<_> = results
1557 .iter()
1558 .map(|result| (result.uuid.clone(), result.urls.clone()))
1559 .collect();
1560
1561 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1562 let http = http.clone();
1563 let upload_map = upload_map.clone();
1564
1565 async move {
1566 for url_info in &urls {
1568 if let Some(local_path) =
1569 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1570 {
1571 upload_file_to_presigned_url(
1573 http.clone(),
1574 &url_info.url,
1575 local_path.clone(),
1576 )
1577 .await?;
1578 }
1579 }
1580
1581 Ok(())
1582 }
1583 })
1584 .await
1585 }
1586
1587 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1588 let resp = self.http.get(url).send().await?;
1590
1591 if !resp.status().is_success() {
1592 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
1593 }
1594
1595 let bytes = resp.bytes().await?;
1596 Ok(bytes.to_vec())
1597 }
1598
1599 #[deprecated(
1639 since = "0.8.0",
1640 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
1641 )]
1642 #[cfg(feature = "polars")]
1643 pub async fn annotations_dataframe(
1644 &self,
1645 annotation_set_id: AnnotationSetID,
1646 groups: &[String],
1647 types: &[AnnotationType],
1648 progress: Option<Sender<Progress>>,
1649 ) -> Result<DataFrame, Error> {
1650 use crate::dataset::annotations_dataframe;
1651
1652 let annotations = self
1653 .annotations(annotation_set_id, groups, types, progress)
1654 .await?;
1655 #[allow(deprecated)]
1656 annotations_dataframe(&annotations)
1657 }
1658
1659 #[cfg(feature = "polars")]
1696 pub async fn samples_dataframe(
1697 &self,
1698 dataset_id: DatasetID,
1699 annotation_set_id: Option<AnnotationSetID>,
1700 groups: &[String],
1701 types: &[AnnotationType],
1702 progress: Option<Sender<Progress>>,
1703 ) -> Result<DataFrame, Error> {
1704 use crate::dataset::samples_dataframe;
1705
1706 let samples = self
1707 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
1708 .await?;
1709 samples_dataframe(&samples)
1710 }
1711
1712 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1715 let snapshots: Vec<Snapshot> = self
1716 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1717 .await?;
1718 if let Some(name) = name {
1719 Ok(snapshots
1720 .into_iter()
1721 .filter(|s| s.description().contains(name))
1722 .collect())
1723 } else {
1724 Ok(snapshots)
1725 }
1726 }
1727
1728 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1730 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1731 self.rpc("snapshots.get".to_owned(), Some(params)).await
1732 }
1733
1734 pub async fn create_snapshot(
1816 &self,
1817 path: &str,
1818 progress: Option<Sender<Progress>>,
1819 ) -> Result<Snapshot, Error> {
1820 let path = Path::new(path);
1821
1822 if path.is_dir() {
1823 let path_str = path.to_str().ok_or_else(|| {
1824 Error::IoError(std::io::Error::new(
1825 std::io::ErrorKind::InvalidInput,
1826 "Path contains invalid UTF-8",
1827 ))
1828 })?;
1829 return self.create_snapshot_folder(path_str, progress).await;
1830 }
1831
1832 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1833 Error::IoError(std::io::Error::new(
1834 std::io::ErrorKind::InvalidInput,
1835 "Invalid filename",
1836 ))
1837 })?;
1838 let total = path.metadata()?.len() as usize;
1839 let current = Arc::new(AtomicUsize::new(0));
1840
1841 if let Some(progress) = &progress {
1842 let _ = progress.send(Progress { current: 0, total }).await;
1843 }
1844
1845 let params = SnapshotCreateMultipartParams {
1846 snapshot_name: name.to_owned(),
1847 keys: vec![name.to_owned()],
1848 file_sizes: vec![total],
1849 };
1850 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1851 .rpc(
1852 "snapshots.create_upload_url_multipart".to_owned(),
1853 Some(params),
1854 )
1855 .await?;
1856
1857 let snapshot_id = match multipart.get("snapshot_id") {
1858 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1859 _ => return Err(Error::InvalidResponse),
1860 };
1861
1862 let snapshot = self.snapshot(snapshot_id).await?;
1863 let part_prefix = snapshot
1864 .path()
1865 .split("::/")
1866 .last()
1867 .ok_or(Error::InvalidResponse)?
1868 .to_owned();
1869 let part_key = format!("{}/{}", part_prefix, name);
1870 let mut part = match multipart.get(&part_key) {
1871 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1872 _ => return Err(Error::InvalidResponse),
1873 }
1874 .clone();
1875 part.key = Some(part_key);
1876
1877 let params = upload_multipart(
1878 self.http.clone(),
1879 part.clone(),
1880 path.to_path_buf(),
1881 total,
1882 current,
1883 progress.clone(),
1884 )
1885 .await?;
1886
1887 let complete: String = self
1888 .rpc(
1889 "snapshots.complete_multipart_upload".to_owned(),
1890 Some(params),
1891 )
1892 .await?;
1893 debug!("Snapshot Multipart Complete: {:?}", complete);
1894
1895 let params: SnapshotStatusParams = SnapshotStatusParams {
1896 snapshot_id,
1897 status: "available".to_owned(),
1898 };
1899 let _: SnapshotStatusResult = self
1900 .rpc("snapshots.update".to_owned(), Some(params))
1901 .await?;
1902
1903 if let Some(progress) = progress {
1904 drop(progress);
1905 }
1906
1907 self.snapshot(snapshot_id).await
1908 }
1909
1910 async fn create_snapshot_folder(
1911 &self,
1912 path: &str,
1913 progress: Option<Sender<Progress>>,
1914 ) -> Result<Snapshot, Error> {
1915 let path = Path::new(path);
1916 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1917 Error::IoError(std::io::Error::new(
1918 std::io::ErrorKind::InvalidInput,
1919 "Invalid directory name",
1920 ))
1921 })?;
1922
1923 let files = WalkDir::new(path)
1924 .into_iter()
1925 .filter_map(|entry| entry.ok())
1926 .filter(|entry| entry.file_type().is_file())
1927 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
1928 .collect::<Vec<_>>();
1929
1930 let total: usize = files
1931 .iter()
1932 .filter_map(|file| path.join(file).metadata().ok())
1933 .map(|metadata| metadata.len() as usize)
1934 .sum();
1935 let current = Arc::new(AtomicUsize::new(0));
1936
1937 if let Some(progress) = &progress {
1938 let _ = progress.send(Progress { current: 0, total }).await;
1939 }
1940
1941 let keys = files
1942 .iter()
1943 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
1944 .collect::<Vec<_>>();
1945 let file_sizes = files
1946 .iter()
1947 .filter_map(|key| path.join(key).metadata().ok())
1948 .map(|metadata| metadata.len() as usize)
1949 .collect::<Vec<_>>();
1950
1951 let params = SnapshotCreateMultipartParams {
1952 snapshot_name: name.to_owned(),
1953 keys,
1954 file_sizes,
1955 };
1956
1957 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1958 .rpc(
1959 "snapshots.create_upload_url_multipart".to_owned(),
1960 Some(params),
1961 )
1962 .await?;
1963
1964 let snapshot_id = match multipart.get("snapshot_id") {
1965 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1966 _ => return Err(Error::InvalidResponse),
1967 };
1968
1969 let snapshot = self.snapshot(snapshot_id).await?;
1970 let part_prefix = snapshot
1971 .path()
1972 .split("::/")
1973 .last()
1974 .ok_or(Error::InvalidResponse)?
1975 .to_owned();
1976
1977 for file in files {
1978 let file_str = file.to_str().ok_or_else(|| {
1979 Error::IoError(std::io::Error::new(
1980 std::io::ErrorKind::InvalidInput,
1981 "File path contains invalid UTF-8",
1982 ))
1983 })?;
1984 let part_key = format!("{}/{}", part_prefix, file_str);
1985 let mut part = match multipart.get(&part_key) {
1986 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1987 _ => return Err(Error::InvalidResponse),
1988 }
1989 .clone();
1990 part.key = Some(part_key);
1991
1992 let params = upload_multipart(
1993 self.http.clone(),
1994 part.clone(),
1995 path.join(file),
1996 total,
1997 current.clone(),
1998 progress.clone(),
1999 )
2000 .await?;
2001
2002 let complete: String = self
2003 .rpc(
2004 "snapshots.complete_multipart_upload".to_owned(),
2005 Some(params),
2006 )
2007 .await?;
2008 debug!("Snapshot Part Complete: {:?}", complete);
2009 }
2010
2011 let params = SnapshotStatusParams {
2012 snapshot_id,
2013 status: "available".to_owned(),
2014 };
2015 let _: SnapshotStatusResult = self
2016 .rpc("snapshots.update".to_owned(), Some(params))
2017 .await?;
2018
2019 if let Some(progress) = progress {
2020 drop(progress);
2021 }
2022
2023 self.snapshot(snapshot_id).await
2024 }
2025
2026 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
2059 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2060 let _: String = self
2061 .rpc("snapshots.delete".to_owned(), Some(params))
2062 .await?;
2063 Ok(())
2064 }
2065
2066 pub async fn download_snapshot(
2119 &self,
2120 snapshot_id: SnapshotID,
2121 output: PathBuf,
2122 progress: Option<Sender<Progress>>,
2123 ) -> Result<(), Error> {
2124 fs::create_dir_all(&output).await?;
2125
2126 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2127 let items: HashMap<String, String> = self
2128 .rpc("snapshots.create_download_url".to_owned(), Some(params))
2129 .await?;
2130
2131 let total = Arc::new(AtomicUsize::new(0));
2132 let current = Arc::new(AtomicUsize::new(0));
2133 let sem = Arc::new(Semaphore::new(max_tasks()));
2134
2135 let tasks = items
2136 .iter()
2137 .map(|(key, url)| {
2138 let http = self.http.clone();
2139 let key = key.clone();
2140 let url = url.clone();
2141 let output = output.clone();
2142 let progress = progress.clone();
2143 let current = current.clone();
2144 let total = total.clone();
2145 let sem = sem.clone();
2146
2147 tokio::spawn(async move {
2148 let _permit = sem.acquire().await.map_err(|_| {
2149 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2150 })?;
2151 let res = http.get(url).send().await?;
2152 let content_length = res.content_length().unwrap_or(0) as usize;
2153
2154 if let Some(progress) = &progress {
2155 let total = total.fetch_add(content_length, Ordering::SeqCst);
2156 let _ = progress
2157 .send(Progress {
2158 current: current.load(Ordering::SeqCst),
2159 total: total + content_length,
2160 })
2161 .await;
2162 }
2163
2164 let mut file = File::create(output.join(key)).await?;
2165 let mut stream = res.bytes_stream();
2166
2167 while let Some(chunk) = stream.next().await {
2168 let chunk = chunk?;
2169 file.write_all(&chunk).await?;
2170 let len = chunk.len();
2171
2172 if let Some(progress) = &progress {
2173 let total = total.load(Ordering::SeqCst);
2174 let current = current.fetch_add(len, Ordering::SeqCst);
2175
2176 let _ = progress
2177 .send(Progress {
2178 current: current + len,
2179 total,
2180 })
2181 .await;
2182 }
2183 }
2184
2185 Ok::<(), Error>(())
2186 })
2187 })
2188 .collect::<Vec<_>>();
2189
2190 join_all(tasks)
2191 .await
2192 .into_iter()
2193 .collect::<Result<Vec<_>, _>>()?
2194 .into_iter()
2195 .collect::<Result<Vec<_>, _>>()?;
2196
2197 Ok(())
2198 }
2199
2200 #[allow(clippy::too_many_arguments)]
2267 pub async fn restore_snapshot(
2268 &self,
2269 project_id: ProjectID,
2270 snapshot_id: SnapshotID,
2271 topics: &[String],
2272 autolabel: &[String],
2273 autodepth: bool,
2274 dataset_name: Option<&str>,
2275 dataset_description: Option<&str>,
2276 ) -> Result<SnapshotRestoreResult, Error> {
2277 let params = SnapshotRestore {
2278 project_id,
2279 snapshot_id,
2280 fps: 1,
2281 autodepth,
2282 agtg_pipeline: !autolabel.is_empty(),
2283 autolabel: autolabel.to_vec(),
2284 topics: topics.to_vec(),
2285 dataset_name: dataset_name.map(|s| s.to_owned()),
2286 dataset_description: dataset_description.map(|s| s.to_owned()),
2287 };
2288 self.rpc("snapshots.restore".to_owned(), Some(params)).await
2289 }
2290
2291 pub async fn experiments(
2300 &self,
2301 project_id: ProjectID,
2302 name: Option<&str>,
2303 ) -> Result<Vec<Experiment>, Error> {
2304 let params = HashMap::from([("project_id", project_id)]);
2305 let experiments: Vec<Experiment> =
2306 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
2307 if let Some(name) = name {
2308 Ok(experiments
2309 .into_iter()
2310 .filter(|e| e.name().contains(name))
2311 .collect())
2312 } else {
2313 Ok(experiments)
2314 }
2315 }
2316
2317 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
2320 let params = HashMap::from([("trainer_id", experiment_id)]);
2321 self.rpc("trainer.get".to_owned(), Some(params)).await
2322 }
2323
2324 pub async fn training_sessions(
2333 &self,
2334 experiment_id: ExperimentID,
2335 name: Option<&str>,
2336 ) -> Result<Vec<TrainingSession>, Error> {
2337 let params = HashMap::from([("trainer_id", experiment_id)]);
2338 let sessions: Vec<TrainingSession> = self
2339 .rpc("trainer.session.list".to_owned(), Some(params))
2340 .await?;
2341 if let Some(name) = name {
2342 Ok(sessions
2343 .into_iter()
2344 .filter(|s| s.name().contains(name))
2345 .collect())
2346 } else {
2347 Ok(sessions)
2348 }
2349 }
2350
2351 pub async fn training_session(
2354 &self,
2355 session_id: TrainingSessionID,
2356 ) -> Result<TrainingSession, Error> {
2357 let params = HashMap::from([("trainer_session_id", session_id)]);
2358 self.rpc("trainer.session.get".to_owned(), Some(params))
2359 .await
2360 }
2361
2362 pub async fn validation_sessions(
2364 &self,
2365 project_id: ProjectID,
2366 ) -> Result<Vec<ValidationSession>, Error> {
2367 let params = HashMap::from([("project_id", project_id)]);
2368 self.rpc("validate.session.list".to_owned(), Some(params))
2369 .await
2370 }
2371
2372 pub async fn validation_session(
2374 &self,
2375 session_id: ValidationSessionID,
2376 ) -> Result<ValidationSession, Error> {
2377 let params = HashMap::from([("validate_session_id", session_id)]);
2378 self.rpc("validate.session.get".to_owned(), Some(params))
2379 .await
2380 }
2381
2382 pub async fn artifacts(
2385 &self,
2386 training_session_id: TrainingSessionID,
2387 ) -> Result<Vec<Artifact>, Error> {
2388 let params = HashMap::from([("training_session_id", training_session_id)]);
2389 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2390 .await
2391 }
2392
2393 pub async fn download_artifact(
2399 &self,
2400 training_session_id: TrainingSessionID,
2401 modelname: &str,
2402 filename: Option<PathBuf>,
2403 progress: Option<Sender<Progress>>,
2404 ) -> Result<(), Error> {
2405 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2406 let resp = self
2407 .http
2408 .get(format!(
2409 "{}/download_model?training_session_id={}&file={}",
2410 self.url,
2411 training_session_id.value(),
2412 modelname
2413 ))
2414 .header("Authorization", format!("Bearer {}", self.token().await))
2415 .send()
2416 .await?;
2417 if !resp.status().is_success() {
2418 let err = resp.error_for_status_ref().unwrap_err();
2419 return Err(Error::HttpError(err));
2420 }
2421
2422 if let Some(parent) = filename.parent() {
2423 fs::create_dir_all(parent).await?;
2424 }
2425
2426 if let Some(progress) = progress {
2427 let total = resp.content_length().unwrap_or(0) as usize;
2428 let _ = progress.send(Progress { current: 0, total }).await;
2429
2430 let mut file = File::create(filename).await?;
2431 let mut current = 0;
2432 let mut stream = resp.bytes_stream();
2433
2434 while let Some(item) = stream.next().await {
2435 let chunk = item?;
2436 file.write_all(&chunk).await?;
2437 current += chunk.len();
2438 let _ = progress.send(Progress { current, total }).await;
2439 }
2440 } else {
2441 let body = resp.bytes().await?;
2442 fs::write(filename, body).await?;
2443 }
2444
2445 Ok(())
2446 }
2447
2448 pub async fn download_checkpoint(
2458 &self,
2459 training_session_id: TrainingSessionID,
2460 checkpoint: &str,
2461 filename: Option<PathBuf>,
2462 progress: Option<Sender<Progress>>,
2463 ) -> Result<(), Error> {
2464 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2465 let resp = self
2466 .http
2467 .get(format!(
2468 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2469 self.url,
2470 training_session_id.value(),
2471 checkpoint
2472 ))
2473 .header("Authorization", format!("Bearer {}", self.token().await))
2474 .send()
2475 .await?;
2476 if !resp.status().is_success() {
2477 let err = resp.error_for_status_ref().unwrap_err();
2478 return Err(Error::HttpError(err));
2479 }
2480
2481 if let Some(parent) = filename.parent() {
2482 fs::create_dir_all(parent).await?;
2483 }
2484
2485 if let Some(progress) = progress {
2486 let total = resp.content_length().unwrap_or(0) as usize;
2487 let _ = progress.send(Progress { current: 0, total }).await;
2488
2489 let mut file = File::create(filename).await?;
2490 let mut current = 0;
2491 let mut stream = resp.bytes_stream();
2492
2493 while let Some(item) = stream.next().await {
2494 let chunk = item?;
2495 file.write_all(&chunk).await?;
2496 current += chunk.len();
2497 let _ = progress.send(Progress { current, total }).await;
2498 }
2499 } else {
2500 let body = resp.bytes().await?;
2501 fs::write(filename, body).await?;
2502 }
2503
2504 Ok(())
2505 }
2506
2507 pub async fn tasks(
2522 &self,
2523 name: Option<&str>,
2524 workflow: Option<&str>,
2525 status: Option<&str>,
2526 manager: Option<&str>,
2527 ) -> Result<Vec<Task>, Error> {
2528 let mut params = TasksListParams {
2529 continue_token: None,
2530 types: workflow.map(|w| vec![w.to_owned()]),
2531 status: status.map(|s| vec![s.to_owned()]),
2532 manager: manager.map(|m| vec![m.to_owned()]),
2533 };
2534 let mut tasks = Vec::new();
2535
2536 loop {
2537 let result = self
2538 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
2539 .await?;
2540 tasks.extend(result.tasks);
2541
2542 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
2543 params.continue_token = None;
2544 } else {
2545 params.continue_token = result.continue_token;
2546 }
2547
2548 if params.continue_token.is_none() {
2549 break;
2550 }
2551 }
2552
2553 if let Some(name) = name {
2554 tasks.retain(|t| t.name().contains(name));
2555 }
2556
2557 Ok(tasks)
2558 }
2559
2560 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
2562 self.rpc(
2563 "task.get".to_owned(),
2564 Some(HashMap::from([("id", task_id)])),
2565 )
2566 .await
2567 }
2568
2569 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
2571 let status = TaskStatus {
2572 task_id,
2573 status: status.to_owned(),
2574 };
2575 self.rpc("docker.update.status".to_owned(), Some(status))
2576 .await
2577 }
2578
2579 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
2583 let stages: Vec<HashMap<String, String>> = stages
2584 .iter()
2585 .map(|(key, value)| {
2586 let mut stage_map = HashMap::new();
2587 stage_map.insert(key.to_string(), value.to_string());
2588 stage_map
2589 })
2590 .collect();
2591 let params = TaskStages { task_id, stages };
2592 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
2593 Ok(())
2594 }
2595
2596 pub async fn update_stage(
2599 &self,
2600 task_id: TaskID,
2601 stage: &str,
2602 status: &str,
2603 message: &str,
2604 percentage: u8,
2605 ) -> Result<(), Error> {
2606 let stage = Stage::new(
2607 Some(task_id),
2608 stage.to_owned(),
2609 Some(status.to_owned()),
2610 Some(message.to_owned()),
2611 percentage,
2612 );
2613 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
2614 Ok(())
2615 }
2616
2617 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
2619 let req = self
2620 .http
2621 .get(format!("{}/{}", self.url, query))
2622 .header("User-Agent", "EdgeFirst Client")
2623 .header("Authorization", format!("Bearer {}", self.token().await));
2624 let resp = req.send().await?;
2625
2626 if resp.status().is_success() {
2627 let body = resp.bytes().await?;
2628
2629 if log_enabled!(Level::Trace) {
2630 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
2631 }
2632
2633 Ok(body.to_vec())
2634 } else {
2635 let err = resp.error_for_status_ref().unwrap_err();
2636 Err(Error::HttpError(err))
2637 }
2638 }
2639
2640 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2644 let req = self
2645 .http
2646 .post(format!("{}/api?method={}", self.url, method))
2647 .header("Accept", "application/json")
2648 .header("User-Agent", "EdgeFirst Client")
2649 .header("Authorization", format!("Bearer {}", self.token().await))
2650 .multipart(form);
2651 let resp = req.send().await?;
2652
2653 if resp.status().is_success() {
2654 let body = resp.bytes().await?;
2655
2656 if log_enabled!(Level::Trace) {
2657 trace!(
2658 "POST Multipart Response: {}",
2659 String::from_utf8_lossy(&body)
2660 );
2661 }
2662
2663 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2664 Ok(response) => response,
2665 Err(err) => {
2666 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2667 return Err(err.into());
2668 }
2669 };
2670
2671 if let Some(error) = response.error {
2672 Err(Error::RpcError(error.code, error.message))
2673 } else if let Some(result) = response.result {
2674 Ok(result)
2675 } else {
2676 Err(Error::InvalidResponse)
2677 }
2678 } else {
2679 let err = resp.error_for_status_ref().unwrap_err();
2680 Err(Error::HttpError(err))
2681 }
2682 }
2683
2684 pub async fn rpc<Params, RpcResult>(
2693 &self,
2694 method: String,
2695 params: Option<Params>,
2696 ) -> Result<RpcResult, Error>
2697 where
2698 Params: Serialize,
2699 RpcResult: DeserializeOwned,
2700 {
2701 let auth_expires = self.token_expiration().await?;
2702 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2703 self.renew_token().await?;
2704 }
2705
2706 self.rpc_without_auth(method, params).await
2707 }
2708
2709 async fn rpc_without_auth<Params, RpcResult>(
2710 &self,
2711 method: String,
2712 params: Option<Params>,
2713 ) -> Result<RpcResult, Error>
2714 where
2715 Params: Serialize,
2716 RpcResult: DeserializeOwned,
2717 {
2718 let request = RpcRequest {
2719 method,
2720 params,
2721 ..Default::default()
2722 };
2723
2724 if log_enabled!(Level::Trace) {
2725 trace!(
2726 "RPC Request: {}",
2727 serde_json::ser::to_string_pretty(&request)?
2728 );
2729 }
2730
2731 let url = format!("{}/api", self.url);
2732
2733 let res = self
2736 .http
2737 .post(&url)
2738 .header("Accept", "application/json")
2739 .header("User-Agent", "EdgeFirst Client")
2740 .header("Authorization", format!("Bearer {}", self.token().await))
2741 .json(&request)
2742 .send()
2743 .await?;
2744
2745 self.process_rpc_response(res).await
2746 }
2747
2748 async fn process_rpc_response<RpcResult>(
2749 &self,
2750 res: reqwest::Response,
2751 ) -> Result<RpcResult, Error>
2752 where
2753 RpcResult: DeserializeOwned,
2754 {
2755 let body = res.bytes().await?;
2756
2757 if log_enabled!(Level::Trace) {
2758 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2759 }
2760
2761 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2762 Ok(response) => response,
2763 Err(err) => {
2764 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2765 return Err(err.into());
2766 }
2767 };
2768
2769 if let Some(error) = response.error {
2775 Err(Error::RpcError(error.code, error.message))
2776 } else if let Some(result) = response.result {
2777 Ok(result)
2778 } else {
2779 Err(Error::InvalidResponse)
2780 }
2781 }
2782}
2783
2784async fn parallel_foreach_items<T, F, Fut>(
2815 items: Vec<T>,
2816 progress: Option<Sender<Progress>>,
2817 work_fn: F,
2818) -> Result<(), Error>
2819where
2820 T: Send + 'static,
2821 F: Fn(T) -> Fut + Send + Sync + 'static,
2822 Fut: Future<Output = Result<(), Error>> + Send + 'static,
2823{
2824 let total = items.len();
2825 let current = Arc::new(AtomicUsize::new(0));
2826 let sem = Arc::new(Semaphore::new(max_tasks()));
2827 let work_fn = Arc::new(work_fn);
2828
2829 let tasks = items
2830 .into_iter()
2831 .map(|item| {
2832 let sem = sem.clone();
2833 let current = current.clone();
2834 let progress = progress.clone();
2835 let work_fn = work_fn.clone();
2836
2837 tokio::spawn(async move {
2838 let _permit = sem.acquire().await.map_err(|_| {
2839 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2840 })?;
2841
2842 work_fn(item).await?;
2844
2845 if let Some(progress) = &progress {
2847 let current = current.fetch_add(1, Ordering::SeqCst);
2848 let _ = progress
2849 .send(Progress {
2850 current: current + 1,
2851 total,
2852 })
2853 .await;
2854 }
2855
2856 Ok::<(), Error>(())
2857 })
2858 })
2859 .collect::<Vec<_>>();
2860
2861 join_all(tasks)
2862 .await
2863 .into_iter()
2864 .collect::<Result<Vec<_>, _>>()?
2865 .into_iter()
2866 .collect::<Result<Vec<_>, _>>()?;
2867
2868 if let Some(progress) = progress {
2869 drop(progress);
2870 }
2871
2872 Ok(())
2873}
2874
2875async fn upload_multipart(
2901 http: reqwest::Client,
2902 part: SnapshotPart,
2903 path: PathBuf,
2904 total: usize,
2905 current: Arc<AtomicUsize>,
2906 progress: Option<Sender<Progress>>,
2907) -> Result<SnapshotCompleteMultipartParams, Error> {
2908 let filesize = path.metadata()?.len() as usize;
2909 let n_parts = filesize.div_ceil(PART_SIZE);
2910 let sem = Arc::new(Semaphore::new(max_tasks()));
2911
2912 let key = part.key.ok_or(Error::InvalidResponse)?;
2913 let upload_id = part.upload_id;
2914
2915 let urls = part.urls.clone();
2916 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2918 EtagPart {
2919 etag: "".to_owned(),
2920 part_number: 0,
2921 };
2922 n_parts
2923 ]));
2924
2925 let tasks = (0..n_parts)
2927 .map(|part| {
2928 let http = http.clone();
2929 let url = urls[part].clone();
2930 let etags = etags.clone();
2931 let path = path.to_owned();
2932 let sem = sem.clone();
2933 let progress = progress.clone();
2934 let current = current.clone();
2935
2936 tokio::spawn(async move {
2937 let _permit = sem.acquire().await?;
2939
2940 let etag =
2942 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
2943
2944 let mut etags = etags.lock().await;
2946 etags[part] = EtagPart {
2947 etag,
2948 part_number: part + 1,
2949 };
2950
2951 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2953 if let Some(progress) = &progress {
2954 let _ = progress
2955 .send(Progress {
2956 current: current + PART_SIZE,
2957 total,
2958 })
2959 .await;
2960 }
2961
2962 Ok::<(), Error>(())
2963 })
2964 })
2965 .collect::<Vec<_>>();
2966
2967 join_all(tasks)
2969 .await
2970 .into_iter()
2971 .collect::<Result<Vec<_>, _>>()?;
2972
2973 Ok(SnapshotCompleteMultipartParams {
2974 key,
2975 upload_id,
2976 etag_list: etags.lock().await.clone(),
2977 })
2978}
2979
2980async fn upload_part(
2981 http: reqwest::Client,
2982 url: String,
2983 path: PathBuf,
2984 part: usize,
2985 n_parts: usize,
2986) -> Result<String, Error> {
2987 let filesize = path.metadata()?.len() as usize;
2988 let mut file = File::open(path).await?;
2989 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2990 .await?;
2991 let file = file.take(PART_SIZE as u64);
2992
2993 let body_length = if part + 1 == n_parts {
2994 filesize % PART_SIZE
2995 } else {
2996 PART_SIZE
2997 };
2998
2999 let stream = FramedRead::new(file, BytesCodec::new());
3000 let body = Body::wrap_stream(stream);
3001
3002 let resp = http
3003 .put(url.clone())
3004 .header(CONTENT_LENGTH, body_length)
3005 .body(body)
3006 .send()
3007 .await?
3008 .error_for_status()?;
3009
3010 let etag = resp
3011 .headers()
3012 .get("etag")
3013 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
3014 .to_str()
3015 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
3016 .to_owned();
3017
3018 let etag = etag
3020 .strip_prefix("\"")
3021 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
3022 let etag = etag
3023 .strip_suffix("\"")
3024 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
3025
3026 Ok(etag.to_owned())
3027}
3028
3029async fn upload_file_to_presigned_url(
3034 http: reqwest::Client,
3035 url: &str,
3036 path: PathBuf,
3037) -> Result<(), Error> {
3038 let file_data = fs::read(&path).await?;
3040 let file_size = file_data.len();
3041
3042 let resp = http
3044 .put(url)
3045 .header(CONTENT_LENGTH, file_size)
3046 .body(file_data)
3047 .send()
3048 .await?;
3049
3050 if resp.status().is_success() {
3051 debug!(
3052 "Successfully uploaded file: {:?} ({} bytes)",
3053 path, file_size
3054 );
3055 Ok(())
3056 } else {
3057 let status = resp.status();
3058 let error_text = resp.text().await.unwrap_or_default();
3059 Err(Error::InvalidParameters(format!(
3060 "Upload failed: HTTP {} - {}",
3061 status, error_text
3062 )))
3063 }
3064}
3065
3066#[cfg(test)]
3067mod tests {
3068 use super::*;
3069
3070 #[test]
3071 fn test_build_filename_no_flatten() {
3072 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
3074 assert_eq!(result, "image.jpg");
3075
3076 let result = Client::build_filename("test.png", false, None, None);
3077 assert_eq!(result, "test.png");
3078 }
3079
3080 #[test]
3081 fn test_build_filename_flatten_no_sequence() {
3082 let result = Client::build_filename("standalone.jpg", true, None, None);
3084 assert_eq!(result, "standalone.jpg");
3085 }
3086
3087 #[test]
3088 fn test_build_filename_flatten_with_sequence_not_prefixed() {
3089 let result = Client::build_filename(
3091 "image.camera.jpeg",
3092 true,
3093 Some(&"deer_sequence".to_string()),
3094 Some(42),
3095 );
3096 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
3097 }
3098
3099 #[test]
3100 fn test_build_filename_flatten_with_sequence_no_frame() {
3101 let result =
3103 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
3104 assert_eq!(result, "sequence_A_image.jpg");
3105 }
3106
3107 #[test]
3108 fn test_build_filename_flatten_already_prefixed() {
3109 let result = Client::build_filename(
3111 "deer_sequence_042.camera.jpeg",
3112 true,
3113 Some(&"deer_sequence".to_string()),
3114 Some(42),
3115 );
3116 assert_eq!(result, "deer_sequence_042.camera.jpeg");
3117 }
3118
3119 #[test]
3120 fn test_build_filename_flatten_already_prefixed_different_frame() {
3121 let result = Client::build_filename(
3124 "sequence_A_001.jpg",
3125 true,
3126 Some(&"sequence_A".to_string()),
3127 Some(2),
3128 );
3129 assert_eq!(result, "sequence_A_001.jpg");
3130 }
3131
3132 #[test]
3133 fn test_build_filename_flatten_partial_match() {
3134 let result = Client::build_filename(
3136 "test_sequence_A_image.jpg",
3137 true,
3138 Some(&"sequence_A".to_string()),
3139 Some(5),
3140 );
3141 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
3143 }
3144
3145 #[test]
3146 fn test_build_filename_flatten_preserves_extension() {
3147 let extensions = vec![
3149 "jpeg",
3150 "jpg",
3151 "png",
3152 "camera.jpeg",
3153 "lidar.pcd",
3154 "depth.png",
3155 ];
3156
3157 for ext in extensions {
3158 let filename = format!("image.{}", ext);
3159 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
3160 assert!(
3161 result.ends_with(&format!(".{}", ext)),
3162 "Extension .{} not preserved in {}",
3163 ext,
3164 result
3165 );
3166 }
3167 }
3168
3169 #[test]
3170 fn test_build_filename_flatten_sanitization_compatibility() {
3171 let result = Client::build_filename(
3173 "sample_001.jpg",
3174 true,
3175 Some(&"seq_name_with_underscores".to_string()),
3176 Some(10),
3177 );
3178 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
3179 }
3180}