edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID, SnapshotRestore,
10 SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
11 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
12 ValidationSessionID,
13 },
14 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
15 retry::{create_retry_policy, log_retry_configuration},
16 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
17};
18use base64::Engine as _;
19use chrono::{DateTime, Utc};
20use directories::ProjectDirs;
21use futures::{StreamExt as _, future::join_all};
22use log::{Level, debug, error, log_enabled, trace, warn};
23use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use std::{
26 collections::HashMap,
27 ffi::OsStr,
28 fs::create_dir_all,
29 io::{SeekFrom, Write as _},
30 path::{Path, PathBuf},
31 sync::{
32 Arc,
33 atomic::{AtomicUsize, Ordering},
34 },
35 time::Duration,
36 vec,
37};
38use tokio::{
39 fs::{self, File},
40 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
41 sync::{RwLock, Semaphore, mpsc::Sender},
42};
43use tokio_util::codec::{BytesCodec, FramedRead};
44use walkdir::WalkDir;
45
46#[cfg(feature = "polars")]
47use polars::prelude::*;
48
49static PART_SIZE: usize = 100 * 1024 * 1024;
50
51fn max_tasks() -> usize {
52 std::env::var("MAX_TASKS")
53 .ok()
54 .and_then(|v| v.parse().ok())
55 .unwrap_or_else(|| {
56 // Default to half the number of CPUs, minimum 2, maximum 8
57 // Lower max prevents timeout issues with large file uploads
58 let cpus = std::thread::available_parallelism()
59 .map(|n| n.get())
60 .unwrap_or(4);
61 (cpus / 2).clamp(2, 8)
62 })
63}
64
65/// Filters items by name and sorts by match quality.
66///
67/// Match quality priority (best to worst):
68/// 1. Exact match (case-sensitive)
69/// 2. Exact match (case-insensitive)
70/// 3. Substring match (shorter names first, then alphabetically)
71///
72/// This ensures that searching for "Deer" returns "Deer" before
73/// "Deer Roundtrip 20251129" or "Reindeer".
74fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
75where
76 F: Fn(&T) -> &str,
77{
78 let filter_lower = filter.to_lowercase();
79 let mut filtered: Vec<T> = items
80 .into_iter()
81 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
82 .collect();
83
84 filtered.sort_by(|a, b| {
85 let name_a = get_name(a);
86 let name_b = get_name(b);
87
88 // Priority 1: Exact match (case-sensitive)
89 let exact_a = name_a == filter;
90 let exact_b = name_b == filter;
91 if exact_a != exact_b {
92 return exact_b.cmp(&exact_a); // true (exact) comes first
93 }
94
95 // Priority 2: Exact match (case-insensitive)
96 let exact_ci_a = name_a.to_lowercase() == filter_lower;
97 let exact_ci_b = name_b.to_lowercase() == filter_lower;
98 if exact_ci_a != exact_ci_b {
99 return exact_ci_b.cmp(&exact_ci_a);
100 }
101
102 // Priority 3: Shorter names first (more specific matches)
103 let len_cmp = name_a.len().cmp(&name_b.len());
104 if len_cmp != std::cmp::Ordering::Equal {
105 return len_cmp;
106 }
107
108 // Priority 4: Alphabetical order for stability
109 name_a.cmp(name_b)
110 });
111
112 filtered
113}
114
115fn sanitize_path_component(name: &str) -> String {
116 let trimmed = name.trim();
117 if trimmed.is_empty() {
118 return "unnamed".to_string();
119 }
120
121 let component = Path::new(trimmed)
122 .file_name()
123 .unwrap_or_else(|| OsStr::new(trimmed));
124
125 let sanitized: String = component
126 .to_string_lossy()
127 .chars()
128 .map(|c| match c {
129 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
130 _ => c,
131 })
132 .collect();
133
134 if sanitized.is_empty() {
135 "unnamed".to_string()
136 } else {
137 sanitized
138 }
139}
140
141/// Progress information for long-running operations.
142///
143/// This struct tracks the current progress of operations like file uploads,
144/// downloads, or dataset processing. It provides the current count and total
145/// count to enable progress reporting in applications.
146///
147/// # Examples
148///
149/// ```rust
150/// use edgefirst_client::Progress;
151///
152/// let progress = Progress {
153/// current: 25,
154/// total: 100,
155/// };
156/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
157/// println!(
158/// "Progress: {:.1}% ({}/{})",
159/// percentage, progress.current, progress.total
160/// );
161/// ```
162#[derive(Debug, Clone)]
163pub struct Progress {
164 /// Current number of completed items.
165 pub current: usize,
166 /// Total number of items to process.
167 pub total: usize,
168}
169
170#[derive(Serialize)]
171struct RpcRequest<Params> {
172 id: u64,
173 jsonrpc: String,
174 method: String,
175 params: Option<Params>,
176}
177
178impl<T> Default for RpcRequest<T> {
179 fn default() -> Self {
180 RpcRequest {
181 id: 0,
182 jsonrpc: "2.0".to_string(),
183 method: "".to_string(),
184 params: None,
185 }
186 }
187}
188
189#[derive(Deserialize)]
190struct RpcError {
191 code: i32,
192 message: String,
193}
194
195#[derive(Deserialize)]
196struct RpcResponse<RpcResult> {
197 #[allow(dead_code)]
198 id: String,
199 #[allow(dead_code)]
200 jsonrpc: String,
201 error: Option<RpcError>,
202 result: Option<RpcResult>,
203}
204
205#[derive(Deserialize)]
206#[allow(dead_code)]
207struct EmptyResult {}
208
209#[derive(Debug, Serialize)]
210#[allow(dead_code)]
211struct SnapshotCreateParams {
212 snapshot_name: String,
213 keys: Vec<String>,
214}
215
216#[derive(Debug, Deserialize)]
217#[allow(dead_code)]
218struct SnapshotCreateResult {
219 snapshot_id: SnapshotID,
220 urls: Vec<String>,
221}
222
223#[derive(Debug, Serialize)]
224struct SnapshotCreateMultipartParams {
225 snapshot_name: String,
226 keys: Vec<String>,
227 file_sizes: Vec<usize>,
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(untagged)]
232enum SnapshotCreateMultipartResultField {
233 Id(u64),
234 Part(SnapshotPart),
235}
236
237#[derive(Debug, Serialize)]
238struct SnapshotCompleteMultipartParams {
239 key: String,
240 upload_id: String,
241 etag_list: Vec<EtagPart>,
242}
243
244#[derive(Debug, Clone, Serialize)]
245struct EtagPart {
246 #[serde(rename = "ETag")]
247 etag: String,
248 #[serde(rename = "PartNumber")]
249 part_number: usize,
250}
251
252#[derive(Debug, Clone, Deserialize)]
253struct SnapshotPart {
254 key: Option<String>,
255 upload_id: String,
256 urls: Vec<String>,
257}
258
259#[derive(Debug, Serialize)]
260struct SnapshotStatusParams {
261 snapshot_id: SnapshotID,
262 status: String,
263}
264
265#[derive(Deserialize, Debug)]
266struct SnapshotStatusResult {
267 #[allow(dead_code)]
268 pub id: SnapshotID,
269 #[allow(dead_code)]
270 pub uid: String,
271 #[allow(dead_code)]
272 pub description: String,
273 #[allow(dead_code)]
274 pub date: String,
275 #[allow(dead_code)]
276 pub status: String,
277}
278
279#[derive(Serialize)]
280#[allow(dead_code)]
281struct ImageListParams {
282 images_filter: ImagesFilter,
283 image_files_filter: HashMap<String, String>,
284 only_ids: bool,
285}
286
287#[derive(Serialize)]
288#[allow(dead_code)]
289struct ImagesFilter {
290 dataset_id: DatasetID,
291}
292
293/// Main client for interacting with EdgeFirst Studio Server.
294///
295/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
296/// and manages authentication, RPC calls, and data operations. It provides
297/// methods for managing projects, datasets, experiments, training sessions,
298/// and various utility functions for data processing.
299///
300/// The client supports multiple authentication methods and can work with both
301/// SaaS and self-hosted EdgeFirst Studio instances.
302///
303/// # Features
304///
305/// - **Authentication**: Token-based authentication with automatic persistence
306/// - **Dataset Management**: Upload, download, and manipulate datasets
307/// - **Project Operations**: Create and manage projects and experiments
308/// - **Training & Validation**: Submit and monitor ML training jobs
309/// - **Data Integration**: Convert between EdgeFirst datasets and popular
310/// formats
311/// - **Progress Tracking**: Real-time progress updates for long-running
312/// operations
313///
314/// # Examples
315///
316/// ```no_run
317/// use edgefirst_client::{Client, DatasetID};
318/// use std::str::FromStr;
319///
320/// # async fn example() -> Result<(), edgefirst_client::Error> {
321/// // Create a new client and authenticate
322/// let mut client = Client::new()?;
323/// let client = client
324/// .with_login("your-email@example.com", "password")
325/// .await?;
326///
327/// // Or use an existing token
328/// let base_client = Client::new()?;
329/// let client = base_client.with_token("your-token-here")?;
330///
331/// // Get organization and projects
332/// let org = client.organization().await?;
333/// let projects = client.projects(None).await?;
334///
335/// // Work with datasets
336/// let dataset_id = DatasetID::from_str("ds-abc123")?;
337/// let dataset = client.dataset(dataset_id).await?;
338/// # Ok(())
339/// # }
340/// ```
341/// Client is Clone but cannot derive Debug due to dyn TokenStorage
342#[derive(Clone)]
343pub struct Client {
344 http: reqwest::Client,
345 url: String,
346 token: Arc<RwLock<String>>,
347 /// Token storage backend. When set, tokens are automatically persisted.
348 storage: Option<Arc<dyn TokenStorage>>,
349 /// Legacy token path field for backwards compatibility with
350 /// with_token_path(). Deprecated: Use with_storage() instead.
351 token_path: Option<PathBuf>,
352}
353
354impl std::fmt::Debug for Client {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("Client")
357 .field("url", &self.url)
358 .field("has_storage", &self.storage.is_some())
359 .field("token_path", &self.token_path)
360 .finish()
361 }
362}
363
364/// Private context struct for pagination operations
365struct FetchContext<'a> {
366 dataset_id: DatasetID,
367 annotation_set_id: Option<AnnotationSetID>,
368 groups: &'a [String],
369 types: Vec<String>,
370 labels: &'a HashMap<String, u64>,
371}
372
373impl Client {
374 /// Create a new unauthenticated client with the default saas server.
375 ///
376 /// By default, the client uses [`FileTokenStorage`] for token persistence.
377 /// Use [`with_storage`][Self::with_storage],
378 /// [`with_memory_storage`][Self::with_memory_storage],
379 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
380 /// behavior.
381 ///
382 /// To connect to a different server, use [`with_server`][Self::with_server]
383 /// or [`with_token`][Self::with_token] (tokens include the server
384 /// instance).
385 ///
386 /// This client is created without a token and will need to authenticate
387 /// before using methods that require authentication.
388 ///
389 /// # Examples
390 ///
391 /// ```rust,no_run
392 /// use edgefirst_client::Client;
393 ///
394 /// # fn main() -> Result<(), edgefirst_client::Error> {
395 /// // Create client with default file storage
396 /// let client = Client::new()?;
397 ///
398 /// // Create client without token persistence
399 /// let client = Client::new()?.with_memory_storage();
400 /// # Ok(())
401 /// # }
402 /// ```
403 pub fn new() -> Result<Self, Error> {
404 log_retry_configuration();
405
406 // Get timeout from environment or use default
407 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
408 .ok()
409 .and_then(|s| s.parse().ok())
410 .unwrap_or(30); // Default 30s timeout for API calls
411
412 // Create single HTTP client with URL-based retry policy
413 //
414 // The retry policy classifies requests into two categories:
415 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
416 // errors
417 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
418 //
419 // This allows the same client to handle both API calls and file operations
420 // with appropriate retry behavior for each. See retry.rs for details.
421 let http = reqwest::Client::builder()
422 .connect_timeout(Duration::from_secs(10))
423 .timeout(Duration::from_secs(timeout_secs))
424 .pool_idle_timeout(Duration::from_secs(90))
425 .pool_max_idle_per_host(10)
426 .retry(create_retry_policy())
427 .build()?;
428
429 // Default to file storage, loading any existing token
430 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
431 Ok(file_storage) => Arc::new(file_storage),
432 Err(e) => {
433 warn!(
434 "Could not initialize file token storage: {}. Using memory storage.",
435 e
436 );
437 Arc::new(MemoryTokenStorage::new())
438 }
439 };
440
441 // Try to load existing token from storage
442 let token = match storage.load() {
443 Ok(Some(t)) => t,
444 Ok(None) => String::new(),
445 Err(e) => {
446 warn!(
447 "Failed to load token from storage: {}. Starting with empty token.",
448 e
449 );
450 String::new()
451 }
452 };
453
454 Ok(Client {
455 http,
456 url: "https://edgefirst.studio".to_string(),
457 token: Arc::new(tokio::sync::RwLock::new(token)),
458 storage: Some(storage),
459 token_path: None,
460 })
461 }
462
463 /// Returns a new client connected to the specified server instance.
464 ///
465 /// The server parameter is an instance name that maps to a URL:
466 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
467 /// server)
468 /// - `"test"` → `https://test.edgefirst.studio`
469 /// - `"stage"` → `https://stage.edgefirst.studio`
470 /// - `"dev"` → `https://dev.edgefirst.studio`
471 /// - `"{name}"` → `https://{name}.edgefirst.studio`
472 ///
473 /// # Server Selection Priority
474 ///
475 /// When using the CLI or Python API, server selection follows this
476 /// priority:
477 ///
478 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
479 /// they were issued for. If you have a valid token, its server is used.
480 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
481 /// token is available. If a token exists with a different server, a
482 /// warning is emitted and the token's server takes priority.
483 /// 3. **Default `"saas"`** - If no token and no server specified, the
484 /// production server (`https://edgefirst.studio`) is used.
485 ///
486 /// # Important Notes
487 ///
488 /// - If a token is already set in the client, calling this method will
489 /// **drop the token** as tokens are specific to the server instance.
490 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
491 /// token's server before calling this method.
492 /// - For login operations, call `with_server()` first, then authenticate.
493 ///
494 /// # Examples
495 ///
496 /// ```rust,no_run
497 /// use edgefirst_client::Client;
498 ///
499 /// # fn main() -> Result<(), edgefirst_client::Error> {
500 /// let client = Client::new()?.with_server("test")?;
501 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
502 /// # Ok(())
503 /// # }
504 /// ```
505 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
506 let url = match server {
507 "" | "saas" => "https://edgefirst.studio".to_string(),
508 name => format!("https://{}.edgefirst.studio", name),
509 };
510
511 // Clear token from storage when changing servers to prevent
512 // authentication issues with stale tokens from different instances
513 if let Some(ref storage) = self.storage
514 && let Err(e) = storage.clear()
515 {
516 warn!(
517 "Failed to clear token from storage when changing servers: {}",
518 e
519 );
520 }
521
522 Ok(Client {
523 url,
524 token: Arc::new(tokio::sync::RwLock::new(String::new())),
525 ..self.clone()
526 })
527 }
528
529 /// Returns a new client with the specified token storage backend.
530 ///
531 /// Use this to configure custom token storage, such as platform-specific
532 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
533 ///
534 /// # Examples
535 ///
536 /// ```rust,no_run
537 /// use edgefirst_client::{Client, FileTokenStorage};
538 /// use std::{path::PathBuf, sync::Arc};
539 ///
540 /// # fn main() -> Result<(), edgefirst_client::Error> {
541 /// // Use a custom file path for token storage
542 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
543 /// let client = Client::new()?.with_storage(Arc::new(storage));
544 /// # Ok(())
545 /// # }
546 /// ```
547 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
548 // Try to load existing token from the new storage
549 let token = match storage.load() {
550 Ok(Some(t)) => t,
551 Ok(None) => String::new(),
552 Err(e) => {
553 warn!(
554 "Failed to load token from storage: {}. Starting with empty token.",
555 e
556 );
557 String::new()
558 }
559 };
560
561 Client {
562 token: Arc::new(tokio::sync::RwLock::new(token)),
563 storage: Some(storage),
564 token_path: None,
565 ..self
566 }
567 }
568
569 /// Returns a new client with in-memory token storage (no persistence).
570 ///
571 /// Tokens are stored in memory only and lost when the application exits.
572 /// This is useful for testing or when you want to manage token persistence
573 /// externally.
574 ///
575 /// # Examples
576 ///
577 /// ```rust,no_run
578 /// use edgefirst_client::Client;
579 ///
580 /// # fn main() -> Result<(), edgefirst_client::Error> {
581 /// let client = Client::new()?.with_memory_storage();
582 /// # Ok(())
583 /// # }
584 /// ```
585 pub fn with_memory_storage(self) -> Self {
586 Client {
587 token: Arc::new(tokio::sync::RwLock::new(String::new())),
588 storage: Some(Arc::new(MemoryTokenStorage::new())),
589 token_path: None,
590 ..self
591 }
592 }
593
594 /// Returns a new client with no token storage.
595 ///
596 /// Tokens are not persisted. Use this when you want to manage tokens
597 /// entirely manually.
598 ///
599 /// # Examples
600 ///
601 /// ```rust,no_run
602 /// use edgefirst_client::Client;
603 ///
604 /// # fn main() -> Result<(), edgefirst_client::Error> {
605 /// let client = Client::new()?.with_no_storage();
606 /// # Ok(())
607 /// # }
608 /// ```
609 pub fn with_no_storage(self) -> Self {
610 Client {
611 storage: None,
612 token_path: None,
613 ..self
614 }
615 }
616
617 /// Returns a new client authenticated with the provided username and
618 /// password.
619 ///
620 /// The token is automatically persisted to storage (if configured).
621 ///
622 /// # Examples
623 ///
624 /// ```rust,no_run
625 /// use edgefirst_client::Client;
626 ///
627 /// # async fn example() -> Result<(), edgefirst_client::Error> {
628 /// let client = Client::new()?
629 /// .with_server("test")?
630 /// .with_login("user@example.com", "password")
631 /// .await?;
632 /// # Ok(())
633 /// # }
634 /// ```
635 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
636 let params = HashMap::from([("username", username), ("password", password)]);
637 let login: LoginResult = self
638 .rpc_without_auth("auth.login".to_owned(), Some(params))
639 .await?;
640
641 // Validate that the server returned a non-empty token
642 if login.token.is_empty() {
643 return Err(Error::EmptyToken);
644 }
645
646 // Persist token to storage if configured
647 if let Some(ref storage) = self.storage
648 && let Err(e) = storage.store(&login.token)
649 {
650 warn!("Failed to persist token to storage: {}", e);
651 }
652
653 Ok(Client {
654 token: Arc::new(tokio::sync::RwLock::new(login.token)),
655 ..self.clone()
656 })
657 }
658
659 /// Returns a new client which will load and save the token to the specified
660 /// path.
661 ///
662 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
663 /// [`FileTokenStorage`] instead for more flexible token management.
664 ///
665 /// This method is maintained for backwards compatibility with existing
666 /// code. It disables the default storage and uses file-based storage at
667 /// the specified path.
668 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
669 let token_path = match token_path {
670 Some(path) => path.to_path_buf(),
671 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
672 .ok_or_else(|| {
673 Error::IoError(std::io::Error::new(
674 std::io::ErrorKind::NotFound,
675 "Could not determine user config directory",
676 ))
677 })?
678 .config_dir()
679 .join("token"),
680 };
681
682 debug!("Using token path (legacy): {:?}", token_path);
683
684 let token = match token_path.exists() {
685 true => std::fs::read_to_string(&token_path)?,
686 false => "".to_string(),
687 };
688
689 if !token.is_empty() {
690 match self.with_token(&token) {
691 Ok(client) => Ok(Client {
692 token_path: Some(token_path),
693 storage: None, // Disable new storage when using legacy token_path
694 ..client
695 }),
696 Err(e) => {
697 // Token is corrupted or invalid - remove it and continue with no token
698 warn!(
699 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
700 token_path, e
701 );
702 if let Err(remove_err) = std::fs::remove_file(&token_path) {
703 warn!("Failed to remove corrupted token file: {:?}", remove_err);
704 }
705 Ok(Client {
706 token_path: Some(token_path),
707 storage: None,
708 ..self.clone()
709 })
710 }
711 }
712 } else {
713 Ok(Client {
714 token_path: Some(token_path),
715 storage: None,
716 ..self.clone()
717 })
718 }
719 }
720
721 /// Returns a new client authenticated with the provided token.
722 ///
723 /// The token is automatically persisted to storage (if configured).
724 /// The server URL is extracted from the token payload.
725 ///
726 /// # Examples
727 ///
728 /// ```rust,no_run
729 /// use edgefirst_client::Client;
730 ///
731 /// # fn main() -> Result<(), edgefirst_client::Error> {
732 /// let client = Client::new()?.with_token("your-jwt-token")?;
733 /// # Ok(())
734 /// # }
735 /// ```
736 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
737 if token.is_empty() {
738 return Ok(self.clone());
739 }
740
741 let token_parts: Vec<&str> = token.split('.').collect();
742 if token_parts.len() != 3 {
743 return Err(Error::InvalidToken);
744 }
745
746 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
747 .decode(token_parts[1])
748 .map_err(|_| Error::InvalidToken)?;
749 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
750 let server = match payload.get("server") {
751 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
752 None => return Err(Error::InvalidToken),
753 };
754
755 // Persist token to storage if configured
756 if let Some(ref storage) = self.storage
757 && let Err(e) = storage.store(token)
758 {
759 warn!("Failed to persist token to storage: {}", e);
760 }
761
762 Ok(Client {
763 url: format!("https://{}.edgefirst.studio", server),
764 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
765 ..self.clone()
766 })
767 }
768
769 /// Persist the current token to storage.
770 ///
771 /// This is automatically called when using [`with_login`][Self::with_login]
772 /// or [`with_token`][Self::with_token], so you typically don't need to call
773 /// this directly.
774 ///
775 /// If using the legacy `token_path` configuration, saves to the file path.
776 /// If using the new storage abstraction, saves to the configured storage.
777 pub async fn save_token(&self) -> Result<(), Error> {
778 let token = self.token.read().await;
779
780 // Try new storage first
781 if let Some(ref storage) = self.storage {
782 storage.store(&token)?;
783 debug!("Token saved to storage");
784 return Ok(());
785 }
786
787 // Fall back to legacy token_path behavior
788 let path = self.token_path.clone().unwrap_or_else(|| {
789 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
790 .map(|dirs| dirs.config_dir().join("token"))
791 .unwrap_or_else(|| PathBuf::from(".token"))
792 });
793
794 create_dir_all(path.parent().ok_or_else(|| {
795 Error::IoError(std::io::Error::new(
796 std::io::ErrorKind::InvalidInput,
797 "Token path has no parent directory",
798 ))
799 })?)?;
800 let mut file = std::fs::File::create(&path)?;
801 file.write_all(token.as_bytes())?;
802
803 debug!("Saved token to {:?}", path);
804
805 Ok(())
806 }
807
808 /// Return the version of the EdgeFirst Studio server for the current
809 /// client connection.
810 pub async fn version(&self) -> Result<String, Error> {
811 let version: HashMap<String, String> = self
812 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
813 .await?;
814 let version = version.get("version").ok_or(Error::InvalidResponse)?;
815 Ok(version.to_owned())
816 }
817
818 /// Clear the token used to authenticate the client with the server.
819 ///
820 /// Clears the token from memory and from storage (if configured).
821 /// If using the legacy `token_path` configuration, removes the token file.
822 pub async fn logout(&self) -> Result<(), Error> {
823 {
824 let mut token = self.token.write().await;
825 *token = "".to_string();
826 }
827
828 // Clear from new storage if configured
829 if let Some(ref storage) = self.storage
830 && let Err(e) = storage.clear()
831 {
832 warn!("Failed to clear token from storage: {}", e);
833 }
834
835 // Also clear legacy token_path if configured
836 if let Some(path) = &self.token_path
837 && path.exists()
838 {
839 fs::remove_file(path).await?;
840 }
841
842 Ok(())
843 }
844
845 /// Return the token used to authenticate the client with the server. When
846 /// logging into the server using a username and password, the token is
847 /// returned by the server and stored in the client for future interactions.
848 pub async fn token(&self) -> String {
849 self.token.read().await.clone()
850 }
851
852 /// Verify the token used to authenticate the client with the server. This
853 /// method is used to ensure that the token is still valid and has not
854 /// expired. If the token is invalid, the server will return an error and
855 /// the client will need to login again.
856 pub async fn verify_token(&self) -> Result<(), Error> {
857 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
858 .await?;
859 Ok::<(), Error>(())
860 }
861
862 /// Renew the token used to authenticate the client with the server.
863 ///
864 /// Refreshes the token before it expires. If the token has already expired,
865 /// the server will return an error and you will need to login again.
866 ///
867 /// The new token is automatically persisted to storage (if configured).
868 pub async fn renew_token(&self) -> Result<(), Error> {
869 let params = HashMap::from([("username".to_string(), self.username().await?)]);
870 let result: LoginResult = self
871 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
872 .await?;
873
874 {
875 let mut token = self.token.write().await;
876 *token = result.token.clone();
877 }
878
879 // Persist to new storage if configured
880 if let Some(ref storage) = self.storage
881 && let Err(e) = storage.store(&result.token)
882 {
883 warn!("Failed to persist renewed token to storage: {}", e);
884 }
885
886 // Also persist to legacy token_path if configured
887 if self.token_path.is_some() {
888 self.save_token().await?;
889 }
890
891 Ok(())
892 }
893
894 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
895 let token = self.token.read().await;
896 if token.is_empty() {
897 return Err(Error::EmptyToken);
898 }
899
900 let token_parts: Vec<&str> = token.split('.').collect();
901 if token_parts.len() != 3 {
902 return Err(Error::InvalidToken);
903 }
904
905 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
906 .decode(token_parts[1])
907 .map_err(|_| Error::InvalidToken)?;
908 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
909 match payload.get(field) {
910 Some(value) => Ok(value.to_owned()),
911 None => Err(Error::InvalidToken),
912 }
913 }
914
915 /// Returns the URL of the EdgeFirst Studio server for the current client.
916 pub fn url(&self) -> &str {
917 &self.url
918 }
919
920 /// Returns the server name for the current client.
921 ///
922 /// This extracts the server name from the client's URL:
923 /// - `https://edgefirst.studio` → `"saas"`
924 /// - `https://test.edgefirst.studio` → `"test"`
925 /// - `https://{name}.edgefirst.studio` → `"{name}"`
926 ///
927 /// # Examples
928 ///
929 /// ```rust,no_run
930 /// use edgefirst_client::Client;
931 ///
932 /// # fn main() -> Result<(), edgefirst_client::Error> {
933 /// let client = Client::new()?.with_server("test")?;
934 /// assert_eq!(client.server(), "test");
935 ///
936 /// let client = Client::new()?; // default
937 /// assert_eq!(client.server(), "saas");
938 /// # Ok(())
939 /// # }
940 /// ```
941 pub fn server(&self) -> &str {
942 if self.url == "https://edgefirst.studio" {
943 "saas"
944 } else if let Some(name) = self.url.strip_prefix("https://") {
945 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
946 } else {
947 "saas"
948 }
949 }
950
951 /// Returns the username associated with the current token.
952 pub async fn username(&self) -> Result<String, Error> {
953 match self.token_field("username").await? {
954 serde_json::Value::String(username) => Ok(username),
955 _ => Err(Error::InvalidToken),
956 }
957 }
958
959 /// Returns the expiration time for the current token.
960 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
961 let ts = match self.token_field("exp").await? {
962 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
963 _ => return Err(Error::InvalidToken),
964 };
965
966 match DateTime::<Utc>::from_timestamp_secs(ts) {
967 Some(dt) => Ok(dt),
968 None => Err(Error::InvalidToken),
969 }
970 }
971
972 /// Returns the organization information for the current user.
973 pub async fn organization(&self) -> Result<Organization, Error> {
974 self.rpc::<(), Organization>("org.get".to_owned(), None)
975 .await
976 }
977
978 /// Returns a list of projects available to the user. The projects are
979 /// returned as a vector of Project objects. If a name filter is
980 /// provided, only projects matching the filter are returned.
981 ///
982 /// Results are sorted by match quality: exact matches first, then
983 /// case-insensitive exact matches, then shorter names (more specific),
984 /// then alphabetically.
985 ///
986 /// Projects are the top-level organizational unit in EdgeFirst Studio.
987 /// Projects contain datasets, trainers, and trainer sessions. Projects
988 /// are used to group related datasets and trainers together.
989 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
990 let projects = self
991 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
992 .await?;
993 if let Some(name) = name {
994 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
995 } else {
996 Ok(projects)
997 }
998 }
999
1000 /// Return the project with the specified project ID. If the project does
1001 /// not exist, an error is returned.
1002 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1003 let params = HashMap::from([("project_id", project_id)]);
1004 self.rpc("project.get".to_owned(), Some(params)).await
1005 }
1006
1007 /// Returns a list of datasets available to the user. The datasets are
1008 /// returned as a vector of Dataset objects. If a name filter is
1009 /// provided, only datasets matching the filter are returned.
1010 ///
1011 /// Results are sorted by match quality: exact matches first, then
1012 /// case-insensitive exact matches, then shorter names (more specific),
1013 /// then alphabetically. This ensures "Deer" returns before "Deer
1014 /// Roundtrip".
1015 pub async fn datasets(
1016 &self,
1017 project_id: ProjectID,
1018 name: Option<&str>,
1019 ) -> Result<Vec<Dataset>, Error> {
1020 let params = HashMap::from([("project_id", project_id)]);
1021 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1022 if let Some(name) = name {
1023 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1024 } else {
1025 Ok(datasets)
1026 }
1027 }
1028
1029 /// Return the dataset with the specified dataset ID. If the dataset does
1030 /// not exist, an error is returned.
1031 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1032 let params = HashMap::from([("dataset_id", dataset_id)]);
1033 self.rpc("dataset.get".to_owned(), Some(params)).await
1034 }
1035
1036 /// Lists the labels for the specified dataset.
1037 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1038 let params = HashMap::from([("dataset_id", dataset_id)]);
1039 self.rpc("label.list".to_owned(), Some(params)).await
1040 }
1041
1042 /// Add a new label to the dataset with the specified name.
1043 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1044 let new_label = NewLabel {
1045 dataset_id,
1046 labels: vec![NewLabelObject {
1047 name: name.to_owned(),
1048 }],
1049 };
1050 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1051 Ok(())
1052 }
1053
1054 /// Removes the label with the specified ID from the dataset. Label IDs are
1055 /// globally unique so the dataset_id is not required.
1056 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1057 let params = HashMap::from([("label_id", label_id)]);
1058 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1059 Ok(())
1060 }
1061
1062 /// Creates a new dataset in the specified project.
1063 ///
1064 /// # Arguments
1065 ///
1066 /// * `project_id` - The ID of the project to create the dataset in
1067 /// * `name` - The name of the new dataset
1068 /// * `description` - Optional description for the dataset
1069 ///
1070 /// # Returns
1071 ///
1072 /// Returns the dataset ID of the newly created dataset.
1073 pub async fn create_dataset(
1074 &self,
1075 project_id: &str,
1076 name: &str,
1077 description: Option<&str>,
1078 ) -> Result<DatasetID, Error> {
1079 let mut params = HashMap::new();
1080 params.insert("project_id", project_id);
1081 params.insert("name", name);
1082 if let Some(desc) = description {
1083 params.insert("description", desc);
1084 }
1085
1086 #[derive(Deserialize)]
1087 struct CreateDatasetResult {
1088 id: DatasetID,
1089 }
1090
1091 let result: CreateDatasetResult =
1092 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1093 Ok(result.id)
1094 }
1095
1096 /// Deletes a dataset by marking it as deleted.
1097 ///
1098 /// # Arguments
1099 ///
1100 /// * `dataset_id` - The ID of the dataset to delete
1101 ///
1102 /// # Returns
1103 ///
1104 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1105 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1106 let params = HashMap::from([("id", dataset_id)]);
1107 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1108 Ok(())
1109 }
1110
1111 /// Updates the label with the specified ID to have the new name or index.
1112 /// Label IDs cannot be changed. Label IDs are globally unique so the
1113 /// dataset_id is not required.
1114 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1115 #[derive(Serialize)]
1116 struct Params {
1117 dataset_id: DatasetID,
1118 label_id: u64,
1119 label_name: String,
1120 label_index: u64,
1121 }
1122
1123 let _: String = self
1124 .rpc(
1125 "label.update".to_owned(),
1126 Some(Params {
1127 dataset_id: label.dataset_id(),
1128 label_id: label.id(),
1129 label_name: label.name().to_owned(),
1130 label_index: label.index(),
1131 }),
1132 )
1133 .await?;
1134 Ok(())
1135 }
1136
1137 /// Downloads dataset samples to the local filesystem.
1138 ///
1139 /// # Arguments
1140 ///
1141 /// * `dataset_id` - The unique identifier of the dataset
1142 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1143 /// * `file_types` - File types to download (e.g., Image, LidarPcd)
1144 /// * `output` - Local directory to save downloaded files
1145 /// * `flatten` - If true, download all files to output root without
1146 /// sequence subdirectories. When flattening, filenames are prefixed with
1147 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1148 /// unavailable) unless the filename already starts with
1149 /// `{sequence_name}_`, to avoid conflicts between sequences.
1150 /// * `progress` - Optional channel for progress updates
1151 ///
1152 /// # Returns
1153 ///
1154 /// Returns `Ok(())` on success or an error if download fails.
1155 ///
1156 /// # Example
1157 ///
1158 /// ```rust,no_run
1159 /// # use edgefirst_client::{Client, DatasetID, FileType};
1160 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1161 /// let client = Client::new()?.with_token_path(None)?;
1162 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1163 ///
1164 /// // Download with sequence subdirectories (default)
1165 /// client
1166 /// .download_dataset(
1167 /// dataset_id,
1168 /// &[],
1169 /// &[FileType::Image],
1170 /// "./data".into(),
1171 /// false,
1172 /// None,
1173 /// )
1174 /// .await?;
1175 ///
1176 /// // Download flattened (all files in one directory)
1177 /// client
1178 /// .download_dataset(
1179 /// dataset_id,
1180 /// &[],
1181 /// &[FileType::Image],
1182 /// "./data".into(),
1183 /// true,
1184 /// None,
1185 /// )
1186 /// .await?;
1187 /// # Ok(())
1188 /// # }
1189 /// ```
1190 pub async fn download_dataset(
1191 &self,
1192 dataset_id: DatasetID,
1193 groups: &[String],
1194 file_types: &[FileType],
1195 output: PathBuf,
1196 flatten: bool,
1197 progress: Option<Sender<Progress>>,
1198 ) -> Result<(), Error> {
1199 let samples = self
1200 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1201 .await?;
1202 fs::create_dir_all(&output).await?;
1203
1204 let client = self.clone();
1205 let file_types = file_types.to_vec();
1206 let output = output.clone();
1207
1208 parallel_foreach_items(samples, progress, move |sample| {
1209 let client = client.clone();
1210 let file_types = file_types.clone();
1211 let output = output.clone();
1212
1213 async move {
1214 for file_type in file_types {
1215 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1216 let (file_ext, is_image) = match file_type.clone() {
1217 FileType::Image => (
1218 infer::get(&data)
1219 .expect("Failed to identify image file format for sample")
1220 .extension()
1221 .to_string(),
1222 true,
1223 ),
1224 other => (other.to_string(), false),
1225 };
1226
1227 // Determine target directory based on sequence membership and flatten
1228 // option
1229 // - flatten=false + sequence_name: dataset/sequence_name/
1230 // - flatten=false + no sequence: dataset/ (root level)
1231 // - flatten=true: dataset/ (all files in output root)
1232 // NOTE: group (train/val/test) is NOT used for directory structure
1233 let sequence_dir = sample
1234 .sequence_name()
1235 .map(|name| sanitize_path_component(name));
1236
1237 let target_dir = if flatten {
1238 output.clone()
1239 } else {
1240 sequence_dir
1241 .as_ref()
1242 .map(|seq| output.join(seq))
1243 .unwrap_or_else(|| output.clone())
1244 };
1245 fs::create_dir_all(&target_dir).await?;
1246
1247 let sanitized_sample_name = sample
1248 .name()
1249 .map(|name| sanitize_path_component(&name))
1250 .unwrap_or_else(|| "unknown".to_string());
1251
1252 let image_name = sample.image_name().map(sanitize_path_component);
1253
1254 // Construct filename with smart prefixing for flatten mode
1255 // When flatten=true and sample belongs to a sequence:
1256 // - Check if filename already starts with "{sequence_name}_"
1257 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1258 // - If yes, use filename as-is (already uniquely named)
1259 let file_name = if is_image {
1260 if let Some(img_name) = image_name {
1261 Self::build_filename(
1262 &img_name,
1263 flatten,
1264 sequence_dir.as_ref(),
1265 sample.frame_number(),
1266 )
1267 } else {
1268 format!("{}.{}", sanitized_sample_name, file_ext)
1269 }
1270 } else {
1271 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1272 Self::build_filename(
1273 &base_name,
1274 flatten,
1275 sequence_dir.as_ref(),
1276 sample.frame_number(),
1277 )
1278 };
1279
1280 let file_path = target_dir.join(&file_name);
1281
1282 let mut file = File::create(&file_path).await?;
1283 file.write_all(&data).await?;
1284 } else {
1285 warn!(
1286 "No data for sample: {}",
1287 sample
1288 .id()
1289 .map(|id| id.to_string())
1290 .unwrap_or_else(|| "unknown".to_string())
1291 );
1292 }
1293 }
1294
1295 Ok(())
1296 }
1297 })
1298 .await
1299 }
1300
1301 /// Builds a filename with smart prefixing for flatten mode.
1302 ///
1303 /// When flattening sequences into a single directory, this function ensures
1304 /// unique filenames by checking if the sequence prefix already exists and
1305 /// adding it if necessary.
1306 ///
1307 /// # Logic
1308 ///
1309 /// - If `flatten=false`: returns `base_name` unchanged
1310 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
1311 /// - If `flatten=true` and in sequence:
1312 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
1313 /// unchanged
1314 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
1315 /// `{sequence_name}_{base_name}`
1316 fn build_filename(
1317 base_name: &str,
1318 flatten: bool,
1319 sequence_name: Option<&String>,
1320 frame_number: Option<u32>,
1321 ) -> String {
1322 if !flatten || sequence_name.is_none() {
1323 return base_name.to_string();
1324 }
1325
1326 let seq_name = sequence_name.unwrap();
1327 let prefix = format!("{}_", seq_name);
1328
1329 // Check if already prefixed with sequence name
1330 if base_name.starts_with(&prefix) {
1331 base_name.to_string()
1332 } else {
1333 // Add sequence (and optionally frame) prefix
1334 match frame_number {
1335 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1336 None => format!("{}{}", prefix, base_name),
1337 }
1338 }
1339 }
1340
1341 /// List available annotation sets for the specified dataset.
1342 pub async fn annotation_sets(
1343 &self,
1344 dataset_id: DatasetID,
1345 ) -> Result<Vec<AnnotationSet>, Error> {
1346 let params = HashMap::from([("dataset_id", dataset_id)]);
1347 self.rpc("annset.list".to_owned(), Some(params)).await
1348 }
1349
1350 /// Create a new annotation set for the specified dataset.
1351 ///
1352 /// # Arguments
1353 ///
1354 /// * `dataset_id` - The ID of the dataset to create the annotation set in
1355 /// * `name` - The name of the new annotation set
1356 /// * `description` - Optional description for the annotation set
1357 ///
1358 /// # Returns
1359 ///
1360 /// Returns the annotation set ID of the newly created annotation set.
1361 pub async fn create_annotation_set(
1362 &self,
1363 dataset_id: DatasetID,
1364 name: &str,
1365 description: Option<&str>,
1366 ) -> Result<AnnotationSetID, Error> {
1367 #[derive(Serialize)]
1368 struct Params<'a> {
1369 dataset_id: DatasetID,
1370 name: &'a str,
1371 operator: &'a str,
1372 #[serde(skip_serializing_if = "Option::is_none")]
1373 description: Option<&'a str>,
1374 }
1375
1376 #[derive(Deserialize)]
1377 struct CreateAnnotationSetResult {
1378 id: AnnotationSetID,
1379 }
1380
1381 let username = self.username().await?;
1382 let result: CreateAnnotationSetResult = self
1383 .rpc(
1384 "annset.add".to_owned(),
1385 Some(Params {
1386 dataset_id,
1387 name,
1388 operator: &username,
1389 description,
1390 }),
1391 )
1392 .await?;
1393 Ok(result.id)
1394 }
1395
1396 /// Deletes an annotation set by marking it as deleted.
1397 ///
1398 /// # Arguments
1399 ///
1400 /// * `annotation_set_id` - The ID of the annotation set to delete
1401 ///
1402 /// # Returns
1403 ///
1404 /// Returns `Ok(())` if the annotation set was successfully marked as
1405 /// deleted.
1406 pub async fn delete_annotation_set(
1407 &self,
1408 annotation_set_id: AnnotationSetID,
1409 ) -> Result<(), Error> {
1410 let params = HashMap::from([("id", annotation_set_id)]);
1411 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1412 Ok(())
1413 }
1414
1415 /// Retrieve the annotation set with the specified ID.
1416 pub async fn annotation_set(
1417 &self,
1418 annotation_set_id: AnnotationSetID,
1419 ) -> Result<AnnotationSet, Error> {
1420 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1421 self.rpc("annset.get".to_owned(), Some(params)).await
1422 }
1423
1424 /// Get the annotations for the specified annotation set with the
1425 /// requested annotation types. The annotation types are used to filter
1426 /// the annotations returned. The groups parameter is used to filter for
1427 /// dataset groups (train, val, test). Images which do not have any
1428 /// annotations are also included in the result as long as they are in the
1429 /// requested groups (when specified).
1430 ///
1431 /// The result is a vector of Annotations objects which contain the
1432 /// full dataset along with the annotations for the specified types.
1433 ///
1434 /// To get the annotations as a DataFrame, use the `annotations_dataframe`
1435 /// method instead.
1436 pub async fn annotations(
1437 &self,
1438 annotation_set_id: AnnotationSetID,
1439 groups: &[String],
1440 annotation_types: &[AnnotationType],
1441 progress: Option<Sender<Progress>>,
1442 ) -> Result<Vec<Annotation>, Error> {
1443 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1444 let labels = self
1445 .labels(dataset_id)
1446 .await?
1447 .into_iter()
1448 .map(|label| (label.name().to_string(), label.index()))
1449 .collect::<HashMap<_, _>>();
1450 let total = self
1451 .samples_count(
1452 dataset_id,
1453 Some(annotation_set_id),
1454 annotation_types,
1455 groups,
1456 &[],
1457 )
1458 .await?
1459 .total as usize;
1460
1461 if total == 0 {
1462 return Ok(vec![]);
1463 }
1464
1465 let context = FetchContext {
1466 dataset_id,
1467 annotation_set_id: Some(annotation_set_id),
1468 groups,
1469 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1470 labels: &labels,
1471 };
1472
1473 self.fetch_annotations_paginated(context, total, progress)
1474 .await
1475 }
1476
1477 async fn fetch_annotations_paginated(
1478 &self,
1479 context: FetchContext<'_>,
1480 total: usize,
1481 progress: Option<Sender<Progress>>,
1482 ) -> Result<Vec<Annotation>, Error> {
1483 let mut annotations = vec![];
1484 let mut continue_token: Option<String> = None;
1485 let mut current = 0;
1486
1487 loop {
1488 let params = SamplesListParams {
1489 dataset_id: context.dataset_id,
1490 annotation_set_id: context.annotation_set_id,
1491 types: context.types.clone(),
1492 group_names: context.groups.to_vec(),
1493 continue_token,
1494 };
1495
1496 let result: SamplesListResult =
1497 self.rpc("samples.list".to_owned(), Some(params)).await?;
1498 current += result.samples.len();
1499 continue_token = result.continue_token;
1500
1501 if result.samples.is_empty() {
1502 break;
1503 }
1504
1505 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1506
1507 if let Some(progress) = &progress {
1508 let _ = progress.send(Progress { current, total }).await;
1509 }
1510
1511 match &continue_token {
1512 Some(token) if !token.is_empty() => continue,
1513 _ => break,
1514 }
1515 }
1516
1517 drop(progress);
1518 Ok(annotations)
1519 }
1520
1521 fn process_sample_annotations(
1522 &self,
1523 samples: &[Sample],
1524 labels: &HashMap<String, u64>,
1525 annotations: &mut Vec<Annotation>,
1526 ) {
1527 for sample in samples {
1528 if sample.annotations().is_empty() {
1529 let mut annotation = Annotation::new();
1530 annotation.set_sample_id(sample.id());
1531 annotation.set_name(sample.name());
1532 annotation.set_sequence_name(sample.sequence_name().cloned());
1533 annotation.set_frame_number(sample.frame_number());
1534 annotation.set_group(sample.group().cloned());
1535 annotations.push(annotation);
1536 continue;
1537 }
1538
1539 for annotation in sample.annotations() {
1540 let mut annotation = annotation.clone();
1541 annotation.set_sample_id(sample.id());
1542 annotation.set_name(sample.name());
1543 annotation.set_sequence_name(sample.sequence_name().cloned());
1544 annotation.set_frame_number(sample.frame_number());
1545 annotation.set_group(sample.group().cloned());
1546 Self::set_label_index_from_map(&mut annotation, labels);
1547 annotations.push(annotation);
1548 }
1549 }
1550 }
1551
1552 /// Helper to parse frame number from image_name when sequence_name is
1553 /// present. This ensures frame_number is always derived from the image
1554 /// filename, not from the server's frame_number field (which may be
1555 /// inconsistent).
1556 ///
1557 /// Returns Some(frame_number) if sequence_name is present and frame can be
1558 /// parsed, otherwise None.
1559 fn parse_frame_from_image_name(
1560 image_name: Option<&String>,
1561 sequence_name: Option<&String>,
1562 ) -> Option<u32> {
1563 use std::path::Path;
1564
1565 let sequence = sequence_name?;
1566 let name = image_name?;
1567
1568 // Extract stem (remove extension)
1569 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1570
1571 // Parse frame from format: "sequence_XXX" where XXX is the frame number
1572 stem.strip_prefix(sequence)
1573 .and_then(|suffix| suffix.strip_prefix('_'))
1574 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1575 }
1576
1577 /// Helper to set label index from a label map
1578 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1579 if let Some(label) = annotation.label() {
1580 annotation.set_label_index(Some(labels[label.as_str()]));
1581 }
1582 }
1583
1584 pub async fn samples_count(
1585 &self,
1586 dataset_id: DatasetID,
1587 annotation_set_id: Option<AnnotationSetID>,
1588 annotation_types: &[AnnotationType],
1589 groups: &[String],
1590 types: &[FileType],
1591 ) -> Result<SamplesCountResult, Error> {
1592 let types = annotation_types
1593 .iter()
1594 .map(|t| t.to_string())
1595 .chain(types.iter().map(|t| t.to_string()))
1596 .collect::<Vec<_>>();
1597
1598 let params = SamplesListParams {
1599 dataset_id,
1600 annotation_set_id,
1601 group_names: groups.to_vec(),
1602 types,
1603 continue_token: None,
1604 };
1605
1606 self.rpc("samples.count".to_owned(), Some(params)).await
1607 }
1608
1609 pub async fn samples(
1610 &self,
1611 dataset_id: DatasetID,
1612 annotation_set_id: Option<AnnotationSetID>,
1613 annotation_types: &[AnnotationType],
1614 groups: &[String],
1615 types: &[FileType],
1616 progress: Option<Sender<Progress>>,
1617 ) -> Result<Vec<Sample>, Error> {
1618 let types_vec = annotation_types
1619 .iter()
1620 .map(|t| t.to_string())
1621 .chain(types.iter().map(|t| t.to_string()))
1622 .collect::<Vec<_>>();
1623 let labels = self
1624 .labels(dataset_id)
1625 .await?
1626 .into_iter()
1627 .map(|label| (label.name().to_string(), label.index()))
1628 .collect::<HashMap<_, _>>();
1629 let total = self
1630 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1631 .await?
1632 .total as usize;
1633
1634 if total == 0 {
1635 return Ok(vec![]);
1636 }
1637
1638 let context = FetchContext {
1639 dataset_id,
1640 annotation_set_id,
1641 groups,
1642 types: types_vec,
1643 labels: &labels,
1644 };
1645
1646 self.fetch_samples_paginated(context, total, progress).await
1647 }
1648
1649 async fn fetch_samples_paginated(
1650 &self,
1651 context: FetchContext<'_>,
1652 total: usize,
1653 progress: Option<Sender<Progress>>,
1654 ) -> Result<Vec<Sample>, Error> {
1655 let mut samples = vec![];
1656 let mut continue_token: Option<String> = None;
1657 let mut current = 0;
1658
1659 loop {
1660 let params = SamplesListParams {
1661 dataset_id: context.dataset_id,
1662 annotation_set_id: context.annotation_set_id,
1663 types: context.types.clone(),
1664 group_names: context.groups.to_vec(),
1665 continue_token: continue_token.clone(),
1666 };
1667
1668 let result: SamplesListResult =
1669 self.rpc("samples.list".to_owned(), Some(params)).await?;
1670 current += result.samples.len();
1671 continue_token = result.continue_token;
1672
1673 if result.samples.is_empty() {
1674 break;
1675 }
1676
1677 samples.append(
1678 &mut result
1679 .samples
1680 .into_iter()
1681 .map(|s| {
1682 // Use server's frame_number if valid (>= 0 after deserialization)
1683 // Otherwise parse from image_name as fallback
1684 // This ensures we respect explicit frame_number from uploads
1685 // while still handling legacy data that only has filename encoding
1686 let frame_number = s.frame_number.or_else(|| {
1687 Self::parse_frame_from_image_name(
1688 s.image_name.as_ref(),
1689 s.sequence_name.as_ref(),
1690 )
1691 });
1692
1693 let mut anns = s.annotations().to_vec();
1694 for ann in &mut anns {
1695 // Set annotation fields from parent sample
1696 ann.set_name(s.name());
1697 ann.set_group(s.group().cloned());
1698 ann.set_sequence_name(s.sequence_name().cloned());
1699 ann.set_frame_number(frame_number);
1700 Self::set_label_index_from_map(ann, context.labels);
1701 }
1702 s.with_annotations(anns).with_frame_number(frame_number)
1703 })
1704 .collect::<Vec<_>>(),
1705 );
1706
1707 if let Some(progress) = &progress {
1708 let _ = progress.send(Progress { current, total }).await;
1709 }
1710
1711 match &continue_token {
1712 Some(token) if !token.is_empty() => continue,
1713 _ => break,
1714 }
1715 }
1716
1717 drop(progress);
1718 Ok(samples)
1719 }
1720
1721 /// Populates (imports) samples into a dataset using the `samples.populate2`
1722 /// API.
1723 ///
1724 /// This method creates new samples in the specified dataset, optionally
1725 /// with annotations and sensor data files. For each sample, the `files`
1726 /// field is checked for local file paths. If a filename is a valid path
1727 /// to an existing file, the file will be automatically uploaded to S3
1728 /// using presigned URLs returned by the server. The filename in the
1729 /// request is replaced with the basename (path removed) before sending
1730 /// to the server.
1731 ///
1732 /// # Important Notes
1733 ///
1734 /// - **`annotation_set_id` is REQUIRED** when importing samples with
1735 /// annotations. Without it, the server will accept the request but will
1736 /// not save the annotation data. Use [`Client::annotation_sets`] to query
1737 /// available annotation sets for a dataset, or create a new one via the
1738 /// Studio UI.
1739 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
1740 /// boxes. Divide pixel coordinates by image width/height before creating
1741 /// [`Box2d`](crate::Box2d) annotations.
1742 /// - **Files are uploaded automatically** when the filename is a valid
1743 /// local path. The method will replace the full path with just the
1744 /// basename before sending to the server.
1745 /// - **Image dimensions are extracted automatically** for image files using
1746 /// the `imagesize` crate. The width/height are sent to the server, but
1747 /// note that the server currently doesn't return these fields when
1748 /// fetching samples back.
1749 /// - **UUIDs are generated automatically** if not provided. If you need
1750 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
1751 /// that the server doesn't currently return UUIDs in sample queries.
1752 ///
1753 /// # Arguments
1754 ///
1755 /// * `dataset_id` - The ID of the dataset to populate
1756 /// * `annotation_set_id` - **Required** if samples contain annotations,
1757 /// otherwise they will be ignored. Query with
1758 /// [`Client::annotation_sets`].
1759 /// * `samples` - Vector of samples to import with metadata and file
1760 /// references. For files, use the full local path - it will be uploaded
1761 /// automatically. UUIDs and image dimensions will be
1762 /// auto-generated/extracted if not provided.
1763 ///
1764 /// # Returns
1765 ///
1766 /// Returns the API result with sample UUIDs and upload status.
1767 ///
1768 /// # Example
1769 ///
1770 /// ```no_run
1771 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
1772 ///
1773 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1774 /// # let client = Client::new()?.with_login("user", "pass").await?;
1775 /// # let dataset_id = DatasetID::from(1);
1776 /// // Query available annotation sets for the dataset
1777 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
1778 /// let annotation_set_id = annotation_sets
1779 /// .first()
1780 /// .ok_or_else(|| {
1781 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
1782 /// })?
1783 /// .id();
1784 ///
1785 /// // Create sample with annotation (UUID will be auto-generated)
1786 /// let mut sample = Sample::new();
1787 /// sample.width = Some(1920);
1788 /// sample.height = Some(1080);
1789 /// sample.group = Some("train".to_string());
1790 ///
1791 /// // Add file - use full path to local file, it will be uploaded automatically
1792 /// sample.files = vec![SampleFile::with_filename(
1793 /// "image".to_string(),
1794 /// "/path/to/image.jpg".to_string(),
1795 /// )];
1796 ///
1797 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
1798 /// let mut annotation = Annotation::new();
1799 /// annotation.set_label(Some("person".to_string()));
1800 /// // Normalize pixel coordinates by dividing by image dimensions
1801 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
1802 /// annotation.set_box2d(Some(bbox));
1803 /// sample.annotations = vec![annotation];
1804 ///
1805 /// // Populate with annotation_set_id (REQUIRED for annotations)
1806 /// let result = client
1807 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
1808 /// .await?;
1809 /// # Ok(())
1810 /// # }
1811 /// ```
1812 pub async fn populate_samples(
1813 &self,
1814 dataset_id: DatasetID,
1815 annotation_set_id: Option<AnnotationSetID>,
1816 samples: Vec<Sample>,
1817 progress: Option<Sender<Progress>>,
1818 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1819 use crate::api::SamplesPopulateParams;
1820
1821 // Track which files need to be uploaded
1822 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1823
1824 // Process samples to detect local files and generate UUIDs
1825 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1826
1827 let has_files_to_upload = !files_to_upload.is_empty();
1828
1829 // Call populate API with presigned_urls=true if we have files to upload
1830 let params = SamplesPopulateParams {
1831 dataset_id,
1832 annotation_set_id,
1833 presigned_urls: Some(has_files_to_upload),
1834 samples,
1835 };
1836
1837 let results: Vec<crate::SamplesPopulateResult> = self
1838 .rpc("samples.populate2".to_owned(), Some(params))
1839 .await?;
1840
1841 // Upload files if we have any
1842 if has_files_to_upload {
1843 self.upload_sample_files(&results, files_to_upload, progress)
1844 .await?;
1845 }
1846
1847 Ok(results)
1848 }
1849
1850 fn prepare_samples_for_upload(
1851 &self,
1852 samples: Vec<Sample>,
1853 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1854 ) -> Result<Vec<Sample>, Error> {
1855 Ok(samples
1856 .into_iter()
1857 .map(|mut sample| {
1858 // Generate UUID if not provided
1859 if sample.uuid.is_none() {
1860 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1861 }
1862
1863 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1864
1865 // Process files: detect local paths and queue for upload
1866 let files_copy = sample.files.clone();
1867 let updated_files: Vec<crate::SampleFile> = files_copy
1868 .iter()
1869 .map(|file| {
1870 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1871 })
1872 .collect();
1873
1874 sample.files = updated_files;
1875 sample
1876 })
1877 .collect())
1878 }
1879
1880 fn process_sample_file(
1881 &self,
1882 file: &crate::SampleFile,
1883 sample_uuid: &str,
1884 sample: &mut Sample,
1885 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1886 ) -> crate::SampleFile {
1887 use std::path::Path;
1888
1889 if let Some(filename) = file.filename() {
1890 let path = Path::new(filename);
1891
1892 // Check if this is a valid local file path
1893 if path.exists()
1894 && path.is_file()
1895 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1896 {
1897 // For image files, try to extract dimensions if not already set
1898 if file.file_type() == "image"
1899 && (sample.width.is_none() || sample.height.is_none())
1900 && let Ok(size) = imagesize::size(path)
1901 {
1902 sample.width = Some(size.width as u32);
1903 sample.height = Some(size.height as u32);
1904 }
1905
1906 // Store the full path for later upload
1907 files_to_upload.push((
1908 sample_uuid.to_string(),
1909 file.file_type().to_string(),
1910 path.to_path_buf(),
1911 basename.to_string(),
1912 ));
1913
1914 // Return SampleFile with just the basename
1915 return crate::SampleFile::with_filename(
1916 file.file_type().to_string(),
1917 basename.to_string(),
1918 );
1919 }
1920 }
1921 // Return the file unchanged if not a local path
1922 file.clone()
1923 }
1924
1925 async fn upload_sample_files(
1926 &self,
1927 results: &[crate::SamplesPopulateResult],
1928 files_to_upload: Vec<(String, String, PathBuf, String)>,
1929 progress: Option<Sender<Progress>>,
1930 ) -> Result<(), Error> {
1931 // Build a map from (sample_uuid, basename) -> local_path
1932 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1933 for (uuid, _file_type, path, basename) in files_to_upload {
1934 upload_map.insert((uuid, basename), path);
1935 }
1936
1937 let http = self.http.clone();
1938
1939 // Extract the data we need for parallel upload
1940 let upload_tasks: Vec<_> = results
1941 .iter()
1942 .map(|result| (result.uuid.clone(), result.urls.clone()))
1943 .collect();
1944
1945 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1946 let http = http.clone();
1947 let upload_map = upload_map.clone();
1948
1949 async move {
1950 // Upload all files for this sample
1951 for url_info in &urls {
1952 if let Some(local_path) =
1953 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1954 {
1955 // Upload the file
1956 upload_file_to_presigned_url(
1957 http.clone(),
1958 &url_info.url,
1959 local_path.clone(),
1960 )
1961 .await?;
1962 }
1963 }
1964
1965 Ok(())
1966 }
1967 })
1968 .await
1969 }
1970
1971 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1972 // Uses default 120s timeout from client
1973 let resp = self.http.get(url).send().await?;
1974
1975 if !resp.status().is_success() {
1976 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
1977 }
1978
1979 let bytes = resp.bytes().await?;
1980 Ok(bytes.to_vec())
1981 }
1982
1983 /// Get the AnnotationGroup for the specified annotation set with the
1984 /// requested annotation types. The annotation type is used to filter
1985 /// the annotations returned. Images which do not have any annotations
1986 /// are included in the result.
1987 ///
1988 /// Get annotations as a DataFrame (2025.01 schema).
1989 ///
1990 /// **DEPRECATED**: Use [`Client::samples_dataframe()`] instead for full
1991 /// 2025.10 schema support including optional metadata columns.
1992 ///
1993 /// The result is a DataFrame following the EdgeFirst Dataset Format
1994 /// definition with 9 columns (original schema). Does not include new
1995 /// optional columns added in 2025.10.
1996 ///
1997 /// # Migration
1998 ///
1999 /// ```rust,no_run
2000 /// # use edgefirst_client::Client;
2001 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2002 /// # let client = Client::new()?;
2003 /// # let dataset_id = 1.into();
2004 /// # let annotation_set_id = 1.into();
2005 /// # let groups = vec![];
2006 /// # let types = vec![];
2007 /// // OLD (deprecated):
2008 /// let df = client
2009 /// .annotations_dataframe(annotation_set_id, &groups, &types, None)
2010 /// .await?;
2011 ///
2012 /// // NEW (recommended):
2013 /// let df = client
2014 /// .samples_dataframe(dataset_id, Some(annotation_set_id), &groups, &types, None)
2015 /// .await?;
2016 /// # Ok(())
2017 /// # }
2018 /// ```
2019 ///
2020 /// To get the annotations as a vector of Annotation objects, use the
2021 /// `annotations` method instead.
2022 #[deprecated(
2023 since = "0.8.0",
2024 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
2025 )]
2026 #[cfg(feature = "polars")]
2027 pub async fn annotations_dataframe(
2028 &self,
2029 annotation_set_id: AnnotationSetID,
2030 groups: &[String],
2031 types: &[AnnotationType],
2032 progress: Option<Sender<Progress>>,
2033 ) -> Result<DataFrame, Error> {
2034 #[allow(deprecated)]
2035 use crate::dataset::annotations_dataframe;
2036
2037 let annotations = self
2038 .annotations(annotation_set_id, groups, types, progress)
2039 .await?;
2040 #[allow(deprecated)]
2041 annotations_dataframe(&annotations)
2042 }
2043
2044 /// Get samples as a DataFrame with complete 2025.10 schema.
2045 ///
2046 /// This is the recommended method for obtaining dataset annotations in
2047 /// DataFrame format. It includes all sample metadata (size, location,
2048 /// pose, degradation) as optional columns.
2049 ///
2050 /// # Arguments
2051 ///
2052 /// * `dataset_id` - Dataset identifier
2053 /// * `annotation_set_id` - Optional annotation set filter
2054 /// * `groups` - Dataset groups to include (train, val, test)
2055 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2056 /// * `progress` - Optional progress callback
2057 ///
2058 /// # Example
2059 ///
2060 /// ```rust,no_run
2061 /// use edgefirst_client::Client;
2062 ///
2063 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2064 /// # let client = Client::new()?;
2065 /// # let dataset_id = 1.into();
2066 /// # let annotation_set_id = 1.into();
2067 /// let df = client
2068 /// .samples_dataframe(
2069 /// dataset_id,
2070 /// Some(annotation_set_id),
2071 /// &["train".to_string()],
2072 /// &[],
2073 /// None,
2074 /// )
2075 /// .await?;
2076 /// println!("DataFrame shape: {:?}", df.shape());
2077 /// # Ok(())
2078 /// # }
2079 /// ```
2080 #[cfg(feature = "polars")]
2081 pub async fn samples_dataframe(
2082 &self,
2083 dataset_id: DatasetID,
2084 annotation_set_id: Option<AnnotationSetID>,
2085 groups: &[String],
2086 types: &[AnnotationType],
2087 progress: Option<Sender<Progress>>,
2088 ) -> Result<DataFrame, Error> {
2089 use crate::dataset::samples_dataframe;
2090
2091 let samples = self
2092 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
2093 .await?;
2094 samples_dataframe(&samples)
2095 }
2096
2097 /// List available snapshots. If a name is provided, only snapshots
2098 /// containing that name are returned.
2099 ///
2100 /// Results are sorted by match quality: exact matches first, then
2101 /// case-insensitive exact matches, then shorter descriptions (more
2102 /// specific), then alphabetically.
2103 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
2104 let snapshots: Vec<Snapshot> = self
2105 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
2106 .await?;
2107 if let Some(name) = name {
2108 Ok(filter_and_sort_by_name(snapshots, name, |s| {
2109 s.description()
2110 }))
2111 } else {
2112 Ok(snapshots)
2113 }
2114 }
2115
2116 /// Get the snapshot with the specified id.
2117 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
2118 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2119 self.rpc("snapshots.get".to_owned(), Some(params)).await
2120 }
2121
2122 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
2123 ///
2124 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
2125 /// pairs) that serve two primary purposes:
2126 ///
2127 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
2128 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
2129 /// restored with AGTG (Automatic Ground Truth Generation) and optional
2130 /// auto-depth processing.
2131 ///
2132 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
2133 /// migration between EdgeFirst Studio instances using the create →
2134 /// download → upload → restore workflow.
2135 ///
2136 /// Large files are automatically chunked into 100MB parts and uploaded
2137 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
2138 /// is streamed without loading into memory, maintaining constant memory
2139 /// usage.
2140 ///
2141 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2142 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
2143 /// better for large files to avoid timeout issues. Higher values (16-32)
2144 /// are better for many small files.
2145 ///
2146 /// # Arguments
2147 ///
2148 /// * `path` - Local file path to MCAP file or directory containing
2149 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
2150 /// * `progress` - Optional channel to receive upload progress updates
2151 ///
2152 /// # Returns
2153 ///
2154 /// Returns a `Snapshot` object with ID, description, status, path, and
2155 /// creation timestamp on success.
2156 ///
2157 /// # Errors
2158 ///
2159 /// Returns an error if:
2160 /// * Path doesn't exist or contains invalid UTF-8
2161 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
2162 /// * Upload fails or network error occurs
2163 /// * Server rejects the snapshot
2164 ///
2165 /// # Example
2166 ///
2167 /// ```no_run
2168 /// # use edgefirst_client::{Client, Progress};
2169 /// # use tokio::sync::mpsc;
2170 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2171 /// let client = Client::new()?.with_token_path(None)?;
2172 ///
2173 /// // Upload MCAP file with progress tracking
2174 /// let (tx, mut rx) = mpsc::channel(1);
2175 /// tokio::spawn(async move {
2176 /// while let Some(Progress { current, total }) = rx.recv().await {
2177 /// println!(
2178 /// "Upload: {}/{} bytes ({:.1}%)",
2179 /// current,
2180 /// total,
2181 /// (current as f64 / total as f64) * 100.0
2182 /// );
2183 /// }
2184 /// });
2185 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
2186 /// println!("Created snapshot: {:?}", snapshot.id());
2187 ///
2188 /// // Upload dataset directory (no progress)
2189 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
2190 /// # Ok(())
2191 /// # }
2192 /// ```
2193 ///
2194 /// # See Also
2195 ///
2196 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2197 /// dataset
2198 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2199 /// data
2200 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2201 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2202 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
2203 pub async fn create_snapshot(
2204 &self,
2205 path: &str,
2206 progress: Option<Sender<Progress>>,
2207 ) -> Result<Snapshot, Error> {
2208 let path = Path::new(path);
2209
2210 if path.is_dir() {
2211 let path_str = path.to_str().ok_or_else(|| {
2212 Error::IoError(std::io::Error::new(
2213 std::io::ErrorKind::InvalidInput,
2214 "Path contains invalid UTF-8",
2215 ))
2216 })?;
2217 return self.create_snapshot_folder(path_str, progress).await;
2218 }
2219
2220 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2221 Error::IoError(std::io::Error::new(
2222 std::io::ErrorKind::InvalidInput,
2223 "Invalid filename",
2224 ))
2225 })?;
2226 let total = path.metadata()?.len() as usize;
2227 let current = Arc::new(AtomicUsize::new(0));
2228
2229 if let Some(progress) = &progress {
2230 let _ = progress.send(Progress { current: 0, total }).await;
2231 }
2232
2233 let params = SnapshotCreateMultipartParams {
2234 snapshot_name: name.to_owned(),
2235 keys: vec![name.to_owned()],
2236 file_sizes: vec![total],
2237 };
2238 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2239 .rpc(
2240 "snapshots.create_upload_url_multipart".to_owned(),
2241 Some(params),
2242 )
2243 .await?;
2244
2245 let snapshot_id = match multipart.get("snapshot_id") {
2246 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2247 _ => return Err(Error::InvalidResponse),
2248 };
2249
2250 let snapshot = self.snapshot(snapshot_id).await?;
2251 let part_prefix = snapshot
2252 .path()
2253 .split("::/")
2254 .last()
2255 .ok_or(Error::InvalidResponse)?
2256 .to_owned();
2257 let part_key = format!("{}/{}", part_prefix, name);
2258 let mut part = match multipart.get(&part_key) {
2259 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2260 _ => return Err(Error::InvalidResponse),
2261 }
2262 .clone();
2263 part.key = Some(part_key);
2264
2265 let params = upload_multipart(
2266 self.http.clone(),
2267 part.clone(),
2268 path.to_path_buf(),
2269 total,
2270 current,
2271 progress.clone(),
2272 )
2273 .await?;
2274
2275 let complete: String = self
2276 .rpc(
2277 "snapshots.complete_multipart_upload".to_owned(),
2278 Some(params),
2279 )
2280 .await?;
2281 debug!("Snapshot Multipart Complete: {:?}", complete);
2282
2283 let params: SnapshotStatusParams = SnapshotStatusParams {
2284 snapshot_id,
2285 status: "available".to_owned(),
2286 };
2287 let _: SnapshotStatusResult = self
2288 .rpc("snapshots.update".to_owned(), Some(params))
2289 .await?;
2290
2291 if let Some(progress) = progress {
2292 drop(progress);
2293 }
2294
2295 self.snapshot(snapshot_id).await
2296 }
2297
2298 async fn create_snapshot_folder(
2299 &self,
2300 path: &str,
2301 progress: Option<Sender<Progress>>,
2302 ) -> Result<Snapshot, Error> {
2303 let path = Path::new(path);
2304 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2305 Error::IoError(std::io::Error::new(
2306 std::io::ErrorKind::InvalidInput,
2307 "Invalid directory name",
2308 ))
2309 })?;
2310
2311 let files = WalkDir::new(path)
2312 .into_iter()
2313 .filter_map(|entry| entry.ok())
2314 .filter(|entry| entry.file_type().is_file())
2315 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
2316 .collect::<Vec<_>>();
2317
2318 let total: usize = files
2319 .iter()
2320 .filter_map(|file| path.join(file).metadata().ok())
2321 .map(|metadata| metadata.len() as usize)
2322 .sum();
2323 let current = Arc::new(AtomicUsize::new(0));
2324
2325 if let Some(progress) = &progress {
2326 let _ = progress.send(Progress { current: 0, total }).await;
2327 }
2328
2329 let keys = files
2330 .iter()
2331 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
2332 .collect::<Vec<_>>();
2333 let file_sizes = files
2334 .iter()
2335 .filter_map(|key| path.join(key).metadata().ok())
2336 .map(|metadata| metadata.len() as usize)
2337 .collect::<Vec<_>>();
2338
2339 let params = SnapshotCreateMultipartParams {
2340 snapshot_name: name.to_owned(),
2341 keys,
2342 file_sizes,
2343 };
2344
2345 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2346 .rpc(
2347 "snapshots.create_upload_url_multipart".to_owned(),
2348 Some(params),
2349 )
2350 .await?;
2351
2352 let snapshot_id = match multipart.get("snapshot_id") {
2353 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2354 _ => return Err(Error::InvalidResponse),
2355 };
2356
2357 let snapshot = self.snapshot(snapshot_id).await?;
2358 let part_prefix = snapshot
2359 .path()
2360 .split("::/")
2361 .last()
2362 .ok_or(Error::InvalidResponse)?
2363 .to_owned();
2364
2365 for file in files {
2366 let file_str = file.to_str().ok_or_else(|| {
2367 Error::IoError(std::io::Error::new(
2368 std::io::ErrorKind::InvalidInput,
2369 "File path contains invalid UTF-8",
2370 ))
2371 })?;
2372 let part_key = format!("{}/{}", part_prefix, file_str);
2373 let mut part = match multipart.get(&part_key) {
2374 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2375 _ => return Err(Error::InvalidResponse),
2376 }
2377 .clone();
2378 part.key = Some(part_key);
2379
2380 let params = upload_multipart(
2381 self.http.clone(),
2382 part.clone(),
2383 path.join(file),
2384 total,
2385 current.clone(),
2386 progress.clone(),
2387 )
2388 .await?;
2389
2390 let complete: String = self
2391 .rpc(
2392 "snapshots.complete_multipart_upload".to_owned(),
2393 Some(params),
2394 )
2395 .await?;
2396 debug!("Snapshot Part Complete: {:?}", complete);
2397 }
2398
2399 let params = SnapshotStatusParams {
2400 snapshot_id,
2401 status: "available".to_owned(),
2402 };
2403 let _: SnapshotStatusResult = self
2404 .rpc("snapshots.update".to_owned(), Some(params))
2405 .await?;
2406
2407 if let Some(progress) = progress {
2408 drop(progress);
2409 }
2410
2411 self.snapshot(snapshot_id).await
2412 }
2413
2414 /// Delete a snapshot from EdgeFirst Studio.
2415 ///
2416 /// Permanently removes a snapshot and its associated data. This operation
2417 /// cannot be undone.
2418 ///
2419 /// # Arguments
2420 ///
2421 /// * `snapshot_id` - The snapshot ID to delete
2422 ///
2423 /// # Errors
2424 ///
2425 /// Returns an error if:
2426 /// * Snapshot doesn't exist
2427 /// * User lacks permission to delete the snapshot
2428 /// * Server error occurs
2429 ///
2430 /// # Example
2431 ///
2432 /// ```no_run
2433 /// # use edgefirst_client::{Client, SnapshotID};
2434 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2435 /// let client = Client::new()?.with_token_path(None)?;
2436 /// let snapshot_id = SnapshotID::from(123);
2437 /// client.delete_snapshot(snapshot_id).await?;
2438 /// # Ok(())
2439 /// # }
2440 /// ```
2441 ///
2442 /// # See Also
2443 ///
2444 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2445 /// * [`snapshots`](Self::snapshots) - List all snapshots
2446 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
2447 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2448 let _: String = self
2449 .rpc("snapshots.delete".to_owned(), Some(params))
2450 .await?;
2451 Ok(())
2452 }
2453
2454 /// Create a snapshot from an existing dataset on the server.
2455 ///
2456 /// Triggers server-side snapshot generation which exports the dataset's
2457 /// images and annotations into a downloadable EdgeFirst Dataset Format
2458 /// snapshot.
2459 ///
2460 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
2461 /// while restore creates a dataset from a snapshot, this method creates a
2462 /// snapshot from a dataset.
2463 ///
2464 /// # Arguments
2465 ///
2466 /// * `dataset_id` - The dataset ID to create snapshot from
2467 /// * `description` - Description for the created snapshot
2468 ///
2469 /// # Returns
2470 ///
2471 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
2472 /// for monitoring progress.
2473 ///
2474 /// # Errors
2475 ///
2476 /// Returns an error if:
2477 /// * Dataset doesn't exist
2478 /// * User lacks permission to access the dataset
2479 /// * Server rejects the request
2480 ///
2481 /// # Example
2482 ///
2483 /// ```no_run
2484 /// # use edgefirst_client::{Client, DatasetID};
2485 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2486 /// let client = Client::new()?.with_token_path(None)?;
2487 /// let dataset_id = DatasetID::from(123);
2488 ///
2489 /// // Create snapshot from dataset (all annotation sets)
2490 /// let result = client
2491 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
2492 /// .await?;
2493 /// println!("Created snapshot: {:?}", result.id);
2494 ///
2495 /// // Monitor progress via task ID
2496 /// if let Some(task_id) = result.task_id {
2497 /// println!("Task: {}", task_id);
2498 /// }
2499 /// # Ok(())
2500 /// # }
2501 /// ```
2502 ///
2503 /// # See Also
2504 ///
2505 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
2506 /// snapshot
2507 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2508 /// dataset
2509 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2510 pub async fn create_snapshot_from_dataset(
2511 &self,
2512 dataset_id: DatasetID,
2513 description: &str,
2514 annotation_set_id: Option<AnnotationSetID>,
2515 ) -> Result<SnapshotFromDatasetResult, Error> {
2516 // Resolve annotation_set_id: use provided value or fetch default
2517 let annotation_set_id = match annotation_set_id {
2518 Some(id) => id,
2519 None => {
2520 // Fetch annotation sets and find default ("annotations") or use first
2521 let sets = self.annotation_sets(dataset_id).await?;
2522 if sets.is_empty() {
2523 return Err(Error::InvalidParameters(
2524 "No annotation sets available for dataset".to_owned(),
2525 ));
2526 }
2527 // Look for "annotations" set (default), otherwise use first
2528 sets.iter()
2529 .find(|s| s.name() == "annotations")
2530 .unwrap_or(&sets[0])
2531 .id()
2532 }
2533 };
2534 let params = SnapshotCreateFromDataset {
2535 description: description.to_owned(),
2536 dataset_id,
2537 annotation_set_id,
2538 };
2539 self.rpc("snapshots.create".to_owned(), Some(params)).await
2540 }
2541
2542 /// Download a snapshot from EdgeFirst Studio to local storage.
2543 ///
2544 /// Downloads all files in a snapshot (single MCAP file or directory of
2545 /// EdgeFirst Dataset Format files) to the specified output path. Files are
2546 /// downloaded concurrently with progress tracking.
2547 ///
2548 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2549 /// downloads (default: half of CPU cores, min 2, max 8).
2550 ///
2551 /// # Arguments
2552 ///
2553 /// * `snapshot_id` - The snapshot ID to download
2554 /// * `output` - Local directory path to save downloaded files
2555 /// * `progress` - Optional channel to receive download progress updates
2556 ///
2557 /// # Errors
2558 ///
2559 /// Returns an error if:
2560 /// * Snapshot doesn't exist
2561 /// * Output directory cannot be created
2562 /// * Download fails or network error occurs
2563 ///
2564 /// # Example
2565 ///
2566 /// ```no_run
2567 /// # use edgefirst_client::{Client, SnapshotID, Progress};
2568 /// # use tokio::sync::mpsc;
2569 /// # use std::path::PathBuf;
2570 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2571 /// let client = Client::new()?.with_token_path(None)?;
2572 /// let snapshot_id = SnapshotID::from(123);
2573 ///
2574 /// // Download with progress tracking
2575 /// let (tx, mut rx) = mpsc::channel(1);
2576 /// tokio::spawn(async move {
2577 /// while let Some(Progress { current, total }) = rx.recv().await {
2578 /// println!("Download: {}/{} bytes", current, total);
2579 /// }
2580 /// });
2581 /// client
2582 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
2583 /// .await?;
2584 /// # Ok(())
2585 /// # }
2586 /// ```
2587 ///
2588 /// # See Also
2589 ///
2590 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2591 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2592 /// dataset
2593 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2594 pub async fn download_snapshot(
2595 &self,
2596 snapshot_id: SnapshotID,
2597 output: PathBuf,
2598 progress: Option<Sender<Progress>>,
2599 ) -> Result<(), Error> {
2600 fs::create_dir_all(&output).await?;
2601
2602 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2603 let items: HashMap<String, String> = self
2604 .rpc("snapshots.create_download_url".to_owned(), Some(params))
2605 .await?;
2606
2607 let total = Arc::new(AtomicUsize::new(0));
2608 let current = Arc::new(AtomicUsize::new(0));
2609 let sem = Arc::new(Semaphore::new(max_tasks()));
2610
2611 let tasks = items
2612 .iter()
2613 .map(|(key, url)| {
2614 let http = self.http.clone();
2615 let key = key.clone();
2616 let url = url.clone();
2617 let output = output.clone();
2618 let progress = progress.clone();
2619 let current = current.clone();
2620 let total = total.clone();
2621 let sem = sem.clone();
2622
2623 tokio::spawn(async move {
2624 let _permit = sem.acquire().await.map_err(|_| {
2625 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2626 })?;
2627 let res = http.get(url).send().await?;
2628 let content_length = res.content_length().unwrap_or(0) as usize;
2629
2630 if let Some(progress) = &progress {
2631 let total = total.fetch_add(content_length, Ordering::SeqCst);
2632 let _ = progress
2633 .send(Progress {
2634 current: current.load(Ordering::SeqCst),
2635 total: total + content_length,
2636 })
2637 .await;
2638 }
2639
2640 let mut file = File::create(output.join(key)).await?;
2641 let mut stream = res.bytes_stream();
2642
2643 while let Some(chunk) = stream.next().await {
2644 let chunk = chunk?;
2645 file.write_all(&chunk).await?;
2646 let len = chunk.len();
2647
2648 if let Some(progress) = &progress {
2649 let total = total.load(Ordering::SeqCst);
2650 let current = current.fetch_add(len, Ordering::SeqCst);
2651
2652 let _ = progress
2653 .send(Progress {
2654 current: current + len,
2655 total,
2656 })
2657 .await;
2658 }
2659 }
2660
2661 Ok::<(), Error>(())
2662 })
2663 })
2664 .collect::<Vec<_>>();
2665
2666 join_all(tasks)
2667 .await
2668 .into_iter()
2669 .collect::<Result<Vec<_>, _>>()?
2670 .into_iter()
2671 .collect::<Result<Vec<_>, _>>()?;
2672
2673 Ok(())
2674 }
2675
2676 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
2677 ///
2678 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
2679 /// the specified project. For MCAP files, supports:
2680 ///
2681 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
2682 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
2683 /// present)
2684 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
2685 /// * **Topic filtering**: Select specific MCAP topics to restore
2686 ///
2687 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
2688 /// dataset structure.
2689 ///
2690 /// # Arguments
2691 ///
2692 /// * `project_id` - Target project ID
2693 /// * `snapshot_id` - Snapshot ID to restore
2694 /// * `topics` - MCAP topics to include (empty = all topics)
2695 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
2696 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
2697 /// * `dataset_name` - Optional custom dataset name
2698 /// * `dataset_description` - Optional dataset description
2699 ///
2700 /// # Returns
2701 ///
2702 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
2703 ///
2704 /// # Errors
2705 ///
2706 /// Returns an error if:
2707 /// * Snapshot or project doesn't exist
2708 /// * Snapshot format is invalid
2709 /// * Server rejects restoration parameters
2710 ///
2711 /// # Example
2712 ///
2713 /// ```no_run
2714 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
2715 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2716 /// let client = Client::new()?.with_token_path(None)?;
2717 /// let project_id = ProjectID::from(1);
2718 /// let snapshot_id = SnapshotID::from(123);
2719 ///
2720 /// // Restore MCAP with AGTG for "person" and "car" detection
2721 /// let result = client
2722 /// .restore_snapshot(
2723 /// project_id,
2724 /// snapshot_id,
2725 /// &[], // All topics
2726 /// &["person".to_string(), "car".to_string()], // AGTG labels
2727 /// true, // Auto-depth
2728 /// Some("Highway Dataset"),
2729 /// Some("Collected on I-95"),
2730 /// )
2731 /// .await?;
2732 /// println!("Restored to dataset: {:?}", result.dataset_id);
2733 /// # Ok(())
2734 /// # }
2735 /// ```
2736 ///
2737 /// # See Also
2738 ///
2739 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2740 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2741 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2742 #[allow(clippy::too_many_arguments)]
2743 pub async fn restore_snapshot(
2744 &self,
2745 project_id: ProjectID,
2746 snapshot_id: SnapshotID,
2747 topics: &[String],
2748 autolabel: &[String],
2749 autodepth: bool,
2750 dataset_name: Option<&str>,
2751 dataset_description: Option<&str>,
2752 ) -> Result<SnapshotRestoreResult, Error> {
2753 let params = SnapshotRestore {
2754 project_id,
2755 snapshot_id,
2756 fps: 1,
2757 autodepth,
2758 agtg_pipeline: !autolabel.is_empty(),
2759 autolabel: autolabel.to_vec(),
2760 topics: topics.to_vec(),
2761 dataset_name: dataset_name.map(|s| s.to_owned()),
2762 dataset_description: dataset_description.map(|s| s.to_owned()),
2763 };
2764 self.rpc("snapshots.restore".to_owned(), Some(params)).await
2765 }
2766
2767 /// Returns a list of experiments available to the user. The experiments
2768 /// are returned as a vector of Experiment objects. If name is provided
2769 /// then only experiments containing this string are returned.
2770 ///
2771 /// Results are sorted by match quality: exact matches first, then
2772 /// case-insensitive exact matches, then shorter names (more specific),
2773 /// then alphabetically.
2774 ///
2775 /// Experiments provide a method of organizing training and validation
2776 /// sessions together and are akin to an Experiment in MLFlow terminology.
2777 /// Each experiment can have multiple trainer sessions associated with it,
2778 /// these would be akin to runs in MLFlow terminology.
2779 pub async fn experiments(
2780 &self,
2781 project_id: ProjectID,
2782 name: Option<&str>,
2783 ) -> Result<Vec<Experiment>, Error> {
2784 let params = HashMap::from([("project_id", project_id)]);
2785 let experiments: Vec<Experiment> =
2786 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
2787 if let Some(name) = name {
2788 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
2789 } else {
2790 Ok(experiments)
2791 }
2792 }
2793
2794 /// Return the experiment with the specified experiment ID. If the
2795 /// experiment does not exist, an error is returned.
2796 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
2797 let params = HashMap::from([("trainer_id", experiment_id)]);
2798 self.rpc("trainer.get".to_owned(), Some(params)).await
2799 }
2800
2801 /// Returns a list of trainer sessions available to the user. The trainer
2802 /// sessions are returned as a vector of TrainingSession objects. If name
2803 /// is provided then only trainer sessions containing this string are
2804 /// returned.
2805 ///
2806 /// Results are sorted by match quality: exact matches first, then
2807 /// case-insensitive exact matches, then shorter names (more specific),
2808 /// then alphabetically.
2809 ///
2810 /// Trainer sessions are akin to runs in MLFlow terminology. These
2811 /// represent an actual training session which will produce metrics and
2812 /// model artifacts.
2813 pub async fn training_sessions(
2814 &self,
2815 experiment_id: ExperimentID,
2816 name: Option<&str>,
2817 ) -> Result<Vec<TrainingSession>, Error> {
2818 let params = HashMap::from([("trainer_id", experiment_id)]);
2819 let sessions: Vec<TrainingSession> = self
2820 .rpc("trainer.session.list".to_owned(), Some(params))
2821 .await?;
2822 if let Some(name) = name {
2823 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
2824 } else {
2825 Ok(sessions)
2826 }
2827 }
2828
2829 /// Return the trainer session with the specified trainer session ID. If
2830 /// the trainer session does not exist, an error is returned.
2831 pub async fn training_session(
2832 &self,
2833 session_id: TrainingSessionID,
2834 ) -> Result<TrainingSession, Error> {
2835 let params = HashMap::from([("trainer_session_id", session_id)]);
2836 self.rpc("trainer.session.get".to_owned(), Some(params))
2837 .await
2838 }
2839
2840 /// List validation sessions for the given project.
2841 pub async fn validation_sessions(
2842 &self,
2843 project_id: ProjectID,
2844 ) -> Result<Vec<ValidationSession>, Error> {
2845 let params = HashMap::from([("project_id", project_id)]);
2846 self.rpc("validate.session.list".to_owned(), Some(params))
2847 .await
2848 }
2849
2850 /// Retrieve a specific validation session.
2851 pub async fn validation_session(
2852 &self,
2853 session_id: ValidationSessionID,
2854 ) -> Result<ValidationSession, Error> {
2855 let params = HashMap::from([("validate_session_id", session_id)]);
2856 self.rpc("validate.session.get".to_owned(), Some(params))
2857 .await
2858 }
2859
2860 /// List the artifacts for the specified trainer session. The artifacts
2861 /// are returned as a vector of strings.
2862 pub async fn artifacts(
2863 &self,
2864 training_session_id: TrainingSessionID,
2865 ) -> Result<Vec<Artifact>, Error> {
2866 let params = HashMap::from([("training_session_id", training_session_id)]);
2867 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2868 .await
2869 }
2870
2871 /// Download the model artifact for the specified trainer session to the
2872 /// specified file path, if path is not provided it will be downloaded to
2873 /// the current directory with the same filename. A progress callback can
2874 /// be provided to monitor the progress of the download over a watch
2875 /// channel.
2876 pub async fn download_artifact(
2877 &self,
2878 training_session_id: TrainingSessionID,
2879 modelname: &str,
2880 filename: Option<PathBuf>,
2881 progress: Option<Sender<Progress>>,
2882 ) -> Result<(), Error> {
2883 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2884 let resp = self
2885 .http
2886 .get(format!(
2887 "{}/download_model?training_session_id={}&file={}",
2888 self.url,
2889 training_session_id.value(),
2890 modelname
2891 ))
2892 .header("Authorization", format!("Bearer {}", self.token().await))
2893 .send()
2894 .await?;
2895 if !resp.status().is_success() {
2896 let err = resp.error_for_status_ref().unwrap_err();
2897 return Err(Error::HttpError(err));
2898 }
2899
2900 if let Some(parent) = filename.parent() {
2901 fs::create_dir_all(parent).await?;
2902 }
2903
2904 if let Some(progress) = progress {
2905 let total = resp.content_length().unwrap_or(0) as usize;
2906 let _ = progress.send(Progress { current: 0, total }).await;
2907
2908 let mut file = File::create(filename).await?;
2909 let mut current = 0;
2910 let mut stream = resp.bytes_stream();
2911
2912 while let Some(item) = stream.next().await {
2913 let chunk = item?;
2914 file.write_all(&chunk).await?;
2915 current += chunk.len();
2916 let _ = progress.send(Progress { current, total }).await;
2917 }
2918 } else {
2919 let body = resp.bytes().await?;
2920 fs::write(filename, body).await?;
2921 }
2922
2923 Ok(())
2924 }
2925
2926 /// Download the model checkpoint associated with the specified trainer
2927 /// session to the specified file path, if path is not provided it will be
2928 /// downloaded to the current directory with the same filename. A progress
2929 /// callback can be provided to monitor the progress of the download over a
2930 /// watch channel.
2931 ///
2932 /// There is no API for listing checkpoints it is expected that trainers are
2933 /// aware of possible checkpoints and their names within the checkpoint
2934 /// folder on the server.
2935 pub async fn download_checkpoint(
2936 &self,
2937 training_session_id: TrainingSessionID,
2938 checkpoint: &str,
2939 filename: Option<PathBuf>,
2940 progress: Option<Sender<Progress>>,
2941 ) -> Result<(), Error> {
2942 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2943 let resp = self
2944 .http
2945 .get(format!(
2946 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2947 self.url,
2948 training_session_id.value(),
2949 checkpoint
2950 ))
2951 .header("Authorization", format!("Bearer {}", self.token().await))
2952 .send()
2953 .await?;
2954 if !resp.status().is_success() {
2955 let err = resp.error_for_status_ref().unwrap_err();
2956 return Err(Error::HttpError(err));
2957 }
2958
2959 if let Some(parent) = filename.parent() {
2960 fs::create_dir_all(parent).await?;
2961 }
2962
2963 if let Some(progress) = progress {
2964 let total = resp.content_length().unwrap_or(0) as usize;
2965 let _ = progress.send(Progress { current: 0, total }).await;
2966
2967 let mut file = File::create(filename).await?;
2968 let mut current = 0;
2969 let mut stream = resp.bytes_stream();
2970
2971 while let Some(item) = stream.next().await {
2972 let chunk = item?;
2973 file.write_all(&chunk).await?;
2974 current += chunk.len();
2975 let _ = progress.send(Progress { current, total }).await;
2976 }
2977 } else {
2978 let body = resp.bytes().await?;
2979 fs::write(filename, body).await?;
2980 }
2981
2982 Ok(())
2983 }
2984
2985 /// Return a list of tasks for the current user.
2986 ///
2987 /// # Arguments
2988 ///
2989 /// * `name` - Optional filter for task name (client-side substring match)
2990 /// * `workflow` - Optional filter for workflow/task type. If provided,
2991 /// filters server-side by exact match. Valid values include: "trainer",
2992 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
2993 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
2994 /// "convertor", "twostage"
2995 /// * `status` - Optional filter for task status (e.g., "running",
2996 /// "complete", "error")
2997 /// * `manager` - Optional filter for task manager type (e.g., "aws",
2998 /// "user", "kubernetes")
2999 pub async fn tasks(
3000 &self,
3001 name: Option<&str>,
3002 workflow: Option<&str>,
3003 status: Option<&str>,
3004 manager: Option<&str>,
3005 ) -> Result<Vec<Task>, Error> {
3006 let mut params = TasksListParams {
3007 continue_token: None,
3008 types: workflow.map(|w| vec![w.to_owned()]),
3009 status: status.map(|s| vec![s.to_owned()]),
3010 manager: manager.map(|m| vec![m.to_owned()]),
3011 };
3012 let mut tasks = Vec::new();
3013
3014 loop {
3015 let result = self
3016 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
3017 .await?;
3018 tasks.extend(result.tasks);
3019
3020 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
3021 params.continue_token = None;
3022 } else {
3023 params.continue_token = result.continue_token;
3024 }
3025
3026 if params.continue_token.is_none() {
3027 break;
3028 }
3029 }
3030
3031 if let Some(name) = name {
3032 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
3033 }
3034
3035 Ok(tasks)
3036 }
3037
3038 /// Retrieve the task information and status.
3039 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
3040 self.rpc(
3041 "task.get".to_owned(),
3042 Some(HashMap::from([("id", task_id)])),
3043 )
3044 .await
3045 }
3046
3047 /// Updates the tasks status.
3048 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
3049 let status = TaskStatus {
3050 task_id,
3051 status: status.to_owned(),
3052 };
3053 self.rpc("docker.update.status".to_owned(), Some(status))
3054 .await
3055 }
3056
3057 /// Defines the stages for the task. The stages are defined as a mapping
3058 /// from stage names to their descriptions. Once stages are defined their
3059 /// status can be updated using the update_stage method.
3060 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
3061 let stages: Vec<HashMap<String, String>> = stages
3062 .iter()
3063 .map(|(key, value)| {
3064 let mut stage_map = HashMap::new();
3065 stage_map.insert(key.to_string(), value.to_string());
3066 stage_map
3067 })
3068 .collect();
3069 let params = TaskStages { task_id, stages };
3070 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
3071 Ok(())
3072 }
3073
3074 /// Updates the progress of the task for the provided stage and status
3075 /// information.
3076 pub async fn update_stage(
3077 &self,
3078 task_id: TaskID,
3079 stage: &str,
3080 status: &str,
3081 message: &str,
3082 percentage: u8,
3083 ) -> Result<(), Error> {
3084 let stage = Stage::new(
3085 Some(task_id),
3086 stage.to_owned(),
3087 Some(status.to_owned()),
3088 Some(message.to_owned()),
3089 percentage,
3090 );
3091 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
3092 Ok(())
3093 }
3094
3095 /// Raw fetch from the Studio server is used for downloading files.
3096 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
3097 let req = self
3098 .http
3099 .get(format!("{}/{}", self.url, query))
3100 .header("User-Agent", "EdgeFirst Client")
3101 .header("Authorization", format!("Bearer {}", self.token().await));
3102 let resp = req.send().await?;
3103
3104 if resp.status().is_success() {
3105 let body = resp.bytes().await?;
3106
3107 if log_enabled!(Level::Trace) {
3108 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
3109 }
3110
3111 Ok(body.to_vec())
3112 } else {
3113 let err = resp.error_for_status_ref().unwrap_err();
3114 Err(Error::HttpError(err))
3115 }
3116 }
3117
3118 /// Sends a multipart post request to the server. This is used by the
3119 /// upload and download APIs which do not use JSON-RPC but instead transfer
3120 /// files using multipart/form-data.
3121 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
3122 let req = self
3123 .http
3124 .post(format!("{}/api?method={}", self.url, method))
3125 .header("Accept", "application/json")
3126 .header("User-Agent", "EdgeFirst Client")
3127 .header("Authorization", format!("Bearer {}", self.token().await))
3128 .multipart(form);
3129 let resp = req.send().await?;
3130
3131 if resp.status().is_success() {
3132 let body = resp.bytes().await?;
3133
3134 if log_enabled!(Level::Trace) {
3135 trace!(
3136 "POST Multipart Response: {}",
3137 String::from_utf8_lossy(&body)
3138 );
3139 }
3140
3141 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
3142 Ok(response) => response,
3143 Err(err) => {
3144 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3145 return Err(err.into());
3146 }
3147 };
3148
3149 if let Some(error) = response.error {
3150 Err(Error::RpcError(error.code, error.message))
3151 } else if let Some(result) = response.result {
3152 Ok(result)
3153 } else {
3154 Err(Error::InvalidResponse)
3155 }
3156 } else {
3157 let err = resp.error_for_status_ref().unwrap_err();
3158 Err(Error::HttpError(err))
3159 }
3160 }
3161
3162 /// Send a JSON-RPC request to the server. The method is the name of the
3163 /// method to call on the server. The params are the parameters to pass to
3164 /// the method. The method and params are serialized into a JSON-RPC
3165 /// request and sent to the server. The response is deserialized into
3166 /// the specified type and returned to the caller.
3167 ///
3168 /// NOTE: This API would generally not be called directly and instead users
3169 /// should use the higher-level methods provided by the client.
3170 pub async fn rpc<Params, RpcResult>(
3171 &self,
3172 method: String,
3173 params: Option<Params>,
3174 ) -> Result<RpcResult, Error>
3175 where
3176 Params: Serialize,
3177 RpcResult: DeserializeOwned,
3178 {
3179 let auth_expires = self.token_expiration().await?;
3180 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
3181 self.renew_token().await?;
3182 }
3183
3184 self.rpc_without_auth(method, params).await
3185 }
3186
3187 async fn rpc_without_auth<Params, RpcResult>(
3188 &self,
3189 method: String,
3190 params: Option<Params>,
3191 ) -> Result<RpcResult, Error>
3192 where
3193 Params: Serialize,
3194 RpcResult: DeserializeOwned,
3195 {
3196 let request = RpcRequest {
3197 method,
3198 params,
3199 ..Default::default()
3200 };
3201
3202 if log_enabled!(Level::Trace) {
3203 trace!(
3204 "RPC Request: {}",
3205 serde_json::ser::to_string_pretty(&request)?
3206 );
3207 }
3208
3209 let url = format!("{}/api", self.url);
3210
3211 // Use client-level timeout (allows retry mechanism to work properly)
3212 // Per-request timeout overrides can prevent retries from functioning
3213 let res = self
3214 .http
3215 .post(&url)
3216 .header("Accept", "application/json")
3217 .header("User-Agent", "EdgeFirst Client")
3218 .header("Authorization", format!("Bearer {}", self.token().await))
3219 .json(&request)
3220 .send()
3221 .await?;
3222
3223 self.process_rpc_response(res).await
3224 }
3225
3226 async fn process_rpc_response<RpcResult>(
3227 &self,
3228 res: reqwest::Response,
3229 ) -> Result<RpcResult, Error>
3230 where
3231 RpcResult: DeserializeOwned,
3232 {
3233 let body = res.bytes().await?;
3234
3235 if log_enabled!(Level::Trace) {
3236 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
3237 }
3238
3239 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
3240 Ok(response) => response,
3241 Err(err) => {
3242 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3243 return Err(err.into());
3244 }
3245 };
3246
3247 // FIXME: Studio Server always returns 999 as the id.
3248 // if request.id.to_string() != response.id {
3249 // return Err(Error::InvalidRpcId(response.id));
3250 // }
3251
3252 if let Some(error) = response.error {
3253 Err(Error::RpcError(error.code, error.message))
3254 } else if let Some(result) = response.result {
3255 Ok(result)
3256 } else {
3257 Err(Error::InvalidResponse)
3258 }
3259 }
3260}
3261
3262/// Process items in parallel with semaphore concurrency control and progress
3263/// tracking.
3264///
3265/// This helper eliminates boilerplate for parallel item processing with:
3266/// - Semaphore limiting concurrent tasks to `max_tasks()` (configurable via
3267/// `MAX_TASKS` environment variable, default: half of CPU cores, min 2, max
3268/// 8)
3269/// - Atomic progress counter with automatic item-level updates
3270/// - Progress updates sent after each item completes (not byte-level streaming)
3271/// - Proper error propagation from spawned tasks
3272///
3273/// Note: This is optimized for discrete items with post-completion progress
3274/// updates. For byte-level streaming progress or custom retry logic, use
3275/// specialized implementations.
3276///
3277/// # Arguments
3278///
3279/// * `items` - Collection of items to process in parallel
3280/// * `progress` - Optional progress channel for tracking completion
3281/// * `work_fn` - Async function to execute for each item
3282///
3283/// # Examples
3284///
3285/// ```rust,ignore
3286/// parallel_foreach_items(samples, progress, |sample| async move {
3287/// // Process sample
3288/// sample.download(&client, file_type).await?;
3289/// Ok(())
3290/// }).await?;
3291/// ```
3292async fn parallel_foreach_items<T, F, Fut>(
3293 items: Vec<T>,
3294 progress: Option<Sender<Progress>>,
3295 work_fn: F,
3296) -> Result<(), Error>
3297where
3298 T: Send + 'static,
3299 F: Fn(T) -> Fut + Send + Sync + 'static,
3300 Fut: Future<Output = Result<(), Error>> + Send + 'static,
3301{
3302 let total = items.len();
3303 let current = Arc::new(AtomicUsize::new(0));
3304 let sem = Arc::new(Semaphore::new(max_tasks()));
3305 let work_fn = Arc::new(work_fn);
3306
3307 let tasks = items
3308 .into_iter()
3309 .map(|item| {
3310 let sem = sem.clone();
3311 let current = current.clone();
3312 let progress = progress.clone();
3313 let work_fn = work_fn.clone();
3314
3315 tokio::spawn(async move {
3316 let _permit = sem.acquire().await.map_err(|_| {
3317 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3318 })?;
3319
3320 // Execute the actual work
3321 work_fn(item).await?;
3322
3323 // Update progress
3324 if let Some(progress) = &progress {
3325 let current = current.fetch_add(1, Ordering::SeqCst);
3326 let _ = progress
3327 .send(Progress {
3328 current: current + 1,
3329 total,
3330 })
3331 .await;
3332 }
3333
3334 Ok::<(), Error>(())
3335 })
3336 })
3337 .collect::<Vec<_>>();
3338
3339 join_all(tasks)
3340 .await
3341 .into_iter()
3342 .collect::<Result<Vec<_>, _>>()?
3343 .into_iter()
3344 .collect::<Result<Vec<_>, _>>()?;
3345
3346 if let Some(progress) = progress {
3347 drop(progress);
3348 }
3349
3350 Ok(())
3351}
3352
3353/// Upload a file to S3 using multipart upload with presigned URLs.
3354///
3355/// Splits a file into chunks (100MB each) and uploads them in parallel using
3356/// S3 multipart upload protocol. Returns completion parameters with ETags for
3357/// finalizing the upload.
3358///
3359/// This function handles:
3360/// - Splitting files into parts based on PART_SIZE (100MB)
3361/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
3362/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
3363/// - Retry logic (handled by reqwest client)
3364/// - Progress tracking across all parts
3365///
3366/// # Arguments
3367///
3368/// * `http` - HTTP client for making requests
3369/// * `part` - Snapshot part info with presigned URLs for each chunk
3370/// * `path` - Local file path to upload
3371/// * `total` - Total bytes across all files for progress calculation
3372/// * `current` - Atomic counter tracking bytes uploaded across all operations
3373/// * `progress` - Optional channel for sending progress updates
3374///
3375/// # Returns
3376///
3377/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
3378async fn upload_multipart(
3379 http: reqwest::Client,
3380 part: SnapshotPart,
3381 path: PathBuf,
3382 total: usize,
3383 current: Arc<AtomicUsize>,
3384 progress: Option<Sender<Progress>>,
3385) -> Result<SnapshotCompleteMultipartParams, Error> {
3386 let filesize = path.metadata()?.len() as usize;
3387 let n_parts = filesize.div_ceil(PART_SIZE);
3388 let sem = Arc::new(Semaphore::new(max_tasks()));
3389
3390 let key = part.key.ok_or(Error::InvalidResponse)?;
3391 let upload_id = part.upload_id;
3392
3393 let urls = part.urls.clone();
3394 // Pre-allocate ETag slots for all parts
3395 let etags = Arc::new(tokio::sync::Mutex::new(vec![
3396 EtagPart {
3397 etag: "".to_owned(),
3398 part_number: 0,
3399 };
3400 n_parts
3401 ]));
3402
3403 // Upload all parts in parallel with concurrency limiting
3404 let tasks = (0..n_parts)
3405 .map(|part| {
3406 let http = http.clone();
3407 let url = urls[part].clone();
3408 let etags = etags.clone();
3409 let path = path.to_owned();
3410 let sem = sem.clone();
3411 let progress = progress.clone();
3412 let current = current.clone();
3413
3414 tokio::spawn(async move {
3415 // Acquire semaphore permit to limit concurrent uploads
3416 let _permit = sem.acquire().await?;
3417
3418 // Upload part (retry is handled by reqwest client)
3419 let etag =
3420 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
3421
3422 // Store ETag for this part (needed to complete multipart upload)
3423 let mut etags = etags.lock().await;
3424 etags[part] = EtagPart {
3425 etag,
3426 part_number: part + 1,
3427 };
3428
3429 // Update progress counter
3430 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
3431 if let Some(progress) = &progress {
3432 let _ = progress
3433 .send(Progress {
3434 current: current + PART_SIZE,
3435 total,
3436 })
3437 .await;
3438 }
3439
3440 Ok::<(), Error>(())
3441 })
3442 })
3443 .collect::<Vec<_>>();
3444
3445 // Wait for all parts to complete
3446 join_all(tasks)
3447 .await
3448 .into_iter()
3449 .collect::<Result<Vec<_>, _>>()?;
3450
3451 Ok(SnapshotCompleteMultipartParams {
3452 key,
3453 upload_id,
3454 etag_list: etags.lock().await.clone(),
3455 })
3456}
3457
3458async fn upload_part(
3459 http: reqwest::Client,
3460 url: String,
3461 path: PathBuf,
3462 part: usize,
3463 n_parts: usize,
3464) -> Result<String, Error> {
3465 let filesize = path.metadata()?.len() as usize;
3466 let mut file = File::open(path).await?;
3467 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
3468 .await?;
3469 let file = file.take(PART_SIZE as u64);
3470
3471 let body_length = if part + 1 == n_parts {
3472 filesize % PART_SIZE
3473 } else {
3474 PART_SIZE
3475 };
3476
3477 let stream = FramedRead::new(file, BytesCodec::new());
3478 let body = Body::wrap_stream(stream);
3479
3480 let resp = http
3481 .put(url.clone())
3482 .header(CONTENT_LENGTH, body_length)
3483 .body(body)
3484 .send()
3485 .await?
3486 .error_for_status()?;
3487
3488 let etag = resp
3489 .headers()
3490 .get("etag")
3491 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
3492 .to_str()
3493 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
3494 .to_owned();
3495
3496 // Studio Server requires etag without the quotes.
3497 let etag = etag
3498 .strip_prefix("\"")
3499 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
3500 let etag = etag
3501 .strip_suffix("\"")
3502 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
3503
3504 Ok(etag.to_owned())
3505}
3506
3507/// Upload a complete file to a presigned S3 URL using HTTP PUT.
3508///
3509/// This is used for populate_samples to upload files to S3 after
3510/// receiving presigned URLs from the server.
3511async fn upload_file_to_presigned_url(
3512 http: reqwest::Client,
3513 url: &str,
3514 path: PathBuf,
3515) -> Result<(), Error> {
3516 // Read the entire file into memory
3517 let file_data = fs::read(&path).await?;
3518 let file_size = file_data.len();
3519
3520 // Upload (retry is handled by reqwest client)
3521 let resp = http
3522 .put(url)
3523 .header(CONTENT_LENGTH, file_size)
3524 .body(file_data)
3525 .send()
3526 .await?;
3527
3528 if resp.status().is_success() {
3529 debug!(
3530 "Successfully uploaded file: {:?} ({} bytes)",
3531 path, file_size
3532 );
3533 Ok(())
3534 } else {
3535 let status = resp.status();
3536 let error_text = resp.text().await.unwrap_or_default();
3537 Err(Error::InvalidParameters(format!(
3538 "Upload failed: HTTP {} - {}",
3539 status, error_text
3540 )))
3541 }
3542}
3543
3544#[cfg(test)]
3545mod tests {
3546 use super::*;
3547
3548 #[test]
3549 fn test_filter_and_sort_by_name_exact_match_first() {
3550 // Test that exact matches come first
3551 let items = vec![
3552 "Deer Roundtrip 123".to_string(),
3553 "Deer".to_string(),
3554 "Reindeer".to_string(),
3555 "DEER".to_string(),
3556 ];
3557 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3558 assert_eq!(result[0], "Deer"); // Exact match first
3559 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
3560 }
3561
3562 #[test]
3563 fn test_filter_and_sort_by_name_shorter_names_preferred() {
3564 // Test that shorter names (more specific) come before longer ones
3565 let items = vec![
3566 "Test Dataset ABC".to_string(),
3567 "Test".to_string(),
3568 "Test Dataset".to_string(),
3569 ];
3570 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3571 assert_eq!(result[0], "Test"); // Exact match first
3572 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
3573 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
3574 }
3575
3576 #[test]
3577 fn test_filter_and_sort_by_name_case_insensitive_filter() {
3578 // Test that filtering is case-insensitive
3579 let items = vec![
3580 "UPPERCASE".to_string(),
3581 "lowercase".to_string(),
3582 "MixedCase".to_string(),
3583 ];
3584 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
3585 assert_eq!(result.len(), 3); // All items should match
3586 }
3587
3588 #[test]
3589 fn test_filter_and_sort_by_name_no_matches() {
3590 // Test that empty result is returned when no matches
3591 let items = vec!["Apple".to_string(), "Banana".to_string()];
3592 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
3593 assert!(result.is_empty());
3594 }
3595
3596 #[test]
3597 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
3598 // Test alphabetical ordering for same-length names
3599 let items = vec![
3600 "TestC".to_string(),
3601 "TestA".to_string(),
3602 "TestB".to_string(),
3603 ];
3604 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3605 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
3606 }
3607
3608 #[test]
3609 fn test_build_filename_no_flatten() {
3610 // When flatten=false, should return base_name unchanged
3611 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
3612 assert_eq!(result, "image.jpg");
3613
3614 let result = Client::build_filename("test.png", false, None, None);
3615 assert_eq!(result, "test.png");
3616 }
3617
3618 #[test]
3619 fn test_build_filename_flatten_no_sequence() {
3620 // When flatten=true but no sequence, should return base_name unchanged
3621 let result = Client::build_filename("standalone.jpg", true, None, None);
3622 assert_eq!(result, "standalone.jpg");
3623 }
3624
3625 #[test]
3626 fn test_build_filename_flatten_with_sequence_not_prefixed() {
3627 // When flatten=true, in sequence, filename not prefixed → add prefix
3628 let result = Client::build_filename(
3629 "image.camera.jpeg",
3630 true,
3631 Some(&"deer_sequence".to_string()),
3632 Some(42),
3633 );
3634 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
3635 }
3636
3637 #[test]
3638 fn test_build_filename_flatten_with_sequence_no_frame() {
3639 // When flatten=true, in sequence, no frame number → prefix with sequence only
3640 let result =
3641 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
3642 assert_eq!(result, "sequence_A_image.jpg");
3643 }
3644
3645 #[test]
3646 fn test_build_filename_flatten_already_prefixed() {
3647 // When flatten=true, filename already starts with sequence_ → return unchanged
3648 let result = Client::build_filename(
3649 "deer_sequence_042.camera.jpeg",
3650 true,
3651 Some(&"deer_sequence".to_string()),
3652 Some(42),
3653 );
3654 assert_eq!(result, "deer_sequence_042.camera.jpeg");
3655 }
3656
3657 #[test]
3658 fn test_build_filename_flatten_already_prefixed_different_frame() {
3659 // Edge case: filename has sequence prefix but we're adding different frame
3660 // Should still respect existing prefix
3661 let result = Client::build_filename(
3662 "sequence_A_001.jpg",
3663 true,
3664 Some(&"sequence_A".to_string()),
3665 Some(2),
3666 );
3667 assert_eq!(result, "sequence_A_001.jpg");
3668 }
3669
3670 #[test]
3671 fn test_build_filename_flatten_partial_match() {
3672 // Edge case: filename contains sequence name but not as prefix
3673 let result = Client::build_filename(
3674 "test_sequence_A_image.jpg",
3675 true,
3676 Some(&"sequence_A".to_string()),
3677 Some(5),
3678 );
3679 // Should add prefix because it doesn't START with "sequence_A_"
3680 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
3681 }
3682
3683 #[test]
3684 fn test_build_filename_flatten_preserves_extension() {
3685 // Verify that file extensions are preserved correctly
3686 let extensions = vec![
3687 "jpeg",
3688 "jpg",
3689 "png",
3690 "camera.jpeg",
3691 "lidar.pcd",
3692 "depth.png",
3693 ];
3694
3695 for ext in extensions {
3696 let filename = format!("image.{}", ext);
3697 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
3698 assert!(
3699 result.ends_with(&format!(".{}", ext)),
3700 "Extension .{} not preserved in {}",
3701 ext,
3702 result
3703 );
3704 }
3705 }
3706
3707 #[test]
3708 fn test_build_filename_flatten_sanitization_compatibility() {
3709 // Test with sanitized path components (no special chars)
3710 let result = Client::build_filename(
3711 "sample_001.jpg",
3712 true,
3713 Some(&"seq_name_with_underscores".to_string()),
3714 Some(10),
3715 );
3716 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
3717 }
3718
3719 // =========================================================================
3720 // Additional filter_and_sort_by_name tests for exact match determinism
3721 // =========================================================================
3722
3723 #[test]
3724 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
3725 // Test that searching for "Deer" always returns "Deer" first, not
3726 // "Deer Roundtrip 20251129" or similar
3727 let items = vec![
3728 "Deer Roundtrip 20251129".to_string(),
3729 "White-Tailed Deer".to_string(),
3730 "Deer".to_string(),
3731 "Deer Snapshot Test".to_string(),
3732 "Reindeer Dataset".to_string(),
3733 ];
3734
3735 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3736
3737 // CRITICAL: First result must be exact match "Deer"
3738 assert_eq!(
3739 result.first().map(|s| s.as_str()),
3740 Some("Deer"),
3741 "Expected exact match 'Deer' first, got: {:?}",
3742 result.first()
3743 );
3744
3745 // Verify all items containing "Deer" are present (case-insensitive)
3746 assert_eq!(result.len(), 5);
3747 }
3748
3749 #[test]
3750 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
3751 // Verify case-sensitive exact match takes priority over case-insensitive
3752 let items = vec![
3753 "DEER".to_string(),
3754 "deer".to_string(),
3755 "Deer".to_string(),
3756 "Deer Test".to_string(),
3757 ];
3758
3759 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3760
3761 // Priority 1: Case-sensitive exact match "Deer" first
3762 assert_eq!(result[0], "Deer");
3763 // Priority 2: Case-insensitive exact matches next
3764 assert!(result[1] == "DEER" || result[1] == "deer");
3765 assert!(result[2] == "DEER" || result[2] == "deer");
3766 }
3767
3768 #[test]
3769 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
3770 // Realistic scenario: User searches for snapshot "Deer" and multiple
3771 // snapshots exist with similar names
3772 let items = vec![
3773 "Unit Testing - Deer Dataset Backup".to_string(),
3774 "Deer".to_string(),
3775 "Deer Snapshot 2025-01-15".to_string(),
3776 "Original Deer".to_string(),
3777 ];
3778
3779 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3780
3781 // MUST return exact match first for deterministic test behavior
3782 assert_eq!(
3783 result[0], "Deer",
3784 "Searching for 'Deer' should return exact 'Deer' first"
3785 );
3786 }
3787
3788 #[test]
3789 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
3790 // Realistic scenario: User searches for dataset "Deer" but multiple
3791 // datasets have "Deer" in their name
3792 let items = vec![
3793 "Deer Roundtrip".to_string(),
3794 "Deer".to_string(),
3795 "deer".to_string(),
3796 "White-Tailed Deer".to_string(),
3797 "Deer-V2".to_string(),
3798 ];
3799
3800 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3801
3802 // Exact case-sensitive match must be first
3803 assert_eq!(result[0], "Deer");
3804 // Case-insensitive exact match should be second
3805 assert_eq!(result[1], "deer");
3806 // Shorter names should come before longer names
3807 assert!(
3808 result.iter().position(|s| s == "Deer-V2").unwrap()
3809 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
3810 );
3811 }
3812
3813 #[test]
3814 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
3815 // CRITICAL: The first result should ALWAYS be the best match
3816 // This is essential for deterministic test behavior
3817 let scenarios = vec![
3818 // (items, filter, expected_first)
3819 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
3820 (vec!["test", "TEST", "Test Data"], "test", "test"),
3821 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
3822 ];
3823
3824 for (items, filter, expected_first) in scenarios {
3825 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
3826 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
3827
3828 assert_eq!(
3829 result.first().map(|s| s.as_str()),
3830 Some(expected_first),
3831 "For filter '{}', expected first result '{}', got: {:?}",
3832 filter,
3833 expected_first,
3834 result.first()
3835 );
3836 }
3837 }
3838
3839 #[test]
3840 fn test_with_server_clears_storage() {
3841 use crate::storage::MemoryTokenStorage;
3842
3843 // Create client with memory storage and a token
3844 let storage = Arc::new(MemoryTokenStorage::new());
3845 storage.store("test-token").unwrap();
3846
3847 let client = Client::new().unwrap().with_storage(storage.clone());
3848
3849 // Verify token is loaded
3850 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
3851
3852 // Change server - should clear storage
3853 let _new_client = client.with_server("test").unwrap();
3854
3855 // Verify storage was cleared
3856 assert_eq!(storage.load().unwrap(), None);
3857 }
3858}