1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14 retry::{create_retry_policy, log_retry_configuration},
15};
16use base64::Engine as _;
17use chrono::{DateTime, Utc};
18use directories::ProjectDirs;
19use futures::{StreamExt as _, future::join_all};
20use log::{Level, debug, error, log_enabled, trace, warn};
21use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
22use serde::{Deserialize, Serialize, de::DeserializeOwned};
23use std::{
24 collections::HashMap,
25 ffi::OsStr,
26 fs::create_dir_all,
27 io::{SeekFrom, Write as _},
28 path::{Path, PathBuf},
29 sync::{
30 Arc,
31 atomic::{AtomicUsize, Ordering},
32 },
33 time::Duration,
34 vec,
35};
36use tokio::{
37 fs::{self, File},
38 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
39 sync::{RwLock, Semaphore, mpsc::Sender},
40};
41use tokio_util::codec::{BytesCodec, FramedRead};
42use walkdir::WalkDir;
43
44#[cfg(feature = "polars")]
45use polars::prelude::*;
46
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49fn max_tasks() -> usize {
50 std::env::var("MAX_TASKS")
51 .ok()
52 .and_then(|v| v.parse().ok())
53 .unwrap_or_else(|| {
54 let cpus = std::thread::available_parallelism()
57 .map(|n| n.get())
58 .unwrap_or(4);
59 (cpus / 2).clamp(2, 8)
60 })
61}
62
63fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
73where
74 F: Fn(&T) -> &str,
75{
76 let filter_lower = filter.to_lowercase();
77 let mut filtered: Vec<T> = items
78 .into_iter()
79 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
80 .collect();
81
82 filtered.sort_by(|a, b| {
83 let name_a = get_name(a);
84 let name_b = get_name(b);
85
86 let exact_a = name_a == filter;
88 let exact_b = name_b == filter;
89 if exact_a != exact_b {
90 return exact_b.cmp(&exact_a); }
92
93 let exact_ci_a = name_a.to_lowercase() == filter_lower;
95 let exact_ci_b = name_b.to_lowercase() == filter_lower;
96 if exact_ci_a != exact_ci_b {
97 return exact_ci_b.cmp(&exact_ci_a);
98 }
99
100 let len_cmp = name_a.len().cmp(&name_b.len());
102 if len_cmp != std::cmp::Ordering::Equal {
103 return len_cmp;
104 }
105
106 name_a.cmp(name_b)
108 });
109
110 filtered
111}
112
113fn sanitize_path_component(name: &str) -> String {
114 let trimmed = name.trim();
115 if trimmed.is_empty() {
116 return "unnamed".to_string();
117 }
118
119 let component = Path::new(trimmed)
120 .file_name()
121 .unwrap_or_else(|| OsStr::new(trimmed));
122
123 let sanitized: String = component
124 .to_string_lossy()
125 .chars()
126 .map(|c| match c {
127 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
128 _ => c,
129 })
130 .collect();
131
132 if sanitized.is_empty() {
133 "unnamed".to_string()
134 } else {
135 sanitized
136 }
137}
138
139#[derive(Debug, Clone)]
161pub struct Progress {
162 pub current: usize,
164 pub total: usize,
166}
167
168#[derive(Serialize)]
169struct RpcRequest<Params> {
170 id: u64,
171 jsonrpc: String,
172 method: String,
173 params: Option<Params>,
174}
175
176impl<T> Default for RpcRequest<T> {
177 fn default() -> Self {
178 RpcRequest {
179 id: 0,
180 jsonrpc: "2.0".to_string(),
181 method: "".to_string(),
182 params: None,
183 }
184 }
185}
186
187#[derive(Deserialize)]
188struct RpcError {
189 code: i32,
190 message: String,
191}
192
193#[derive(Deserialize)]
194struct RpcResponse<RpcResult> {
195 #[allow(dead_code)]
196 id: String,
197 #[allow(dead_code)]
198 jsonrpc: String,
199 error: Option<RpcError>,
200 result: Option<RpcResult>,
201}
202
203#[derive(Deserialize)]
204#[allow(dead_code)]
205struct EmptyResult {}
206
207#[derive(Debug, Serialize)]
208#[allow(dead_code)]
209struct SnapshotCreateParams {
210 snapshot_name: String,
211 keys: Vec<String>,
212}
213
214#[derive(Debug, Deserialize)]
215#[allow(dead_code)]
216struct SnapshotCreateResult {
217 snapshot_id: SnapshotID,
218 urls: Vec<String>,
219}
220
221#[derive(Debug, Serialize)]
222struct SnapshotCreateMultipartParams {
223 snapshot_name: String,
224 keys: Vec<String>,
225 file_sizes: Vec<usize>,
226}
227
228#[derive(Debug, Deserialize)]
229#[serde(untagged)]
230enum SnapshotCreateMultipartResultField {
231 Id(u64),
232 Part(SnapshotPart),
233}
234
235#[derive(Debug, Serialize)]
236struct SnapshotCompleteMultipartParams {
237 key: String,
238 upload_id: String,
239 etag_list: Vec<EtagPart>,
240}
241
242#[derive(Debug, Clone, Serialize)]
243struct EtagPart {
244 #[serde(rename = "ETag")]
245 etag: String,
246 #[serde(rename = "PartNumber")]
247 part_number: usize,
248}
249
250#[derive(Debug, Clone, Deserialize)]
251struct SnapshotPart {
252 key: Option<String>,
253 upload_id: String,
254 urls: Vec<String>,
255}
256
257#[derive(Debug, Serialize)]
258struct SnapshotStatusParams {
259 snapshot_id: SnapshotID,
260 status: String,
261}
262
263#[derive(Deserialize, Debug)]
264struct SnapshotStatusResult {
265 #[allow(dead_code)]
266 pub id: SnapshotID,
267 #[allow(dead_code)]
268 pub uid: String,
269 #[allow(dead_code)]
270 pub description: String,
271 #[allow(dead_code)]
272 pub date: String,
273 #[allow(dead_code)]
274 pub status: String,
275}
276
277#[derive(Serialize)]
278#[allow(dead_code)]
279struct ImageListParams {
280 images_filter: ImagesFilter,
281 image_files_filter: HashMap<String, String>,
282 only_ids: bool,
283}
284
285#[derive(Serialize)]
286#[allow(dead_code)]
287struct ImagesFilter {
288 dataset_id: DatasetID,
289}
290
291#[derive(Clone, Debug)]
340pub struct Client {
341 http: reqwest::Client,
342 url: String,
343 token: Arc<RwLock<String>>,
344 token_path: Option<PathBuf>,
345}
346
347struct FetchContext<'a> {
349 dataset_id: DatasetID,
350 annotation_set_id: Option<AnnotationSetID>,
351 groups: &'a [String],
352 types: Vec<String>,
353 labels: &'a HashMap<String, u64>,
354}
355
356impl Client {
357 pub fn new() -> Result<Self, Error> {
366 log_retry_configuration();
367
368 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
370 .ok()
371 .and_then(|s| s.parse().ok())
372 .unwrap_or(30); let http = reqwest::Client::builder()
384 .connect_timeout(Duration::from_secs(10))
385 .timeout(Duration::from_secs(timeout_secs))
386 .pool_idle_timeout(Duration::from_secs(90))
387 .pool_max_idle_per_host(10)
388 .retry(create_retry_policy())
389 .build()?;
390
391 Ok(Client {
392 http,
393 url: "https://edgefirst.studio".to_string(),
394 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
395 token_path: None,
396 })
397 }
398
399 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
403 Ok(Client {
404 url: format!("https://{}.edgefirst.studio", server),
405 ..self.clone()
406 })
407 }
408
409 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
412 let params = HashMap::from([("username", username), ("password", password)]);
413 let login: LoginResult = self
414 .rpc_without_auth("auth.login".to_owned(), Some(params))
415 .await?;
416
417 if login.token.is_empty() {
419 return Err(Error::EmptyToken);
420 }
421
422 Ok(Client {
423 token: Arc::new(tokio::sync::RwLock::new(login.token)),
424 ..self.clone()
425 })
426 }
427
428 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
431 let token_path = match token_path {
432 Some(path) => path.to_path_buf(),
433 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
434 .ok_or_else(|| {
435 Error::IoError(std::io::Error::new(
436 std::io::ErrorKind::NotFound,
437 "Could not determine user config directory",
438 ))
439 })?
440 .config_dir()
441 .join("token"),
442 };
443
444 debug!("Using token path: {:?}", token_path);
445
446 let token = match token_path.exists() {
447 true => std::fs::read_to_string(&token_path)?,
448 false => "".to_string(),
449 };
450
451 if !token.is_empty() {
452 match self.with_token(&token) {
453 Ok(client) => Ok(Client {
454 token_path: Some(token_path),
455 ..client
456 }),
457 Err(e) => {
458 warn!(
460 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
461 token_path, e
462 );
463 if let Err(remove_err) = std::fs::remove_file(&token_path) {
464 warn!("Failed to remove corrupted token file: {:?}", remove_err);
465 }
466 Ok(Client {
467 token_path: Some(token_path),
468 ..self.clone()
469 })
470 }
471 }
472 } else {
473 Ok(Client {
474 token_path: Some(token_path),
475 ..self.clone()
476 })
477 }
478 }
479
480 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
482 if token.is_empty() {
483 return Ok(self.clone());
484 }
485
486 let token_parts: Vec<&str> = token.split('.').collect();
487 if token_parts.len() != 3 {
488 return Err(Error::InvalidToken);
489 }
490
491 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
492 .decode(token_parts[1])
493 .map_err(|_| Error::InvalidToken)?;
494 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
495 let server = match payload.get("server") {
496 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
497 None => return Err(Error::InvalidToken),
498 };
499
500 Ok(Client {
501 url: format!("https://{}.edgefirst.studio", server),
502 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
503 ..self.clone()
504 })
505 }
506
507 pub async fn save_token(&self) -> Result<(), Error> {
508 let path = self.token_path.clone().unwrap_or_else(|| {
509 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
510 .map(|dirs| dirs.config_dir().join("token"))
511 .unwrap_or_else(|| PathBuf::from(".token"))
512 });
513
514 create_dir_all(path.parent().ok_or_else(|| {
515 Error::IoError(std::io::Error::new(
516 std::io::ErrorKind::InvalidInput,
517 "Token path has no parent directory",
518 ))
519 })?)?;
520 let mut file = std::fs::File::create(&path)?;
521 file.write_all(self.token.read().await.as_bytes())?;
522
523 debug!("Saved token to {:?}", path);
524
525 Ok(())
526 }
527
528 pub async fn version(&self) -> Result<String, Error> {
531 let version: HashMap<String, String> = self
532 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
533 .await?;
534 let version = version.get("version").ok_or(Error::InvalidResponse)?;
535 Ok(version.to_owned())
536 }
537
538 pub async fn logout(&self) -> Result<(), Error> {
542 {
543 let mut token = self.token.write().await;
544 *token = "".to_string();
545 }
546
547 if let Some(path) = &self.token_path
548 && path.exists()
549 {
550 fs::remove_file(path).await?;
551 }
552
553 Ok(())
554 }
555
556 pub async fn token(&self) -> String {
560 self.token.read().await.clone()
561 }
562
563 pub async fn verify_token(&self) -> Result<(), Error> {
568 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
569 .await?;
570 Ok::<(), Error>(())
571 }
572
573 pub async fn renew_token(&self) -> Result<(), Error> {
578 let params = HashMap::from([("username".to_string(), self.username().await?)]);
579 let result: LoginResult = self
580 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
581 .await?;
582
583 {
584 let mut token = self.token.write().await;
585 *token = result.token;
586 }
587
588 if self.token_path.is_some() {
589 self.save_token().await?;
590 }
591
592 Ok(())
593 }
594
595 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
596 let token = self.token.read().await;
597 if token.is_empty() {
598 return Err(Error::EmptyToken);
599 }
600
601 let token_parts: Vec<&str> = token.split('.').collect();
602 if token_parts.len() != 3 {
603 return Err(Error::InvalidToken);
604 }
605
606 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
607 .decode(token_parts[1])
608 .map_err(|_| Error::InvalidToken)?;
609 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
610 match payload.get(field) {
611 Some(value) => Ok(value.to_owned()),
612 None => Err(Error::InvalidToken),
613 }
614 }
615
616 pub fn url(&self) -> &str {
618 &self.url
619 }
620
621 pub async fn username(&self) -> Result<String, Error> {
623 match self.token_field("username").await? {
624 serde_json::Value::String(username) => Ok(username),
625 _ => Err(Error::InvalidToken),
626 }
627 }
628
629 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
631 let ts = match self.token_field("exp").await? {
632 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
633 _ => return Err(Error::InvalidToken),
634 };
635
636 match DateTime::<Utc>::from_timestamp_secs(ts) {
637 Some(dt) => Ok(dt),
638 None => Err(Error::InvalidToken),
639 }
640 }
641
642 pub async fn organization(&self) -> Result<Organization, Error> {
644 self.rpc::<(), Organization>("org.get".to_owned(), None)
645 .await
646 }
647
648 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
660 let projects = self
661 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
662 .await?;
663 if let Some(name) = name {
664 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
665 } else {
666 Ok(projects)
667 }
668 }
669
670 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
673 let params = HashMap::from([("project_id", project_id)]);
674 self.rpc("project.get".to_owned(), Some(params)).await
675 }
676
677 pub async fn datasets(
686 &self,
687 project_id: ProjectID,
688 name: Option<&str>,
689 ) -> Result<Vec<Dataset>, Error> {
690 let params = HashMap::from([("project_id", project_id)]);
691 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
692 if let Some(name) = name {
693 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
694 } else {
695 Ok(datasets)
696 }
697 }
698
699 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
702 let params = HashMap::from([("dataset_id", dataset_id)]);
703 self.rpc("dataset.get".to_owned(), Some(params)).await
704 }
705
706 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
708 let params = HashMap::from([("dataset_id", dataset_id)]);
709 self.rpc("label.list".to_owned(), Some(params)).await
710 }
711
712 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
714 let new_label = NewLabel {
715 dataset_id,
716 labels: vec![NewLabelObject {
717 name: name.to_owned(),
718 }],
719 };
720 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
721 Ok(())
722 }
723
724 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
727 let params = HashMap::from([("label_id", label_id)]);
728 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
729 Ok(())
730 }
731
732 pub async fn create_dataset(
744 &self,
745 project_id: &str,
746 name: &str,
747 description: Option<&str>,
748 ) -> Result<DatasetID, Error> {
749 let mut params = HashMap::new();
750 params.insert("project_id", project_id);
751 params.insert("name", name);
752 if let Some(desc) = description {
753 params.insert("description", desc);
754 }
755
756 #[derive(Deserialize)]
757 struct CreateDatasetResult {
758 id: DatasetID,
759 }
760
761 let result: CreateDatasetResult =
762 self.rpc("dataset.create".to_owned(), Some(params)).await?;
763 Ok(result.id)
764 }
765
766 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
776 let params = HashMap::from([("id", dataset_id)]);
777 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
778 Ok(())
779 }
780
781 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
785 #[derive(Serialize)]
786 struct Params {
787 dataset_id: DatasetID,
788 label_id: u64,
789 label_name: String,
790 label_index: u64,
791 }
792
793 let _: String = self
794 .rpc(
795 "label.update".to_owned(),
796 Some(Params {
797 dataset_id: label.dataset_id(),
798 label_id: label.id(),
799 label_name: label.name().to_owned(),
800 label_index: label.index(),
801 }),
802 )
803 .await?;
804 Ok(())
805 }
806
807 pub async fn download_dataset(
861 &self,
862 dataset_id: DatasetID,
863 groups: &[String],
864 file_types: &[FileType],
865 output: PathBuf,
866 flatten: bool,
867 progress: Option<Sender<Progress>>,
868 ) -> Result<(), Error> {
869 let samples = self
870 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
871 .await?;
872 fs::create_dir_all(&output).await?;
873
874 let client = self.clone();
875 let file_types = file_types.to_vec();
876 let output = output.clone();
877
878 parallel_foreach_items(samples, progress, move |sample| {
879 let client = client.clone();
880 let file_types = file_types.clone();
881 let output = output.clone();
882
883 async move {
884 for file_type in file_types {
885 if let Some(data) = sample.download(&client, file_type.clone()).await? {
886 let (file_ext, is_image) = match file_type.clone() {
887 FileType::Image => (
888 infer::get(&data)
889 .expect("Failed to identify image file format for sample")
890 .extension()
891 .to_string(),
892 true,
893 ),
894 other => (other.to_string(), false),
895 };
896
897 let sequence_dir = sample
904 .sequence_name()
905 .map(|name| sanitize_path_component(name));
906
907 let target_dir = if flatten {
908 output.clone()
909 } else {
910 sequence_dir
911 .as_ref()
912 .map(|seq| output.join(seq))
913 .unwrap_or_else(|| output.clone())
914 };
915 fs::create_dir_all(&target_dir).await?;
916
917 let sanitized_sample_name = sample
918 .name()
919 .map(|name| sanitize_path_component(&name))
920 .unwrap_or_else(|| "unknown".to_string());
921
922 let image_name = sample.image_name().map(sanitize_path_component);
923
924 let file_name = if is_image {
930 if let Some(img_name) = image_name {
931 Self::build_filename(
932 &img_name,
933 flatten,
934 sequence_dir.as_ref(),
935 sample.frame_number(),
936 )
937 } else {
938 format!("{}.{}", sanitized_sample_name, file_ext)
939 }
940 } else {
941 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
942 Self::build_filename(
943 &base_name,
944 flatten,
945 sequence_dir.as_ref(),
946 sample.frame_number(),
947 )
948 };
949
950 let file_path = target_dir.join(&file_name);
951
952 let mut file = File::create(&file_path).await?;
953 file.write_all(&data).await?;
954 } else {
955 warn!(
956 "No data for sample: {}",
957 sample
958 .id()
959 .map(|id| id.to_string())
960 .unwrap_or_else(|| "unknown".to_string())
961 );
962 }
963 }
964
965 Ok(())
966 }
967 })
968 .await
969 }
970
971 fn build_filename(
987 base_name: &str,
988 flatten: bool,
989 sequence_name: Option<&String>,
990 frame_number: Option<u32>,
991 ) -> String {
992 if !flatten || sequence_name.is_none() {
993 return base_name.to_string();
994 }
995
996 let seq_name = sequence_name.unwrap();
997 let prefix = format!("{}_", seq_name);
998
999 if base_name.starts_with(&prefix) {
1001 base_name.to_string()
1002 } else {
1003 match frame_number {
1005 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1006 None => format!("{}{}", prefix, base_name),
1007 }
1008 }
1009 }
1010
1011 pub async fn annotation_sets(
1013 &self,
1014 dataset_id: DatasetID,
1015 ) -> Result<Vec<AnnotationSet>, Error> {
1016 let params = HashMap::from([("dataset_id", dataset_id)]);
1017 self.rpc("annset.list".to_owned(), Some(params)).await
1018 }
1019
1020 pub async fn create_annotation_set(
1032 &self,
1033 dataset_id: DatasetID,
1034 name: &str,
1035 description: Option<&str>,
1036 ) -> Result<AnnotationSetID, Error> {
1037 #[derive(Serialize)]
1038 struct Params<'a> {
1039 dataset_id: DatasetID,
1040 name: &'a str,
1041 operator: &'a str,
1042 #[serde(skip_serializing_if = "Option::is_none")]
1043 description: Option<&'a str>,
1044 }
1045
1046 #[derive(Deserialize)]
1047 struct CreateAnnotationSetResult {
1048 id: AnnotationSetID,
1049 }
1050
1051 let username = self.username().await?;
1052 let result: CreateAnnotationSetResult = self
1053 .rpc(
1054 "annset.add".to_owned(),
1055 Some(Params {
1056 dataset_id,
1057 name,
1058 operator: &username,
1059 description,
1060 }),
1061 )
1062 .await?;
1063 Ok(result.id)
1064 }
1065
1066 pub async fn delete_annotation_set(
1077 &self,
1078 annotation_set_id: AnnotationSetID,
1079 ) -> Result<(), Error> {
1080 let params = HashMap::from([("id", annotation_set_id)]);
1081 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1082 Ok(())
1083 }
1084
1085 pub async fn annotation_set(
1087 &self,
1088 annotation_set_id: AnnotationSetID,
1089 ) -> Result<AnnotationSet, Error> {
1090 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1091 self.rpc("annset.get".to_owned(), Some(params)).await
1092 }
1093
1094 pub async fn annotations(
1107 &self,
1108 annotation_set_id: AnnotationSetID,
1109 groups: &[String],
1110 annotation_types: &[AnnotationType],
1111 progress: Option<Sender<Progress>>,
1112 ) -> Result<Vec<Annotation>, Error> {
1113 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1114 let labels = self
1115 .labels(dataset_id)
1116 .await?
1117 .into_iter()
1118 .map(|label| (label.name().to_string(), label.index()))
1119 .collect::<HashMap<_, _>>();
1120 let total = self
1121 .samples_count(
1122 dataset_id,
1123 Some(annotation_set_id),
1124 annotation_types,
1125 groups,
1126 &[],
1127 )
1128 .await?
1129 .total as usize;
1130
1131 if total == 0 {
1132 return Ok(vec![]);
1133 }
1134
1135 let context = FetchContext {
1136 dataset_id,
1137 annotation_set_id: Some(annotation_set_id),
1138 groups,
1139 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1140 labels: &labels,
1141 };
1142
1143 self.fetch_annotations_paginated(context, total, progress)
1144 .await
1145 }
1146
1147 async fn fetch_annotations_paginated(
1148 &self,
1149 context: FetchContext<'_>,
1150 total: usize,
1151 progress: Option<Sender<Progress>>,
1152 ) -> Result<Vec<Annotation>, Error> {
1153 let mut annotations = vec![];
1154 let mut continue_token: Option<String> = None;
1155 let mut current = 0;
1156
1157 loop {
1158 let params = SamplesListParams {
1159 dataset_id: context.dataset_id,
1160 annotation_set_id: context.annotation_set_id,
1161 types: context.types.clone(),
1162 group_names: context.groups.to_vec(),
1163 continue_token,
1164 };
1165
1166 let result: SamplesListResult =
1167 self.rpc("samples.list".to_owned(), Some(params)).await?;
1168 current += result.samples.len();
1169 continue_token = result.continue_token;
1170
1171 if result.samples.is_empty() {
1172 break;
1173 }
1174
1175 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1176
1177 if let Some(progress) = &progress {
1178 let _ = progress.send(Progress { current, total }).await;
1179 }
1180
1181 match &continue_token {
1182 Some(token) if !token.is_empty() => continue,
1183 _ => break,
1184 }
1185 }
1186
1187 drop(progress);
1188 Ok(annotations)
1189 }
1190
1191 fn process_sample_annotations(
1192 &self,
1193 samples: &[Sample],
1194 labels: &HashMap<String, u64>,
1195 annotations: &mut Vec<Annotation>,
1196 ) {
1197 for sample in samples {
1198 if sample.annotations().is_empty() {
1199 let mut annotation = Annotation::new();
1200 annotation.set_sample_id(sample.id());
1201 annotation.set_name(sample.name());
1202 annotation.set_sequence_name(sample.sequence_name().cloned());
1203 annotation.set_frame_number(sample.frame_number());
1204 annotation.set_group(sample.group().cloned());
1205 annotations.push(annotation);
1206 continue;
1207 }
1208
1209 for annotation in sample.annotations() {
1210 let mut annotation = annotation.clone();
1211 annotation.set_sample_id(sample.id());
1212 annotation.set_name(sample.name());
1213 annotation.set_sequence_name(sample.sequence_name().cloned());
1214 annotation.set_frame_number(sample.frame_number());
1215 annotation.set_group(sample.group().cloned());
1216 Self::set_label_index_from_map(&mut annotation, labels);
1217 annotations.push(annotation);
1218 }
1219 }
1220 }
1221
1222 fn parse_frame_from_image_name(
1230 image_name: Option<&String>,
1231 sequence_name: Option<&String>,
1232 ) -> Option<u32> {
1233 use std::path::Path;
1234
1235 let sequence = sequence_name?;
1236 let name = image_name?;
1237
1238 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1240
1241 stem.strip_prefix(sequence)
1243 .and_then(|suffix| suffix.strip_prefix('_'))
1244 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1245 }
1246
1247 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1249 if let Some(label) = annotation.label() {
1250 annotation.set_label_index(Some(labels[label.as_str()]));
1251 }
1252 }
1253
1254 pub async fn samples_count(
1255 &self,
1256 dataset_id: DatasetID,
1257 annotation_set_id: Option<AnnotationSetID>,
1258 annotation_types: &[AnnotationType],
1259 groups: &[String],
1260 types: &[FileType],
1261 ) -> Result<SamplesCountResult, Error> {
1262 let types = annotation_types
1263 .iter()
1264 .map(|t| t.to_string())
1265 .chain(types.iter().map(|t| t.to_string()))
1266 .collect::<Vec<_>>();
1267
1268 let params = SamplesListParams {
1269 dataset_id,
1270 annotation_set_id,
1271 group_names: groups.to_vec(),
1272 types,
1273 continue_token: None,
1274 };
1275
1276 self.rpc("samples.count".to_owned(), Some(params)).await
1277 }
1278
1279 pub async fn samples(
1280 &self,
1281 dataset_id: DatasetID,
1282 annotation_set_id: Option<AnnotationSetID>,
1283 annotation_types: &[AnnotationType],
1284 groups: &[String],
1285 types: &[FileType],
1286 progress: Option<Sender<Progress>>,
1287 ) -> Result<Vec<Sample>, Error> {
1288 let types_vec = annotation_types
1289 .iter()
1290 .map(|t| t.to_string())
1291 .chain(types.iter().map(|t| t.to_string()))
1292 .collect::<Vec<_>>();
1293 let labels = self
1294 .labels(dataset_id)
1295 .await?
1296 .into_iter()
1297 .map(|label| (label.name().to_string(), label.index()))
1298 .collect::<HashMap<_, _>>();
1299 let total = self
1300 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1301 .await?
1302 .total as usize;
1303
1304 if total == 0 {
1305 return Ok(vec![]);
1306 }
1307
1308 let context = FetchContext {
1309 dataset_id,
1310 annotation_set_id,
1311 groups,
1312 types: types_vec,
1313 labels: &labels,
1314 };
1315
1316 self.fetch_samples_paginated(context, total, progress).await
1317 }
1318
1319 async fn fetch_samples_paginated(
1320 &self,
1321 context: FetchContext<'_>,
1322 total: usize,
1323 progress: Option<Sender<Progress>>,
1324 ) -> Result<Vec<Sample>, Error> {
1325 let mut samples = vec![];
1326 let mut continue_token: Option<String> = None;
1327 let mut current = 0;
1328
1329 loop {
1330 let params = SamplesListParams {
1331 dataset_id: context.dataset_id,
1332 annotation_set_id: context.annotation_set_id,
1333 types: context.types.clone(),
1334 group_names: context.groups.to_vec(),
1335 continue_token: continue_token.clone(),
1336 };
1337
1338 let result: SamplesListResult =
1339 self.rpc("samples.list".to_owned(), Some(params)).await?;
1340 current += result.samples.len();
1341 continue_token = result.continue_token;
1342
1343 if result.samples.is_empty() {
1344 break;
1345 }
1346
1347 samples.append(
1348 &mut result
1349 .samples
1350 .into_iter()
1351 .map(|s| {
1352 let frame_number = s.frame_number.or_else(|| {
1357 Self::parse_frame_from_image_name(
1358 s.image_name.as_ref(),
1359 s.sequence_name.as_ref(),
1360 )
1361 });
1362
1363 let mut anns = s.annotations().to_vec();
1364 for ann in &mut anns {
1365 ann.set_name(s.name());
1367 ann.set_group(s.group().cloned());
1368 ann.set_sequence_name(s.sequence_name().cloned());
1369 ann.set_frame_number(frame_number);
1370 Self::set_label_index_from_map(ann, context.labels);
1371 }
1372 s.with_annotations(anns).with_frame_number(frame_number)
1373 })
1374 .collect::<Vec<_>>(),
1375 );
1376
1377 if let Some(progress) = &progress {
1378 let _ = progress.send(Progress { current, total }).await;
1379 }
1380
1381 match &continue_token {
1382 Some(token) if !token.is_empty() => continue,
1383 _ => break,
1384 }
1385 }
1386
1387 drop(progress);
1388 Ok(samples)
1389 }
1390
1391 pub async fn populate_samples(
1483 &self,
1484 dataset_id: DatasetID,
1485 annotation_set_id: Option<AnnotationSetID>,
1486 samples: Vec<Sample>,
1487 progress: Option<Sender<Progress>>,
1488 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1489 use crate::api::SamplesPopulateParams;
1490
1491 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1493
1494 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1496
1497 let has_files_to_upload = !files_to_upload.is_empty();
1498
1499 let params = SamplesPopulateParams {
1501 dataset_id,
1502 annotation_set_id,
1503 presigned_urls: Some(has_files_to_upload),
1504 samples,
1505 };
1506
1507 let results: Vec<crate::SamplesPopulateResult> = self
1508 .rpc("samples.populate2".to_owned(), Some(params))
1509 .await?;
1510
1511 if has_files_to_upload {
1513 self.upload_sample_files(&results, files_to_upload, progress)
1514 .await?;
1515 }
1516
1517 Ok(results)
1518 }
1519
1520 fn prepare_samples_for_upload(
1521 &self,
1522 samples: Vec<Sample>,
1523 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1524 ) -> Result<Vec<Sample>, Error> {
1525 Ok(samples
1526 .into_iter()
1527 .map(|mut sample| {
1528 if sample.uuid.is_none() {
1530 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1531 }
1532
1533 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1534
1535 let files_copy = sample.files.clone();
1537 let updated_files: Vec<crate::SampleFile> = files_copy
1538 .iter()
1539 .map(|file| {
1540 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1541 })
1542 .collect();
1543
1544 sample.files = updated_files;
1545 sample
1546 })
1547 .collect())
1548 }
1549
1550 fn process_sample_file(
1551 &self,
1552 file: &crate::SampleFile,
1553 sample_uuid: &str,
1554 sample: &mut Sample,
1555 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1556 ) -> crate::SampleFile {
1557 use std::path::Path;
1558
1559 if let Some(filename) = file.filename() {
1560 let path = Path::new(filename);
1561
1562 if path.exists()
1564 && path.is_file()
1565 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1566 {
1567 if file.file_type() == "image"
1569 && (sample.width.is_none() || sample.height.is_none())
1570 && let Ok(size) = imagesize::size(path)
1571 {
1572 sample.width = Some(size.width as u32);
1573 sample.height = Some(size.height as u32);
1574 }
1575
1576 files_to_upload.push((
1578 sample_uuid.to_string(),
1579 file.file_type().to_string(),
1580 path.to_path_buf(),
1581 basename.to_string(),
1582 ));
1583
1584 return crate::SampleFile::with_filename(
1586 file.file_type().to_string(),
1587 basename.to_string(),
1588 );
1589 }
1590 }
1591 file.clone()
1593 }
1594
1595 async fn upload_sample_files(
1596 &self,
1597 results: &[crate::SamplesPopulateResult],
1598 files_to_upload: Vec<(String, String, PathBuf, String)>,
1599 progress: Option<Sender<Progress>>,
1600 ) -> Result<(), Error> {
1601 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1603 for (uuid, _file_type, path, basename) in files_to_upload {
1604 upload_map.insert((uuid, basename), path);
1605 }
1606
1607 let http = self.http.clone();
1608
1609 let upload_tasks: Vec<_> = results
1611 .iter()
1612 .map(|result| (result.uuid.clone(), result.urls.clone()))
1613 .collect();
1614
1615 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1616 let http = http.clone();
1617 let upload_map = upload_map.clone();
1618
1619 async move {
1620 for url_info in &urls {
1622 if let Some(local_path) =
1623 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1624 {
1625 upload_file_to_presigned_url(
1627 http.clone(),
1628 &url_info.url,
1629 local_path.clone(),
1630 )
1631 .await?;
1632 }
1633 }
1634
1635 Ok(())
1636 }
1637 })
1638 .await
1639 }
1640
1641 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1642 let resp = self.http.get(url).send().await?;
1644
1645 if !resp.status().is_success() {
1646 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
1647 }
1648
1649 let bytes = resp.bytes().await?;
1650 Ok(bytes.to_vec())
1651 }
1652
1653 #[deprecated(
1693 since = "0.8.0",
1694 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
1695 )]
1696 #[cfg(feature = "polars")]
1697 pub async fn annotations_dataframe(
1698 &self,
1699 annotation_set_id: AnnotationSetID,
1700 groups: &[String],
1701 types: &[AnnotationType],
1702 progress: Option<Sender<Progress>>,
1703 ) -> Result<DataFrame, Error> {
1704 use crate::dataset::annotations_dataframe;
1705
1706 let annotations = self
1707 .annotations(annotation_set_id, groups, types, progress)
1708 .await?;
1709 #[allow(deprecated)]
1710 annotations_dataframe(&annotations)
1711 }
1712
1713 #[cfg(feature = "polars")]
1750 pub async fn samples_dataframe(
1751 &self,
1752 dataset_id: DatasetID,
1753 annotation_set_id: Option<AnnotationSetID>,
1754 groups: &[String],
1755 types: &[AnnotationType],
1756 progress: Option<Sender<Progress>>,
1757 ) -> Result<DataFrame, Error> {
1758 use crate::dataset::samples_dataframe;
1759
1760 let samples = self
1761 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
1762 .await?;
1763 samples_dataframe(&samples)
1764 }
1765
1766 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1773 let snapshots: Vec<Snapshot> = self
1774 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1775 .await?;
1776 if let Some(name) = name {
1777 Ok(filter_and_sort_by_name(snapshots, name, |s| {
1778 s.description()
1779 }))
1780 } else {
1781 Ok(snapshots)
1782 }
1783 }
1784
1785 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1787 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1788 self.rpc("snapshots.get".to_owned(), Some(params)).await
1789 }
1790
1791 pub async fn create_snapshot(
1873 &self,
1874 path: &str,
1875 progress: Option<Sender<Progress>>,
1876 ) -> Result<Snapshot, Error> {
1877 let path = Path::new(path);
1878
1879 if path.is_dir() {
1880 let path_str = path.to_str().ok_or_else(|| {
1881 Error::IoError(std::io::Error::new(
1882 std::io::ErrorKind::InvalidInput,
1883 "Path contains invalid UTF-8",
1884 ))
1885 })?;
1886 return self.create_snapshot_folder(path_str, progress).await;
1887 }
1888
1889 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1890 Error::IoError(std::io::Error::new(
1891 std::io::ErrorKind::InvalidInput,
1892 "Invalid filename",
1893 ))
1894 })?;
1895 let total = path.metadata()?.len() as usize;
1896 let current = Arc::new(AtomicUsize::new(0));
1897
1898 if let Some(progress) = &progress {
1899 let _ = progress.send(Progress { current: 0, total }).await;
1900 }
1901
1902 let params = SnapshotCreateMultipartParams {
1903 snapshot_name: name.to_owned(),
1904 keys: vec![name.to_owned()],
1905 file_sizes: vec![total],
1906 };
1907 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1908 .rpc(
1909 "snapshots.create_upload_url_multipart".to_owned(),
1910 Some(params),
1911 )
1912 .await?;
1913
1914 let snapshot_id = match multipart.get("snapshot_id") {
1915 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1916 _ => return Err(Error::InvalidResponse),
1917 };
1918
1919 let snapshot = self.snapshot(snapshot_id).await?;
1920 let part_prefix = snapshot
1921 .path()
1922 .split("::/")
1923 .last()
1924 .ok_or(Error::InvalidResponse)?
1925 .to_owned();
1926 let part_key = format!("{}/{}", part_prefix, name);
1927 let mut part = match multipart.get(&part_key) {
1928 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1929 _ => return Err(Error::InvalidResponse),
1930 }
1931 .clone();
1932 part.key = Some(part_key);
1933
1934 let params = upload_multipart(
1935 self.http.clone(),
1936 part.clone(),
1937 path.to_path_buf(),
1938 total,
1939 current,
1940 progress.clone(),
1941 )
1942 .await?;
1943
1944 let complete: String = self
1945 .rpc(
1946 "snapshots.complete_multipart_upload".to_owned(),
1947 Some(params),
1948 )
1949 .await?;
1950 debug!("Snapshot Multipart Complete: {:?}", complete);
1951
1952 let params: SnapshotStatusParams = SnapshotStatusParams {
1953 snapshot_id,
1954 status: "available".to_owned(),
1955 };
1956 let _: SnapshotStatusResult = self
1957 .rpc("snapshots.update".to_owned(), Some(params))
1958 .await?;
1959
1960 if let Some(progress) = progress {
1961 drop(progress);
1962 }
1963
1964 self.snapshot(snapshot_id).await
1965 }
1966
1967 async fn create_snapshot_folder(
1968 &self,
1969 path: &str,
1970 progress: Option<Sender<Progress>>,
1971 ) -> Result<Snapshot, Error> {
1972 let path = Path::new(path);
1973 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
1974 Error::IoError(std::io::Error::new(
1975 std::io::ErrorKind::InvalidInput,
1976 "Invalid directory name",
1977 ))
1978 })?;
1979
1980 let files = WalkDir::new(path)
1981 .into_iter()
1982 .filter_map(|entry| entry.ok())
1983 .filter(|entry| entry.file_type().is_file())
1984 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
1985 .collect::<Vec<_>>();
1986
1987 let total: usize = files
1988 .iter()
1989 .filter_map(|file| path.join(file).metadata().ok())
1990 .map(|metadata| metadata.len() as usize)
1991 .sum();
1992 let current = Arc::new(AtomicUsize::new(0));
1993
1994 if let Some(progress) = &progress {
1995 let _ = progress.send(Progress { current: 0, total }).await;
1996 }
1997
1998 let keys = files
1999 .iter()
2000 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
2001 .collect::<Vec<_>>();
2002 let file_sizes = files
2003 .iter()
2004 .filter_map(|key| path.join(key).metadata().ok())
2005 .map(|metadata| metadata.len() as usize)
2006 .collect::<Vec<_>>();
2007
2008 let params = SnapshotCreateMultipartParams {
2009 snapshot_name: name.to_owned(),
2010 keys,
2011 file_sizes,
2012 };
2013
2014 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2015 .rpc(
2016 "snapshots.create_upload_url_multipart".to_owned(),
2017 Some(params),
2018 )
2019 .await?;
2020
2021 let snapshot_id = match multipart.get("snapshot_id") {
2022 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2023 _ => return Err(Error::InvalidResponse),
2024 };
2025
2026 let snapshot = self.snapshot(snapshot_id).await?;
2027 let part_prefix = snapshot
2028 .path()
2029 .split("::/")
2030 .last()
2031 .ok_or(Error::InvalidResponse)?
2032 .to_owned();
2033
2034 for file in files {
2035 let file_str = file.to_str().ok_or_else(|| {
2036 Error::IoError(std::io::Error::new(
2037 std::io::ErrorKind::InvalidInput,
2038 "File path contains invalid UTF-8",
2039 ))
2040 })?;
2041 let part_key = format!("{}/{}", part_prefix, file_str);
2042 let mut part = match multipart.get(&part_key) {
2043 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2044 _ => return Err(Error::InvalidResponse),
2045 }
2046 .clone();
2047 part.key = Some(part_key);
2048
2049 let params = upload_multipart(
2050 self.http.clone(),
2051 part.clone(),
2052 path.join(file),
2053 total,
2054 current.clone(),
2055 progress.clone(),
2056 )
2057 .await?;
2058
2059 let complete: String = self
2060 .rpc(
2061 "snapshots.complete_multipart_upload".to_owned(),
2062 Some(params),
2063 )
2064 .await?;
2065 debug!("Snapshot Part Complete: {:?}", complete);
2066 }
2067
2068 let params = SnapshotStatusParams {
2069 snapshot_id,
2070 status: "available".to_owned(),
2071 };
2072 let _: SnapshotStatusResult = self
2073 .rpc("snapshots.update".to_owned(), Some(params))
2074 .await?;
2075
2076 if let Some(progress) = progress {
2077 drop(progress);
2078 }
2079
2080 self.snapshot(snapshot_id).await
2081 }
2082
2083 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
2116 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2117 let _: String = self
2118 .rpc("snapshots.delete".to_owned(), Some(params))
2119 .await?;
2120 Ok(())
2121 }
2122
2123 pub async fn download_snapshot(
2176 &self,
2177 snapshot_id: SnapshotID,
2178 output: PathBuf,
2179 progress: Option<Sender<Progress>>,
2180 ) -> Result<(), Error> {
2181 fs::create_dir_all(&output).await?;
2182
2183 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2184 let items: HashMap<String, String> = self
2185 .rpc("snapshots.create_download_url".to_owned(), Some(params))
2186 .await?;
2187
2188 let total = Arc::new(AtomicUsize::new(0));
2189 let current = Arc::new(AtomicUsize::new(0));
2190 let sem = Arc::new(Semaphore::new(max_tasks()));
2191
2192 let tasks = items
2193 .iter()
2194 .map(|(key, url)| {
2195 let http = self.http.clone();
2196 let key = key.clone();
2197 let url = url.clone();
2198 let output = output.clone();
2199 let progress = progress.clone();
2200 let current = current.clone();
2201 let total = total.clone();
2202 let sem = sem.clone();
2203
2204 tokio::spawn(async move {
2205 let _permit = sem.acquire().await.map_err(|_| {
2206 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2207 })?;
2208 let res = http.get(url).send().await?;
2209 let content_length = res.content_length().unwrap_or(0) as usize;
2210
2211 if let Some(progress) = &progress {
2212 let total = total.fetch_add(content_length, Ordering::SeqCst);
2213 let _ = progress
2214 .send(Progress {
2215 current: current.load(Ordering::SeqCst),
2216 total: total + content_length,
2217 })
2218 .await;
2219 }
2220
2221 let mut file = File::create(output.join(key)).await?;
2222 let mut stream = res.bytes_stream();
2223
2224 while let Some(chunk) = stream.next().await {
2225 let chunk = chunk?;
2226 file.write_all(&chunk).await?;
2227 let len = chunk.len();
2228
2229 if let Some(progress) = &progress {
2230 let total = total.load(Ordering::SeqCst);
2231 let current = current.fetch_add(len, Ordering::SeqCst);
2232
2233 let _ = progress
2234 .send(Progress {
2235 current: current + len,
2236 total,
2237 })
2238 .await;
2239 }
2240 }
2241
2242 Ok::<(), Error>(())
2243 })
2244 })
2245 .collect::<Vec<_>>();
2246
2247 join_all(tasks)
2248 .await
2249 .into_iter()
2250 .collect::<Result<Vec<_>, _>>()?
2251 .into_iter()
2252 .collect::<Result<Vec<_>, _>>()?;
2253
2254 Ok(())
2255 }
2256
2257 #[allow(clippy::too_many_arguments)]
2324 pub async fn restore_snapshot(
2325 &self,
2326 project_id: ProjectID,
2327 snapshot_id: SnapshotID,
2328 topics: &[String],
2329 autolabel: &[String],
2330 autodepth: bool,
2331 dataset_name: Option<&str>,
2332 dataset_description: Option<&str>,
2333 ) -> Result<SnapshotRestoreResult, Error> {
2334 let params = SnapshotRestore {
2335 project_id,
2336 snapshot_id,
2337 fps: 1,
2338 autodepth,
2339 agtg_pipeline: !autolabel.is_empty(),
2340 autolabel: autolabel.to_vec(),
2341 topics: topics.to_vec(),
2342 dataset_name: dataset_name.map(|s| s.to_owned()),
2343 dataset_description: dataset_description.map(|s| s.to_owned()),
2344 };
2345 self.rpc("snapshots.restore".to_owned(), Some(params)).await
2346 }
2347
2348 pub async fn experiments(
2361 &self,
2362 project_id: ProjectID,
2363 name: Option<&str>,
2364 ) -> Result<Vec<Experiment>, Error> {
2365 let params = HashMap::from([("project_id", project_id)]);
2366 let experiments: Vec<Experiment> =
2367 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
2368 if let Some(name) = name {
2369 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
2370 } else {
2371 Ok(experiments)
2372 }
2373 }
2374
2375 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
2378 let params = HashMap::from([("trainer_id", experiment_id)]);
2379 self.rpc("trainer.get".to_owned(), Some(params)).await
2380 }
2381
2382 pub async fn training_sessions(
2395 &self,
2396 experiment_id: ExperimentID,
2397 name: Option<&str>,
2398 ) -> Result<Vec<TrainingSession>, Error> {
2399 let params = HashMap::from([("trainer_id", experiment_id)]);
2400 let sessions: Vec<TrainingSession> = self
2401 .rpc("trainer.session.list".to_owned(), Some(params))
2402 .await?;
2403 if let Some(name) = name {
2404 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
2405 } else {
2406 Ok(sessions)
2407 }
2408 }
2409
2410 pub async fn training_session(
2413 &self,
2414 session_id: TrainingSessionID,
2415 ) -> Result<TrainingSession, Error> {
2416 let params = HashMap::from([("trainer_session_id", session_id)]);
2417 self.rpc("trainer.session.get".to_owned(), Some(params))
2418 .await
2419 }
2420
2421 pub async fn validation_sessions(
2423 &self,
2424 project_id: ProjectID,
2425 ) -> Result<Vec<ValidationSession>, Error> {
2426 let params = HashMap::from([("project_id", project_id)]);
2427 self.rpc("validate.session.list".to_owned(), Some(params))
2428 .await
2429 }
2430
2431 pub async fn validation_session(
2433 &self,
2434 session_id: ValidationSessionID,
2435 ) -> Result<ValidationSession, Error> {
2436 let params = HashMap::from([("validate_session_id", session_id)]);
2437 self.rpc("validate.session.get".to_owned(), Some(params))
2438 .await
2439 }
2440
2441 pub async fn artifacts(
2444 &self,
2445 training_session_id: TrainingSessionID,
2446 ) -> Result<Vec<Artifact>, Error> {
2447 let params = HashMap::from([("training_session_id", training_session_id)]);
2448 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2449 .await
2450 }
2451
2452 pub async fn download_artifact(
2458 &self,
2459 training_session_id: TrainingSessionID,
2460 modelname: &str,
2461 filename: Option<PathBuf>,
2462 progress: Option<Sender<Progress>>,
2463 ) -> Result<(), Error> {
2464 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2465 let resp = self
2466 .http
2467 .get(format!(
2468 "{}/download_model?training_session_id={}&file={}",
2469 self.url,
2470 training_session_id.value(),
2471 modelname
2472 ))
2473 .header("Authorization", format!("Bearer {}", self.token().await))
2474 .send()
2475 .await?;
2476 if !resp.status().is_success() {
2477 let err = resp.error_for_status_ref().unwrap_err();
2478 return Err(Error::HttpError(err));
2479 }
2480
2481 if let Some(parent) = filename.parent() {
2482 fs::create_dir_all(parent).await?;
2483 }
2484
2485 if let Some(progress) = progress {
2486 let total = resp.content_length().unwrap_or(0) as usize;
2487 let _ = progress.send(Progress { current: 0, total }).await;
2488
2489 let mut file = File::create(filename).await?;
2490 let mut current = 0;
2491 let mut stream = resp.bytes_stream();
2492
2493 while let Some(item) = stream.next().await {
2494 let chunk = item?;
2495 file.write_all(&chunk).await?;
2496 current += chunk.len();
2497 let _ = progress.send(Progress { current, total }).await;
2498 }
2499 } else {
2500 let body = resp.bytes().await?;
2501 fs::write(filename, body).await?;
2502 }
2503
2504 Ok(())
2505 }
2506
2507 pub async fn download_checkpoint(
2517 &self,
2518 training_session_id: TrainingSessionID,
2519 checkpoint: &str,
2520 filename: Option<PathBuf>,
2521 progress: Option<Sender<Progress>>,
2522 ) -> Result<(), Error> {
2523 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2524 let resp = self
2525 .http
2526 .get(format!(
2527 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2528 self.url,
2529 training_session_id.value(),
2530 checkpoint
2531 ))
2532 .header("Authorization", format!("Bearer {}", self.token().await))
2533 .send()
2534 .await?;
2535 if !resp.status().is_success() {
2536 let err = resp.error_for_status_ref().unwrap_err();
2537 return Err(Error::HttpError(err));
2538 }
2539
2540 if let Some(parent) = filename.parent() {
2541 fs::create_dir_all(parent).await?;
2542 }
2543
2544 if let Some(progress) = progress {
2545 let total = resp.content_length().unwrap_or(0) as usize;
2546 let _ = progress.send(Progress { current: 0, total }).await;
2547
2548 let mut file = File::create(filename).await?;
2549 let mut current = 0;
2550 let mut stream = resp.bytes_stream();
2551
2552 while let Some(item) = stream.next().await {
2553 let chunk = item?;
2554 file.write_all(&chunk).await?;
2555 current += chunk.len();
2556 let _ = progress.send(Progress { current, total }).await;
2557 }
2558 } else {
2559 let body = resp.bytes().await?;
2560 fs::write(filename, body).await?;
2561 }
2562
2563 Ok(())
2564 }
2565
2566 pub async fn tasks(
2581 &self,
2582 name: Option<&str>,
2583 workflow: Option<&str>,
2584 status: Option<&str>,
2585 manager: Option<&str>,
2586 ) -> Result<Vec<Task>, Error> {
2587 let mut params = TasksListParams {
2588 continue_token: None,
2589 types: workflow.map(|w| vec![w.to_owned()]),
2590 status: status.map(|s| vec![s.to_owned()]),
2591 manager: manager.map(|m| vec![m.to_owned()]),
2592 };
2593 let mut tasks = Vec::new();
2594
2595 loop {
2596 let result = self
2597 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
2598 .await?;
2599 tasks.extend(result.tasks);
2600
2601 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
2602 params.continue_token = None;
2603 } else {
2604 params.continue_token = result.continue_token;
2605 }
2606
2607 if params.continue_token.is_none() {
2608 break;
2609 }
2610 }
2611
2612 if let Some(name) = name {
2613 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
2614 }
2615
2616 Ok(tasks)
2617 }
2618
2619 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
2621 self.rpc(
2622 "task.get".to_owned(),
2623 Some(HashMap::from([("id", task_id)])),
2624 )
2625 .await
2626 }
2627
2628 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
2630 let status = TaskStatus {
2631 task_id,
2632 status: status.to_owned(),
2633 };
2634 self.rpc("docker.update.status".to_owned(), Some(status))
2635 .await
2636 }
2637
2638 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
2642 let stages: Vec<HashMap<String, String>> = stages
2643 .iter()
2644 .map(|(key, value)| {
2645 let mut stage_map = HashMap::new();
2646 stage_map.insert(key.to_string(), value.to_string());
2647 stage_map
2648 })
2649 .collect();
2650 let params = TaskStages { task_id, stages };
2651 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
2652 Ok(())
2653 }
2654
2655 pub async fn update_stage(
2658 &self,
2659 task_id: TaskID,
2660 stage: &str,
2661 status: &str,
2662 message: &str,
2663 percentage: u8,
2664 ) -> Result<(), Error> {
2665 let stage = Stage::new(
2666 Some(task_id),
2667 stage.to_owned(),
2668 Some(status.to_owned()),
2669 Some(message.to_owned()),
2670 percentage,
2671 );
2672 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
2673 Ok(())
2674 }
2675
2676 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
2678 let req = self
2679 .http
2680 .get(format!("{}/{}", self.url, query))
2681 .header("User-Agent", "EdgeFirst Client")
2682 .header("Authorization", format!("Bearer {}", self.token().await));
2683 let resp = req.send().await?;
2684
2685 if resp.status().is_success() {
2686 let body = resp.bytes().await?;
2687
2688 if log_enabled!(Level::Trace) {
2689 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
2690 }
2691
2692 Ok(body.to_vec())
2693 } else {
2694 let err = resp.error_for_status_ref().unwrap_err();
2695 Err(Error::HttpError(err))
2696 }
2697 }
2698
2699 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2703 let req = self
2704 .http
2705 .post(format!("{}/api?method={}", self.url, method))
2706 .header("Accept", "application/json")
2707 .header("User-Agent", "EdgeFirst Client")
2708 .header("Authorization", format!("Bearer {}", self.token().await))
2709 .multipart(form);
2710 let resp = req.send().await?;
2711
2712 if resp.status().is_success() {
2713 let body = resp.bytes().await?;
2714
2715 if log_enabled!(Level::Trace) {
2716 trace!(
2717 "POST Multipart Response: {}",
2718 String::from_utf8_lossy(&body)
2719 );
2720 }
2721
2722 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2723 Ok(response) => response,
2724 Err(err) => {
2725 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2726 return Err(err.into());
2727 }
2728 };
2729
2730 if let Some(error) = response.error {
2731 Err(Error::RpcError(error.code, error.message))
2732 } else if let Some(result) = response.result {
2733 Ok(result)
2734 } else {
2735 Err(Error::InvalidResponse)
2736 }
2737 } else {
2738 let err = resp.error_for_status_ref().unwrap_err();
2739 Err(Error::HttpError(err))
2740 }
2741 }
2742
2743 pub async fn rpc<Params, RpcResult>(
2752 &self,
2753 method: String,
2754 params: Option<Params>,
2755 ) -> Result<RpcResult, Error>
2756 where
2757 Params: Serialize,
2758 RpcResult: DeserializeOwned,
2759 {
2760 let auth_expires = self.token_expiration().await?;
2761 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2762 self.renew_token().await?;
2763 }
2764
2765 self.rpc_without_auth(method, params).await
2766 }
2767
2768 async fn rpc_without_auth<Params, RpcResult>(
2769 &self,
2770 method: String,
2771 params: Option<Params>,
2772 ) -> Result<RpcResult, Error>
2773 where
2774 Params: Serialize,
2775 RpcResult: DeserializeOwned,
2776 {
2777 let request = RpcRequest {
2778 method,
2779 params,
2780 ..Default::default()
2781 };
2782
2783 if log_enabled!(Level::Trace) {
2784 trace!(
2785 "RPC Request: {}",
2786 serde_json::ser::to_string_pretty(&request)?
2787 );
2788 }
2789
2790 let url = format!("{}/api", self.url);
2791
2792 let res = self
2795 .http
2796 .post(&url)
2797 .header("Accept", "application/json")
2798 .header("User-Agent", "EdgeFirst Client")
2799 .header("Authorization", format!("Bearer {}", self.token().await))
2800 .json(&request)
2801 .send()
2802 .await?;
2803
2804 self.process_rpc_response(res).await
2805 }
2806
2807 async fn process_rpc_response<RpcResult>(
2808 &self,
2809 res: reqwest::Response,
2810 ) -> Result<RpcResult, Error>
2811 where
2812 RpcResult: DeserializeOwned,
2813 {
2814 let body = res.bytes().await?;
2815
2816 if log_enabled!(Level::Trace) {
2817 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2818 }
2819
2820 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2821 Ok(response) => response,
2822 Err(err) => {
2823 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2824 return Err(err.into());
2825 }
2826 };
2827
2828 if let Some(error) = response.error {
2834 Err(Error::RpcError(error.code, error.message))
2835 } else if let Some(result) = response.result {
2836 Ok(result)
2837 } else {
2838 Err(Error::InvalidResponse)
2839 }
2840 }
2841}
2842
2843async fn parallel_foreach_items<T, F, Fut>(
2874 items: Vec<T>,
2875 progress: Option<Sender<Progress>>,
2876 work_fn: F,
2877) -> Result<(), Error>
2878where
2879 T: Send + 'static,
2880 F: Fn(T) -> Fut + Send + Sync + 'static,
2881 Fut: Future<Output = Result<(), Error>> + Send + 'static,
2882{
2883 let total = items.len();
2884 let current = Arc::new(AtomicUsize::new(0));
2885 let sem = Arc::new(Semaphore::new(max_tasks()));
2886 let work_fn = Arc::new(work_fn);
2887
2888 let tasks = items
2889 .into_iter()
2890 .map(|item| {
2891 let sem = sem.clone();
2892 let current = current.clone();
2893 let progress = progress.clone();
2894 let work_fn = work_fn.clone();
2895
2896 tokio::spawn(async move {
2897 let _permit = sem.acquire().await.map_err(|_| {
2898 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2899 })?;
2900
2901 work_fn(item).await?;
2903
2904 if let Some(progress) = &progress {
2906 let current = current.fetch_add(1, Ordering::SeqCst);
2907 let _ = progress
2908 .send(Progress {
2909 current: current + 1,
2910 total,
2911 })
2912 .await;
2913 }
2914
2915 Ok::<(), Error>(())
2916 })
2917 })
2918 .collect::<Vec<_>>();
2919
2920 join_all(tasks)
2921 .await
2922 .into_iter()
2923 .collect::<Result<Vec<_>, _>>()?
2924 .into_iter()
2925 .collect::<Result<Vec<_>, _>>()?;
2926
2927 if let Some(progress) = progress {
2928 drop(progress);
2929 }
2930
2931 Ok(())
2932}
2933
2934async fn upload_multipart(
2960 http: reqwest::Client,
2961 part: SnapshotPart,
2962 path: PathBuf,
2963 total: usize,
2964 current: Arc<AtomicUsize>,
2965 progress: Option<Sender<Progress>>,
2966) -> Result<SnapshotCompleteMultipartParams, Error> {
2967 let filesize = path.metadata()?.len() as usize;
2968 let n_parts = filesize.div_ceil(PART_SIZE);
2969 let sem = Arc::new(Semaphore::new(max_tasks()));
2970
2971 let key = part.key.ok_or(Error::InvalidResponse)?;
2972 let upload_id = part.upload_id;
2973
2974 let urls = part.urls.clone();
2975 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2977 EtagPart {
2978 etag: "".to_owned(),
2979 part_number: 0,
2980 };
2981 n_parts
2982 ]));
2983
2984 let tasks = (0..n_parts)
2986 .map(|part| {
2987 let http = http.clone();
2988 let url = urls[part].clone();
2989 let etags = etags.clone();
2990 let path = path.to_owned();
2991 let sem = sem.clone();
2992 let progress = progress.clone();
2993 let current = current.clone();
2994
2995 tokio::spawn(async move {
2996 let _permit = sem.acquire().await?;
2998
2999 let etag =
3001 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
3002
3003 let mut etags = etags.lock().await;
3005 etags[part] = EtagPart {
3006 etag,
3007 part_number: part + 1,
3008 };
3009
3010 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
3012 if let Some(progress) = &progress {
3013 let _ = progress
3014 .send(Progress {
3015 current: current + PART_SIZE,
3016 total,
3017 })
3018 .await;
3019 }
3020
3021 Ok::<(), Error>(())
3022 })
3023 })
3024 .collect::<Vec<_>>();
3025
3026 join_all(tasks)
3028 .await
3029 .into_iter()
3030 .collect::<Result<Vec<_>, _>>()?;
3031
3032 Ok(SnapshotCompleteMultipartParams {
3033 key,
3034 upload_id,
3035 etag_list: etags.lock().await.clone(),
3036 })
3037}
3038
3039async fn upload_part(
3040 http: reqwest::Client,
3041 url: String,
3042 path: PathBuf,
3043 part: usize,
3044 n_parts: usize,
3045) -> Result<String, Error> {
3046 let filesize = path.metadata()?.len() as usize;
3047 let mut file = File::open(path).await?;
3048 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
3049 .await?;
3050 let file = file.take(PART_SIZE as u64);
3051
3052 let body_length = if part + 1 == n_parts {
3053 filesize % PART_SIZE
3054 } else {
3055 PART_SIZE
3056 };
3057
3058 let stream = FramedRead::new(file, BytesCodec::new());
3059 let body = Body::wrap_stream(stream);
3060
3061 let resp = http
3062 .put(url.clone())
3063 .header(CONTENT_LENGTH, body_length)
3064 .body(body)
3065 .send()
3066 .await?
3067 .error_for_status()?;
3068
3069 let etag = resp
3070 .headers()
3071 .get("etag")
3072 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
3073 .to_str()
3074 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
3075 .to_owned();
3076
3077 let etag = etag
3079 .strip_prefix("\"")
3080 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
3081 let etag = etag
3082 .strip_suffix("\"")
3083 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
3084
3085 Ok(etag.to_owned())
3086}
3087
3088async fn upload_file_to_presigned_url(
3093 http: reqwest::Client,
3094 url: &str,
3095 path: PathBuf,
3096) -> Result<(), Error> {
3097 let file_data = fs::read(&path).await?;
3099 let file_size = file_data.len();
3100
3101 let resp = http
3103 .put(url)
3104 .header(CONTENT_LENGTH, file_size)
3105 .body(file_data)
3106 .send()
3107 .await?;
3108
3109 if resp.status().is_success() {
3110 debug!(
3111 "Successfully uploaded file: {:?} ({} bytes)",
3112 path, file_size
3113 );
3114 Ok(())
3115 } else {
3116 let status = resp.status();
3117 let error_text = resp.text().await.unwrap_or_default();
3118 Err(Error::InvalidParameters(format!(
3119 "Upload failed: HTTP {} - {}",
3120 status, error_text
3121 )))
3122 }
3123}
3124
3125#[cfg(test)]
3126mod tests {
3127 use super::*;
3128
3129 #[test]
3130 fn test_filter_and_sort_by_name_exact_match_first() {
3131 let items = vec![
3133 "Deer Roundtrip 123".to_string(),
3134 "Deer".to_string(),
3135 "Reindeer".to_string(),
3136 "DEER".to_string(),
3137 ];
3138 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3139 assert_eq!(result[0], "Deer"); assert_eq!(result[1], "DEER"); }
3142
3143 #[test]
3144 fn test_filter_and_sort_by_name_shorter_names_preferred() {
3145 let items = vec![
3147 "Test Dataset ABC".to_string(),
3148 "Test".to_string(),
3149 "Test Dataset".to_string(),
3150 ];
3151 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3152 assert_eq!(result[0], "Test"); assert_eq!(result[1], "Test Dataset"); assert_eq!(result[2], "Test Dataset ABC"); }
3156
3157 #[test]
3158 fn test_filter_and_sort_by_name_case_insensitive_filter() {
3159 let items = vec![
3161 "UPPERCASE".to_string(),
3162 "lowercase".to_string(),
3163 "MixedCase".to_string(),
3164 ];
3165 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
3166 assert_eq!(result.len(), 3); }
3168
3169 #[test]
3170 fn test_filter_and_sort_by_name_no_matches() {
3171 let items = vec!["Apple".to_string(), "Banana".to_string()];
3173 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
3174 assert!(result.is_empty());
3175 }
3176
3177 #[test]
3178 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
3179 let items = vec![
3181 "TestC".to_string(),
3182 "TestA".to_string(),
3183 "TestB".to_string(),
3184 ];
3185 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3186 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
3187 }
3188
3189 #[test]
3190 fn test_build_filename_no_flatten() {
3191 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
3193 assert_eq!(result, "image.jpg");
3194
3195 let result = Client::build_filename("test.png", false, None, None);
3196 assert_eq!(result, "test.png");
3197 }
3198
3199 #[test]
3200 fn test_build_filename_flatten_no_sequence() {
3201 let result = Client::build_filename("standalone.jpg", true, None, None);
3203 assert_eq!(result, "standalone.jpg");
3204 }
3205
3206 #[test]
3207 fn test_build_filename_flatten_with_sequence_not_prefixed() {
3208 let result = Client::build_filename(
3210 "image.camera.jpeg",
3211 true,
3212 Some(&"deer_sequence".to_string()),
3213 Some(42),
3214 );
3215 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
3216 }
3217
3218 #[test]
3219 fn test_build_filename_flatten_with_sequence_no_frame() {
3220 let result =
3222 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
3223 assert_eq!(result, "sequence_A_image.jpg");
3224 }
3225
3226 #[test]
3227 fn test_build_filename_flatten_already_prefixed() {
3228 let result = Client::build_filename(
3230 "deer_sequence_042.camera.jpeg",
3231 true,
3232 Some(&"deer_sequence".to_string()),
3233 Some(42),
3234 );
3235 assert_eq!(result, "deer_sequence_042.camera.jpeg");
3236 }
3237
3238 #[test]
3239 fn test_build_filename_flatten_already_prefixed_different_frame() {
3240 let result = Client::build_filename(
3243 "sequence_A_001.jpg",
3244 true,
3245 Some(&"sequence_A".to_string()),
3246 Some(2),
3247 );
3248 assert_eq!(result, "sequence_A_001.jpg");
3249 }
3250
3251 #[test]
3252 fn test_build_filename_flatten_partial_match() {
3253 let result = Client::build_filename(
3255 "test_sequence_A_image.jpg",
3256 true,
3257 Some(&"sequence_A".to_string()),
3258 Some(5),
3259 );
3260 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
3262 }
3263
3264 #[test]
3265 fn test_build_filename_flatten_preserves_extension() {
3266 let extensions = vec![
3268 "jpeg",
3269 "jpg",
3270 "png",
3271 "camera.jpeg",
3272 "lidar.pcd",
3273 "depth.png",
3274 ];
3275
3276 for ext in extensions {
3277 let filename = format!("image.{}", ext);
3278 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
3279 assert!(
3280 result.ends_with(&format!(".{}", ext)),
3281 "Extension .{} not preserved in {}",
3282 ext,
3283 result
3284 );
3285 }
3286 }
3287
3288 #[test]
3289 fn test_build_filename_flatten_sanitization_compatibility() {
3290 let result = Client::build_filename(
3292 "sample_001.jpg",
3293 true,
3294 Some(&"seq_name_with_underscores".to_string()),
3295 Some(10),
3296 );
3297 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
3298 }
3299}