edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SampleID, SamplesCountResult, SamplesListParams, SamplesListResult,
9 Snapshot, SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID,
10 SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus,
11 TasksListParams, TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
12 ValidationSessionID,
13 },
14 dataset::{
15 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
16 },
17 retry::{create_retry_policy, log_retry_configuration},
18 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
19};
20use base64::Engine as _;
21use chrono::{DateTime, Utc};
22use directories::ProjectDirs;
23use futures::{StreamExt as _, future::join_all};
24use log::{Level, debug, error, log_enabled, trace, warn};
25use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
26use serde::{Deserialize, Serialize, de::DeserializeOwned};
27use std::{
28 collections::HashMap,
29 ffi::OsStr,
30 fs::create_dir_all,
31 io::{SeekFrom, Write as _},
32 path::{Path, PathBuf},
33 sync::{
34 Arc,
35 atomic::{AtomicUsize, Ordering},
36 },
37 time::Duration,
38 vec,
39};
40use tokio::{
41 fs::{self, File},
42 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
43 sync::{RwLock, Semaphore, mpsc::Sender},
44};
45use tokio_util::codec::{BytesCodec, FramedRead};
46use walkdir::WalkDir;
47
48#[cfg(feature = "polars")]
49use polars::prelude::*;
50
51static PART_SIZE: usize = 100 * 1024 * 1024;
52
53fn max_tasks() -> usize {
54 std::env::var("MAX_TASKS")
55 .ok()
56 .and_then(|v| v.parse().ok())
57 .unwrap_or_else(|| {
58 // Default to half the number of CPUs, minimum 2, maximum 8
59 // Lower max prevents timeout issues with large file uploads
60 let cpus = std::thread::available_parallelism()
61 .map(|n| n.get())
62 .unwrap_or(4);
63 (cpus / 2).clamp(2, 8)
64 })
65}
66
67/// Filters items by name and sorts by match quality.
68///
69/// Match quality priority (best to worst):
70/// 1. Exact match (case-sensitive)
71/// 2. Exact match (case-insensitive)
72/// 3. Substring match (shorter names first, then alphabetically)
73///
74/// This ensures that searching for "Deer" returns "Deer" before
75/// "Deer Roundtrip 20251129" or "Reindeer".
76fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
77where
78 F: Fn(&T) -> &str,
79{
80 let filter_lower = filter.to_lowercase();
81 let mut filtered: Vec<T> = items
82 .into_iter()
83 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
84 .collect();
85
86 filtered.sort_by(|a, b| {
87 let name_a = get_name(a);
88 let name_b = get_name(b);
89
90 // Priority 1: Exact match (case-sensitive)
91 let exact_a = name_a == filter;
92 let exact_b = name_b == filter;
93 if exact_a != exact_b {
94 return exact_b.cmp(&exact_a); // true (exact) comes first
95 }
96
97 // Priority 2: Exact match (case-insensitive)
98 let exact_ci_a = name_a.to_lowercase() == filter_lower;
99 let exact_ci_b = name_b.to_lowercase() == filter_lower;
100 if exact_ci_a != exact_ci_b {
101 return exact_ci_b.cmp(&exact_ci_a);
102 }
103
104 // Priority 3: Shorter names first (more specific matches)
105 let len_cmp = name_a.len().cmp(&name_b.len());
106 if len_cmp != std::cmp::Ordering::Equal {
107 return len_cmp;
108 }
109
110 // Priority 4: Alphabetical order for stability
111 name_a.cmp(name_b)
112 });
113
114 filtered
115}
116
117fn sanitize_path_component(name: &str) -> String {
118 let trimmed = name.trim();
119 if trimmed.is_empty() {
120 return "unnamed".to_string();
121 }
122
123 let component = Path::new(trimmed)
124 .file_name()
125 .unwrap_or_else(|| OsStr::new(trimmed));
126
127 let sanitized: String = component
128 .to_string_lossy()
129 .chars()
130 .map(|c| match c {
131 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
132 _ => c,
133 })
134 .collect();
135
136 if sanitized.is_empty() {
137 "unnamed".to_string()
138 } else {
139 sanitized
140 }
141}
142
143/// Progress information for long-running operations.
144///
145/// This struct tracks the current progress of operations like file uploads,
146/// downloads, or dataset processing. It provides the current count and total
147/// count to enable progress reporting in applications.
148///
149/// # Examples
150///
151/// ```rust
152/// use edgefirst_client::Progress;
153///
154/// let progress = Progress {
155/// current: 25,
156/// total: 100,
157/// };
158/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
159/// println!(
160/// "Progress: {:.1}% ({}/{})",
161/// percentage, progress.current, progress.total
162/// );
163/// ```
164#[derive(Debug, Clone)]
165pub struct Progress {
166 /// Current number of completed items.
167 pub current: usize,
168 /// Total number of items to process.
169 pub total: usize,
170}
171
172#[derive(Serialize)]
173struct RpcRequest<Params> {
174 id: u64,
175 jsonrpc: String,
176 method: String,
177 params: Option<Params>,
178}
179
180impl<T> Default for RpcRequest<T> {
181 fn default() -> Self {
182 RpcRequest {
183 id: 0,
184 jsonrpc: "2.0".to_string(),
185 method: "".to_string(),
186 params: None,
187 }
188 }
189}
190
191#[derive(Deserialize)]
192struct RpcError {
193 code: i32,
194 message: String,
195}
196
197#[derive(Deserialize)]
198struct RpcResponse<RpcResult> {
199 #[allow(dead_code)]
200 id: String,
201 #[allow(dead_code)]
202 jsonrpc: String,
203 error: Option<RpcError>,
204 result: Option<RpcResult>,
205}
206
207#[derive(Deserialize)]
208#[allow(dead_code)]
209struct EmptyResult {}
210
211#[derive(Debug, Serialize)]
212#[allow(dead_code)]
213struct SnapshotCreateParams {
214 snapshot_name: String,
215 keys: Vec<String>,
216}
217
218#[derive(Debug, Deserialize)]
219#[allow(dead_code)]
220struct SnapshotCreateResult {
221 snapshot_id: SnapshotID,
222 urls: Vec<String>,
223}
224
225#[derive(Debug, Serialize)]
226struct SnapshotCreateMultipartParams {
227 snapshot_name: String,
228 keys: Vec<String>,
229 file_sizes: Vec<usize>,
230}
231
232#[derive(Debug, Deserialize)]
233#[serde(untagged)]
234enum SnapshotCreateMultipartResultField {
235 Id(u64),
236 Part(SnapshotPart),
237}
238
239#[derive(Debug, Serialize)]
240struct SnapshotCompleteMultipartParams {
241 key: String,
242 upload_id: String,
243 etag_list: Vec<EtagPart>,
244}
245
246#[derive(Debug, Clone, Serialize)]
247struct EtagPart {
248 #[serde(rename = "ETag")]
249 etag: String,
250 #[serde(rename = "PartNumber")]
251 part_number: usize,
252}
253
254#[derive(Debug, Clone, Deserialize)]
255struct SnapshotPart {
256 key: Option<String>,
257 upload_id: String,
258 urls: Vec<String>,
259}
260
261#[derive(Debug, Serialize)]
262struct SnapshotStatusParams {
263 snapshot_id: SnapshotID,
264 status: String,
265}
266
267#[derive(Deserialize, Debug)]
268struct SnapshotStatusResult {
269 #[allow(dead_code)]
270 pub id: SnapshotID,
271 #[allow(dead_code)]
272 pub uid: String,
273 #[allow(dead_code)]
274 pub description: String,
275 #[allow(dead_code)]
276 pub date: String,
277 #[allow(dead_code)]
278 pub status: String,
279}
280
281#[derive(Serialize)]
282#[allow(dead_code)]
283struct ImageListParams {
284 images_filter: ImagesFilter,
285 image_files_filter: HashMap<String, String>,
286 only_ids: bool,
287}
288
289#[derive(Serialize)]
290#[allow(dead_code)]
291struct ImagesFilter {
292 dataset_id: DatasetID,
293}
294
295/// Main client for interacting with EdgeFirst Studio Server.
296///
297/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
298/// and manages authentication, RPC calls, and data operations. It provides
299/// methods for managing projects, datasets, experiments, training sessions,
300/// and various utility functions for data processing.
301///
302/// The client supports multiple authentication methods and can work with both
303/// SaaS and self-hosted EdgeFirst Studio instances.
304///
305/// # Features
306///
307/// - **Authentication**: Token-based authentication with automatic persistence
308/// - **Dataset Management**: Upload, download, and manipulate datasets
309/// - **Project Operations**: Create and manage projects and experiments
310/// - **Training & Validation**: Submit and monitor ML training jobs
311/// - **Data Integration**: Convert between EdgeFirst datasets and popular
312/// formats
313/// - **Progress Tracking**: Real-time progress updates for long-running
314/// operations
315///
316/// # Examples
317///
318/// ```no_run
319/// use edgefirst_client::{Client, DatasetID};
320/// use std::str::FromStr;
321///
322/// # async fn example() -> Result<(), edgefirst_client::Error> {
323/// // Create a new client and authenticate
324/// let mut client = Client::new()?;
325/// let client = client
326/// .with_login("your-email@example.com", "password")
327/// .await?;
328///
329/// // Or use an existing token
330/// let base_client = Client::new()?;
331/// let client = base_client.with_token("your-token-here")?;
332///
333/// // Get organization and projects
334/// let org = client.organization().await?;
335/// let projects = client.projects(None).await?;
336///
337/// // Work with datasets
338/// let dataset_id = DatasetID::from_str("ds-abc123")?;
339/// let dataset = client.dataset(dataset_id).await?;
340/// # Ok(())
341/// # }
342/// ```
343/// Client is Clone but cannot derive Debug due to dyn TokenStorage
344#[derive(Clone)]
345pub struct Client {
346 http: reqwest::Client,
347 url: String,
348 token: Arc<RwLock<String>>,
349 /// Token storage backend. When set, tokens are automatically persisted.
350 storage: Option<Arc<dyn TokenStorage>>,
351 /// Legacy token path field for backwards compatibility with
352 /// with_token_path(). Deprecated: Use with_storage() instead.
353 token_path: Option<PathBuf>,
354}
355
356impl std::fmt::Debug for Client {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 f.debug_struct("Client")
359 .field("url", &self.url)
360 .field("has_storage", &self.storage.is_some())
361 .field("token_path", &self.token_path)
362 .finish()
363 }
364}
365
366/// Private context struct for pagination operations
367struct FetchContext<'a> {
368 dataset_id: DatasetID,
369 annotation_set_id: Option<AnnotationSetID>,
370 groups: &'a [String],
371 types: Vec<String>,
372 labels: &'a HashMap<String, u64>,
373}
374
375impl Client {
376 /// Create a new unauthenticated client with the default saas server.
377 ///
378 /// By default, the client uses [`FileTokenStorage`] for token persistence.
379 /// Use [`with_storage`][Self::with_storage],
380 /// [`with_memory_storage`][Self::with_memory_storage],
381 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
382 /// behavior.
383 ///
384 /// To connect to a different server, use [`with_server`][Self::with_server]
385 /// or [`with_token`][Self::with_token] (tokens include the server
386 /// instance).
387 ///
388 /// This client is created without a token and will need to authenticate
389 /// before using methods that require authentication.
390 ///
391 /// # Examples
392 ///
393 /// ```rust,no_run
394 /// use edgefirst_client::Client;
395 ///
396 /// # fn main() -> Result<(), edgefirst_client::Error> {
397 /// // Create client with default file storage
398 /// let client = Client::new()?;
399 ///
400 /// // Create client without token persistence
401 /// let client = Client::new()?.with_memory_storage();
402 /// # Ok(())
403 /// # }
404 /// ```
405 pub fn new() -> Result<Self, Error> {
406 log_retry_configuration();
407
408 // Get timeout from environment or use default
409 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
410 .ok()
411 .and_then(|s| s.parse().ok())
412 .unwrap_or(30); // Default 30s timeout for API calls
413
414 // Create single HTTP client with URL-based retry policy
415 //
416 // The retry policy classifies requests into two categories:
417 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
418 // errors
419 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
420 //
421 // This allows the same client to handle both API calls and file operations
422 // with appropriate retry behavior for each. See retry.rs for details.
423 let http = reqwest::Client::builder()
424 .connect_timeout(Duration::from_secs(10))
425 .timeout(Duration::from_secs(timeout_secs))
426 .pool_idle_timeout(Duration::from_secs(90))
427 .pool_max_idle_per_host(10)
428 .retry(create_retry_policy())
429 .build()?;
430
431 // Default to file storage, loading any existing token
432 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
433 Ok(file_storage) => Arc::new(file_storage),
434 Err(e) => {
435 warn!(
436 "Could not initialize file token storage: {}. Using memory storage.",
437 e
438 );
439 Arc::new(MemoryTokenStorage::new())
440 }
441 };
442
443 // Try to load existing token from storage
444 let token = match storage.load() {
445 Ok(Some(t)) => t,
446 Ok(None) => String::new(),
447 Err(e) => {
448 warn!(
449 "Failed to load token from storage: {}. Starting with empty token.",
450 e
451 );
452 String::new()
453 }
454 };
455
456 // Extract server from token if available
457 let url = if !token.is_empty() {
458 match Self::extract_server_from_token(&token) {
459 Ok(server) => format!("https://{}.edgefirst.studio", server),
460 Err(e) => {
461 warn!(
462 "Failed to extract server from token: {}. Using default server.",
463 e
464 );
465 "https://edgefirst.studio".to_string()
466 }
467 }
468 } else {
469 "https://edgefirst.studio".to_string()
470 };
471
472 Ok(Client {
473 http,
474 url,
475 token: Arc::new(tokio::sync::RwLock::new(token)),
476 storage: Some(storage),
477 token_path: None,
478 })
479 }
480
481 /// Returns a new client connected to the specified server instance.
482 ///
483 /// The server parameter is an instance name that maps to a URL:
484 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
485 /// server)
486 /// - `"test"` → `https://test.edgefirst.studio`
487 /// - `"stage"` → `https://stage.edgefirst.studio`
488 /// - `"dev"` → `https://dev.edgefirst.studio`
489 /// - `"{name}"` → `https://{name}.edgefirst.studio`
490 ///
491 /// # Server Selection Priority
492 ///
493 /// When using the CLI or Python API, server selection follows this
494 /// priority:
495 ///
496 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
497 /// they were issued for. If you have a valid token, its server is used.
498 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
499 /// token is available. If a token exists with a different server, a
500 /// warning is emitted and the token's server takes priority.
501 /// 3. **Default `"saas"`** - If no token and no server specified, the
502 /// production server (`https://edgefirst.studio`) is used.
503 ///
504 /// # Important Notes
505 ///
506 /// - If a token is already set in the client, calling this method will
507 /// **drop the token** as tokens are specific to the server instance.
508 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
509 /// token's server before calling this method.
510 /// - For login operations, call `with_server()` first, then authenticate.
511 ///
512 /// # Examples
513 ///
514 /// ```rust,no_run
515 /// use edgefirst_client::Client;
516 ///
517 /// # fn main() -> Result<(), edgefirst_client::Error> {
518 /// let client = Client::new()?.with_server("test")?;
519 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
520 /// # Ok(())
521 /// # }
522 /// ```
523 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
524 let url = match server {
525 "" | "saas" => "https://edgefirst.studio".to_string(),
526 name => format!("https://{}.edgefirst.studio", name),
527 };
528
529 // Clear token from storage when changing servers to prevent
530 // authentication issues with stale tokens from different instances
531 if let Some(ref storage) = self.storage
532 && let Err(e) = storage.clear()
533 {
534 warn!(
535 "Failed to clear token from storage when changing servers: {}",
536 e
537 );
538 }
539
540 Ok(Client {
541 url,
542 token: Arc::new(tokio::sync::RwLock::new(String::new())),
543 ..self.clone()
544 })
545 }
546
547 /// Returns a new client with the specified token storage backend.
548 ///
549 /// Use this to configure custom token storage, such as platform-specific
550 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
551 ///
552 /// # Examples
553 ///
554 /// ```rust,no_run
555 /// use edgefirst_client::{Client, FileTokenStorage};
556 /// use std::{path::PathBuf, sync::Arc};
557 ///
558 /// # fn main() -> Result<(), edgefirst_client::Error> {
559 /// // Use a custom file path for token storage
560 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
561 /// let client = Client::new()?.with_storage(Arc::new(storage));
562 /// # Ok(())
563 /// # }
564 /// ```
565 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
566 // Try to load existing token from the new storage
567 let token = match storage.load() {
568 Ok(Some(t)) => t,
569 Ok(None) => String::new(),
570 Err(e) => {
571 warn!(
572 "Failed to load token from storage: {}. Starting with empty token.",
573 e
574 );
575 String::new()
576 }
577 };
578
579 Client {
580 token: Arc::new(tokio::sync::RwLock::new(token)),
581 storage: Some(storage),
582 token_path: None,
583 ..self
584 }
585 }
586
587 /// Returns a new client with in-memory token storage (no persistence).
588 ///
589 /// Tokens are stored in memory only and lost when the application exits.
590 /// This is useful for testing or when you want to manage token persistence
591 /// externally.
592 ///
593 /// # Examples
594 ///
595 /// ```rust,no_run
596 /// use edgefirst_client::Client;
597 ///
598 /// # fn main() -> Result<(), edgefirst_client::Error> {
599 /// let client = Client::new()?.with_memory_storage();
600 /// # Ok(())
601 /// # }
602 /// ```
603 pub fn with_memory_storage(self) -> Self {
604 Client {
605 token: Arc::new(tokio::sync::RwLock::new(String::new())),
606 storage: Some(Arc::new(MemoryTokenStorage::new())),
607 token_path: None,
608 ..self
609 }
610 }
611
612 /// Returns a new client with no token storage.
613 ///
614 /// Tokens are not persisted. Use this when you want to manage tokens
615 /// entirely manually.
616 ///
617 /// # Examples
618 ///
619 /// ```rust,no_run
620 /// use edgefirst_client::Client;
621 ///
622 /// # fn main() -> Result<(), edgefirst_client::Error> {
623 /// let client = Client::new()?.with_no_storage();
624 /// # Ok(())
625 /// # }
626 /// ```
627 pub fn with_no_storage(self) -> Self {
628 Client {
629 storage: None,
630 token_path: None,
631 ..self
632 }
633 }
634
635 /// Returns a new client authenticated with the provided username and
636 /// password.
637 ///
638 /// The token is automatically persisted to storage (if configured).
639 ///
640 /// # Examples
641 ///
642 /// ```rust,no_run
643 /// use edgefirst_client::Client;
644 ///
645 /// # async fn example() -> Result<(), edgefirst_client::Error> {
646 /// let client = Client::new()?
647 /// .with_server("test")?
648 /// .with_login("user@example.com", "password")
649 /// .await?;
650 /// # Ok(())
651 /// # }
652 /// ```
653 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
654 let params = HashMap::from([("username", username), ("password", password)]);
655 let login: LoginResult = self
656 .rpc_without_auth("auth.login".to_owned(), Some(params))
657 .await?;
658
659 // Validate that the server returned a non-empty token
660 if login.token.is_empty() {
661 return Err(Error::EmptyToken);
662 }
663
664 // Persist token to storage if configured
665 if let Some(ref storage) = self.storage
666 && let Err(e) = storage.store(&login.token)
667 {
668 warn!("Failed to persist token to storage: {}", e);
669 }
670
671 Ok(Client {
672 token: Arc::new(tokio::sync::RwLock::new(login.token)),
673 ..self.clone()
674 })
675 }
676
677 /// Returns a new client which will load and save the token to the specified
678 /// path.
679 ///
680 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
681 /// [`FileTokenStorage`] instead for more flexible token management.
682 ///
683 /// This method is maintained for backwards compatibility with existing
684 /// code. It disables the default storage and uses file-based storage at
685 /// the specified path.
686 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
687 let token_path = match token_path {
688 Some(path) => path.to_path_buf(),
689 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
690 .ok_or_else(|| {
691 Error::IoError(std::io::Error::new(
692 std::io::ErrorKind::NotFound,
693 "Could not determine user config directory",
694 ))
695 })?
696 .config_dir()
697 .join("token"),
698 };
699
700 debug!("Using token path (legacy): {:?}", token_path);
701
702 let token = match token_path.exists() {
703 true => std::fs::read_to_string(&token_path)?,
704 false => "".to_string(),
705 };
706
707 if !token.is_empty() {
708 match self.with_token(&token) {
709 Ok(client) => Ok(Client {
710 token_path: Some(token_path),
711 storage: None, // Disable new storage when using legacy token_path
712 ..client
713 }),
714 Err(e) => {
715 // Token is corrupted or invalid - remove it and continue with no token
716 warn!(
717 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
718 token_path, e
719 );
720 if let Err(remove_err) = std::fs::remove_file(&token_path) {
721 warn!("Failed to remove corrupted token file: {:?}", remove_err);
722 }
723 Ok(Client {
724 token_path: Some(token_path),
725 storage: None,
726 ..self.clone()
727 })
728 }
729 }
730 } else {
731 Ok(Client {
732 token_path: Some(token_path),
733 storage: None,
734 ..self.clone()
735 })
736 }
737 }
738
739 /// Returns a new client authenticated with the provided token.
740 ///
741 /// The token is automatically persisted to storage (if configured).
742 /// The server URL is extracted from the token payload.
743 ///
744 /// # Examples
745 ///
746 /// ```rust,no_run
747 /// use edgefirst_client::Client;
748 ///
749 /// # fn main() -> Result<(), edgefirst_client::Error> {
750 /// let client = Client::new()?.with_token("your-jwt-token")?;
751 /// # Ok(())
752 /// # }
753 /// ```
754 /// Extract server name from JWT token payload.
755 ///
756 /// Helper method to parse the JWT token and extract the "server" field
757 /// from the payload. Returns the server name (e.g., "test", "stage", "")
758 /// or an error if the token is invalid.
759 fn extract_server_from_token(token: &str) -> Result<String, Error> {
760 let token_parts: Vec<&str> = token.split('.').collect();
761 if token_parts.len() != 3 {
762 return Err(Error::InvalidToken);
763 }
764
765 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
766 .decode(token_parts[1])
767 .map_err(|_| Error::InvalidToken)?;
768 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
769 let server = match payload.get("server") {
770 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
771 None => return Err(Error::InvalidToken),
772 };
773
774 Ok(server)
775 }
776
777 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
778 if token.is_empty() {
779 return Ok(self.clone());
780 }
781
782 let server = Self::extract_server_from_token(token)?;
783
784 // Persist token to storage if configured
785 if let Some(ref storage) = self.storage
786 && let Err(e) = storage.store(token)
787 {
788 warn!("Failed to persist token to storage: {}", e);
789 }
790
791 Ok(Client {
792 url: format!("https://{}.edgefirst.studio", server),
793 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
794 ..self.clone()
795 })
796 }
797
798 /// Persist the current token to storage.
799 ///
800 /// This is automatically called when using [`with_login`][Self::with_login]
801 /// or [`with_token`][Self::with_token], so you typically don't need to call
802 /// this directly.
803 ///
804 /// If using the legacy `token_path` configuration, saves to the file path.
805 /// If using the new storage abstraction, saves to the configured storage.
806 pub async fn save_token(&self) -> Result<(), Error> {
807 let token = self.token.read().await;
808
809 // Try new storage first
810 if let Some(ref storage) = self.storage {
811 storage.store(&token)?;
812 debug!("Token saved to storage");
813 return Ok(());
814 }
815
816 // Fall back to legacy token_path behavior
817 let path = self.token_path.clone().unwrap_or_else(|| {
818 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
819 .map(|dirs| dirs.config_dir().join("token"))
820 .unwrap_or_else(|| PathBuf::from(".token"))
821 });
822
823 create_dir_all(path.parent().ok_or_else(|| {
824 Error::IoError(std::io::Error::new(
825 std::io::ErrorKind::InvalidInput,
826 "Token path has no parent directory",
827 ))
828 })?)?;
829 let mut file = std::fs::File::create(&path)?;
830 file.write_all(token.as_bytes())?;
831
832 debug!("Saved token to {:?}", path);
833
834 Ok(())
835 }
836
837 /// Return the version of the EdgeFirst Studio server for the current
838 /// client connection.
839 pub async fn version(&self) -> Result<String, Error> {
840 let version: HashMap<String, String> = self
841 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
842 .await?;
843 let version = version.get("version").ok_or(Error::InvalidResponse)?;
844 Ok(version.to_owned())
845 }
846
847 /// Clear the token used to authenticate the client with the server.
848 ///
849 /// Clears the token from memory and from storage (if configured).
850 /// If using the legacy `token_path` configuration, removes the token file.
851 pub async fn logout(&self) -> Result<(), Error> {
852 {
853 let mut token = self.token.write().await;
854 *token = "".to_string();
855 }
856
857 // Clear from new storage if configured
858 if let Some(ref storage) = self.storage
859 && let Err(e) = storage.clear()
860 {
861 warn!("Failed to clear token from storage: {}", e);
862 }
863
864 // Also clear legacy token_path if configured
865 if let Some(path) = &self.token_path
866 && path.exists()
867 {
868 fs::remove_file(path).await?;
869 }
870
871 Ok(())
872 }
873
874 /// Return the token used to authenticate the client with the server. When
875 /// logging into the server using a username and password, the token is
876 /// returned by the server and stored in the client for future interactions.
877 pub async fn token(&self) -> String {
878 self.token.read().await.clone()
879 }
880
881 /// Verify the token used to authenticate the client with the server. This
882 /// method is used to ensure that the token is still valid and has not
883 /// expired. If the token is invalid, the server will return an error and
884 /// the client will need to login again.
885 pub async fn verify_token(&self) -> Result<(), Error> {
886 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
887 .await?;
888 Ok::<(), Error>(())
889 }
890
891 /// Renew the token used to authenticate the client with the server.
892 ///
893 /// Refreshes the token before it expires. If the token has already expired,
894 /// the server will return an error and you will need to login again.
895 ///
896 /// The new token is automatically persisted to storage (if configured).
897 pub async fn renew_token(&self) -> Result<(), Error> {
898 let params = HashMap::from([("username".to_string(), self.username().await?)]);
899 let result: LoginResult = self
900 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
901 .await?;
902
903 {
904 let mut token = self.token.write().await;
905 *token = result.token.clone();
906 }
907
908 // Persist to new storage if configured
909 if let Some(ref storage) = self.storage
910 && let Err(e) = storage.store(&result.token)
911 {
912 warn!("Failed to persist renewed token to storage: {}", e);
913 }
914
915 // Also persist to legacy token_path if configured
916 if self.token_path.is_some() {
917 self.save_token().await?;
918 }
919
920 Ok(())
921 }
922
923 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
924 let token = self.token.read().await;
925 if token.is_empty() {
926 return Err(Error::EmptyToken);
927 }
928
929 let token_parts: Vec<&str> = token.split('.').collect();
930 if token_parts.len() != 3 {
931 return Err(Error::InvalidToken);
932 }
933
934 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
935 .decode(token_parts[1])
936 .map_err(|_| Error::InvalidToken)?;
937 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
938 match payload.get(field) {
939 Some(value) => Ok(value.to_owned()),
940 None => Err(Error::InvalidToken),
941 }
942 }
943
944 /// Returns the URL of the EdgeFirst Studio server for the current client.
945 pub fn url(&self) -> &str {
946 &self.url
947 }
948
949 /// Returns the server name for the current client.
950 ///
951 /// This extracts the server name from the client's URL:
952 /// - `https://edgefirst.studio` → `"saas"`
953 /// - `https://test.edgefirst.studio` → `"test"`
954 /// - `https://{name}.edgefirst.studio` → `"{name}"`
955 ///
956 /// # Examples
957 ///
958 /// ```rust,no_run
959 /// use edgefirst_client::Client;
960 ///
961 /// # fn main() -> Result<(), edgefirst_client::Error> {
962 /// let client = Client::new()?.with_server("test")?;
963 /// assert_eq!(client.server(), "test");
964 ///
965 /// let client = Client::new()?; // default
966 /// assert_eq!(client.server(), "saas");
967 /// # Ok(())
968 /// # }
969 /// ```
970 pub fn server(&self) -> &str {
971 if self.url == "https://edgefirst.studio" {
972 "saas"
973 } else if let Some(name) = self.url.strip_prefix("https://") {
974 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
975 } else {
976 "saas"
977 }
978 }
979
980 /// Returns the username associated with the current token.
981 pub async fn username(&self) -> Result<String, Error> {
982 match self.token_field("username").await? {
983 serde_json::Value::String(username) => Ok(username),
984 _ => Err(Error::InvalidToken),
985 }
986 }
987
988 /// Returns the expiration time for the current token.
989 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
990 let ts = match self.token_field("exp").await? {
991 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
992 _ => return Err(Error::InvalidToken),
993 };
994
995 match DateTime::<Utc>::from_timestamp_secs(ts) {
996 Some(dt) => Ok(dt),
997 None => Err(Error::InvalidToken),
998 }
999 }
1000
1001 /// Returns the organization information for the current user.
1002 pub async fn organization(&self) -> Result<Organization, Error> {
1003 self.rpc::<(), Organization>("org.get".to_owned(), None)
1004 .await
1005 }
1006
1007 /// Returns a list of projects available to the user. The projects are
1008 /// returned as a vector of Project objects. If a name filter is
1009 /// provided, only projects matching the filter are returned.
1010 ///
1011 /// Results are sorted by match quality: exact matches first, then
1012 /// case-insensitive exact matches, then shorter names (more specific),
1013 /// then alphabetically.
1014 ///
1015 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1016 /// Projects contain datasets, trainers, and trainer sessions. Projects
1017 /// are used to group related datasets and trainers together.
1018 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1019 let projects = self
1020 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1021 .await?;
1022 if let Some(name) = name {
1023 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1024 } else {
1025 Ok(projects)
1026 }
1027 }
1028
1029 /// Return the project with the specified project ID. If the project does
1030 /// not exist, an error is returned.
1031 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1032 let params = HashMap::from([("project_id", project_id)]);
1033 self.rpc("project.get".to_owned(), Some(params)).await
1034 }
1035
1036 /// Returns a list of datasets available to the user. The datasets are
1037 /// returned as a vector of Dataset objects. If a name filter is
1038 /// provided, only datasets matching the filter are returned.
1039 ///
1040 /// Results are sorted by match quality: exact matches first, then
1041 /// case-insensitive exact matches, then shorter names (more specific),
1042 /// then alphabetically. This ensures "Deer" returns before "Deer
1043 /// Roundtrip".
1044 pub async fn datasets(
1045 &self,
1046 project_id: ProjectID,
1047 name: Option<&str>,
1048 ) -> Result<Vec<Dataset>, Error> {
1049 let params = HashMap::from([("project_id", project_id)]);
1050 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1051 if let Some(name) = name {
1052 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1053 } else {
1054 Ok(datasets)
1055 }
1056 }
1057
1058 /// Return the dataset with the specified dataset ID. If the dataset does
1059 /// not exist, an error is returned.
1060 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1061 let params = HashMap::from([("dataset_id", dataset_id)]);
1062 self.rpc("dataset.get".to_owned(), Some(params)).await
1063 }
1064
1065 /// Lists the labels for the specified dataset.
1066 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1067 let params = HashMap::from([("dataset_id", dataset_id)]);
1068 self.rpc("label.list".to_owned(), Some(params)).await
1069 }
1070
1071 /// Add a new label to the dataset with the specified name.
1072 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1073 let new_label = NewLabel {
1074 dataset_id,
1075 labels: vec![NewLabelObject {
1076 name: name.to_owned(),
1077 }],
1078 };
1079 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1080 Ok(())
1081 }
1082
1083 /// Removes the label with the specified ID from the dataset. Label IDs are
1084 /// globally unique so the dataset_id is not required.
1085 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1086 let params = HashMap::from([("label_id", label_id)]);
1087 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1088 Ok(())
1089 }
1090
1091 /// Creates a new dataset in the specified project.
1092 ///
1093 /// # Arguments
1094 ///
1095 /// * `project_id` - The ID of the project to create the dataset in
1096 /// * `name` - The name of the new dataset
1097 /// * `description` - Optional description for the dataset
1098 ///
1099 /// # Returns
1100 ///
1101 /// Returns the dataset ID of the newly created dataset.
1102 pub async fn create_dataset(
1103 &self,
1104 project_id: &str,
1105 name: &str,
1106 description: Option<&str>,
1107 ) -> Result<DatasetID, Error> {
1108 let mut params = HashMap::new();
1109 params.insert("project_id", project_id);
1110 params.insert("name", name);
1111 if let Some(desc) = description {
1112 params.insert("description", desc);
1113 }
1114
1115 #[derive(Deserialize)]
1116 struct CreateDatasetResult {
1117 id: DatasetID,
1118 }
1119
1120 let result: CreateDatasetResult =
1121 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1122 Ok(result.id)
1123 }
1124
1125 /// Deletes a dataset by marking it as deleted.
1126 ///
1127 /// # Arguments
1128 ///
1129 /// * `dataset_id` - The ID of the dataset to delete
1130 ///
1131 /// # Returns
1132 ///
1133 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1134 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1135 let params = HashMap::from([("id", dataset_id)]);
1136 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1137 Ok(())
1138 }
1139
1140 /// Updates the label with the specified ID to have the new name or index.
1141 /// Label IDs cannot be changed. Label IDs are globally unique so the
1142 /// dataset_id is not required.
1143 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1144 #[derive(Serialize)]
1145 struct Params {
1146 dataset_id: DatasetID,
1147 label_id: u64,
1148 label_name: String,
1149 label_index: u64,
1150 }
1151
1152 let _: String = self
1153 .rpc(
1154 "label.update".to_owned(),
1155 Some(Params {
1156 dataset_id: label.dataset_id(),
1157 label_id: label.id(),
1158 label_name: label.name().to_owned(),
1159 label_index: label.index(),
1160 }),
1161 )
1162 .await?;
1163 Ok(())
1164 }
1165
1166 /// Lists the groups for the specified dataset.
1167 ///
1168 /// Groups are used to organize samples into logical subsets such as
1169 /// "train", "val", "test", etc. Each sample can belong to at most one
1170 /// group at a time.
1171 ///
1172 /// # Arguments
1173 ///
1174 /// * `dataset_id` - The ID of the dataset to list groups for
1175 ///
1176 /// # Returns
1177 ///
1178 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1179 /// empty vector if no groups have been created yet.
1180 ///
1181 /// # Errors
1182 ///
1183 /// Returns an error if the dataset does not exist or cannot be accessed.
1184 ///
1185 /// # Example
1186 ///
1187 /// ```rust,no_run
1188 /// # use edgefirst_client::{Client, DatasetID};
1189 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1190 /// let client = Client::new()?.with_token_path(None)?;
1191 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1192 ///
1193 /// let groups = client.groups(dataset_id).await?;
1194 /// for group in groups {
1195 /// println!("{}: {}", group.id, group.name);
1196 /// }
1197 /// # Ok(())
1198 /// # }
1199 /// ```
1200 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1201 let params = HashMap::from([("dataset_id", dataset_id)]);
1202 self.rpc("groups.list".to_owned(), Some(params)).await
1203 }
1204
1205 /// Gets an existing group by name or creates a new one.
1206 ///
1207 /// This is a convenience method that first checks if a group with the
1208 /// specified name exists, and creates it if not. This is useful when
1209 /// you need to ensure a group exists before assigning samples to it.
1210 ///
1211 /// # Arguments
1212 ///
1213 /// * `dataset_id` - The ID of the dataset
1214 /// * `name` - The name of the group (e.g., "train", "val", "test")
1215 ///
1216 /// # Returns
1217 ///
1218 /// Returns the group ID (either existing or newly created).
1219 ///
1220 /// # Errors
1221 ///
1222 /// Returns an error if:
1223 /// - The dataset does not exist or cannot be accessed
1224 /// - The group creation fails
1225 ///
1226 /// # Concurrency
1227 ///
1228 /// This method handles concurrent creation attempts gracefully. If another
1229 /// process creates the group between the existence check and creation,
1230 /// this method will return the existing group's ID.
1231 ///
1232 /// # Example
1233 ///
1234 /// ```rust,no_run
1235 /// # use edgefirst_client::{Client, DatasetID};
1236 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1237 /// let client = Client::new()?.with_token_path(None)?;
1238 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1239 ///
1240 /// // Get or create a "train" group
1241 /// let train_group_id = client
1242 /// .get_or_create_group(dataset_id.clone(), "train")
1243 /// .await?;
1244 /// println!("Train group ID: {}", train_group_id);
1245 ///
1246 /// // Calling again returns the same ID
1247 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1248 /// assert_eq!(train_group_id, same_id);
1249 /// # Ok(())
1250 /// # }
1251 /// ```
1252 pub async fn get_or_create_group(
1253 &self,
1254 dataset_id: DatasetID,
1255 name: &str,
1256 ) -> Result<u64, Error> {
1257 // First check if the group already exists
1258 let groups = self.groups(dataset_id).await?;
1259 if let Some(group) = groups.iter().find(|g| g.name == name) {
1260 return Ok(group.id);
1261 }
1262
1263 // Create the group
1264 #[derive(Serialize)]
1265 struct CreateGroupParams {
1266 dataset_id: DatasetID,
1267 group_names: Vec<String>,
1268 group_splits: Vec<i64>,
1269 }
1270
1271 let params = CreateGroupParams {
1272 dataset_id,
1273 group_names: vec![name.to_string()],
1274 group_splits: vec![0], // No automatic splitting
1275 };
1276
1277 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1278 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1279 Ok(group.id)
1280 } else {
1281 // Group might have been created by concurrent call, try fetching again
1282 let groups = self.groups(dataset_id).await?;
1283 groups
1284 .iter()
1285 .find(|g| g.name == name)
1286 .map(|g| g.id)
1287 .ok_or_else(|| {
1288 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1289 })
1290 }
1291 }
1292
1293 /// Sets the group for a sample.
1294 ///
1295 /// Assigns a sample to a specific group. Each sample can belong to at most
1296 /// one group at a time. Setting a new group replaces any existing group
1297 /// assignment.
1298 ///
1299 /// # Arguments
1300 ///
1301 /// * `sample_id` - The ID of the sample (image) to update
1302 /// * `group_id` - The ID of the group to assign. Use
1303 /// [`get_or_create_group`] to obtain a group ID from a name.
1304 ///
1305 /// # Returns
1306 ///
1307 /// Returns `Ok(())` on success.
1308 ///
1309 /// # Errors
1310 ///
1311 /// Returns an error if:
1312 /// - The sample does not exist
1313 /// - The group does not exist
1314 /// - Insufficient permissions to modify the sample
1315 ///
1316 /// # Example
1317 ///
1318 /// ```rust,no_run
1319 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1320 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1321 /// let client = Client::new()?.with_token_path(None)?;
1322 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1323 /// let sample_id: SampleID = 12345.into();
1324 ///
1325 /// // Get or create the "val" group
1326 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1327 ///
1328 /// // Assign the sample to the "val" group
1329 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1330 /// # Ok(())
1331 /// # }
1332 /// ```
1333 ///
1334 /// [`get_or_create_group`]: Self::get_or_create_group
1335 pub async fn set_sample_group_id(
1336 &self,
1337 sample_id: SampleID,
1338 group_id: u64,
1339 ) -> Result<(), Error> {
1340 #[derive(Serialize)]
1341 struct SetGroupParams {
1342 image_id: SampleID,
1343 group_id: u64,
1344 }
1345
1346 let params = SetGroupParams {
1347 image_id: sample_id,
1348 group_id,
1349 };
1350 let _: String = self
1351 .rpc("image.set_group_id".to_owned(), Some(params))
1352 .await?;
1353 Ok(())
1354 }
1355
1356 /// Downloads dataset samples to the local filesystem.
1357 ///
1358 /// # Arguments
1359 ///
1360 /// * `dataset_id` - The unique identifier of the dataset
1361 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1362 /// * `file_types` - File types to download (e.g., Image, LidarPcd)
1363 /// * `output` - Local directory to save downloaded files
1364 /// * `flatten` - If true, download all files to output root without
1365 /// sequence subdirectories. When flattening, filenames are prefixed with
1366 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1367 /// unavailable) unless the filename already starts with
1368 /// `{sequence_name}_`, to avoid conflicts between sequences.
1369 /// * `progress` - Optional channel for progress updates
1370 ///
1371 /// # Returns
1372 ///
1373 /// Returns `Ok(())` on success or an error if download fails.
1374 ///
1375 /// # Example
1376 ///
1377 /// ```rust,no_run
1378 /// # use edgefirst_client::{Client, DatasetID, FileType};
1379 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1380 /// let client = Client::new()?.with_token_path(None)?;
1381 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1382 ///
1383 /// // Download with sequence subdirectories (default)
1384 /// client
1385 /// .download_dataset(
1386 /// dataset_id,
1387 /// &[],
1388 /// &[FileType::Image],
1389 /// "./data".into(),
1390 /// false,
1391 /// None,
1392 /// )
1393 /// .await?;
1394 ///
1395 /// // Download flattened (all files in one directory)
1396 /// client
1397 /// .download_dataset(
1398 /// dataset_id,
1399 /// &[],
1400 /// &[FileType::Image],
1401 /// "./data".into(),
1402 /// true,
1403 /// None,
1404 /// )
1405 /// .await?;
1406 /// # Ok(())
1407 /// # }
1408 /// ```
1409 pub async fn download_dataset(
1410 &self,
1411 dataset_id: DatasetID,
1412 groups: &[String],
1413 file_types: &[FileType],
1414 output: PathBuf,
1415 flatten: bool,
1416 progress: Option<Sender<Progress>>,
1417 ) -> Result<(), Error> {
1418 let samples = self
1419 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1420 .await?;
1421 fs::create_dir_all(&output).await?;
1422
1423 let client = self.clone();
1424 let file_types = file_types.to_vec();
1425 let output = output.clone();
1426
1427 parallel_foreach_items(samples, progress, None, move |sample| {
1428 let client = client.clone();
1429 let file_types = file_types.clone();
1430 let output = output.clone();
1431
1432 async move {
1433 for file_type in file_types {
1434 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1435 let (file_ext, is_image) = match file_type.clone() {
1436 FileType::Image => (
1437 infer::get(&data)
1438 .expect("Failed to identify image file format for sample")
1439 .extension()
1440 .to_string(),
1441 true,
1442 ),
1443 other => (other.to_string(), false),
1444 };
1445
1446 // Determine target directory based on sequence membership and flatten
1447 // option
1448 // - flatten=false + sequence_name: dataset/sequence_name/
1449 // - flatten=false + no sequence: dataset/ (root level)
1450 // - flatten=true: dataset/ (all files in output root)
1451 // NOTE: group (train/val/test) is NOT used for directory structure
1452 let sequence_dir = sample
1453 .sequence_name()
1454 .map(|name| sanitize_path_component(name));
1455
1456 let target_dir = if flatten {
1457 output.clone()
1458 } else {
1459 sequence_dir
1460 .as_ref()
1461 .map(|seq| output.join(seq))
1462 .unwrap_or_else(|| output.clone())
1463 };
1464 fs::create_dir_all(&target_dir).await?;
1465
1466 let sanitized_sample_name = sample
1467 .name()
1468 .map(|name| sanitize_path_component(&name))
1469 .unwrap_or_else(|| "unknown".to_string());
1470
1471 let image_name = sample.image_name().map(sanitize_path_component);
1472
1473 // Construct filename with smart prefixing for flatten mode
1474 // When flatten=true and sample belongs to a sequence:
1475 // - Check if filename already starts with "{sequence_name}_"
1476 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1477 // - If yes, use filename as-is (already uniquely named)
1478 let file_name = if is_image {
1479 if let Some(img_name) = image_name {
1480 Self::build_filename(
1481 &img_name,
1482 flatten,
1483 sequence_dir.as_ref(),
1484 sample.frame_number(),
1485 )
1486 } else {
1487 format!("{}.{}", sanitized_sample_name, file_ext)
1488 }
1489 } else {
1490 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1491 Self::build_filename(
1492 &base_name,
1493 flatten,
1494 sequence_dir.as_ref(),
1495 sample.frame_number(),
1496 )
1497 };
1498
1499 let file_path = target_dir.join(&file_name);
1500
1501 let mut file = File::create(&file_path).await?;
1502 file.write_all(&data).await?;
1503 } else {
1504 warn!(
1505 "No data for sample: {}",
1506 sample
1507 .id()
1508 .map(|id| id.to_string())
1509 .unwrap_or_else(|| "unknown".to_string())
1510 );
1511 }
1512 }
1513
1514 Ok(())
1515 }
1516 })
1517 .await
1518 }
1519
1520 /// Builds a filename with smart prefixing for flatten mode.
1521 ///
1522 /// When flattening sequences into a single directory, this function ensures
1523 /// unique filenames by checking if the sequence prefix already exists and
1524 /// adding it if necessary.
1525 ///
1526 /// # Logic
1527 ///
1528 /// - If `flatten=false`: returns `base_name` unchanged
1529 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
1530 /// - If `flatten=true` and in sequence:
1531 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
1532 /// unchanged
1533 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
1534 /// `{sequence_name}_{base_name}`
1535 fn build_filename(
1536 base_name: &str,
1537 flatten: bool,
1538 sequence_name: Option<&String>,
1539 frame_number: Option<u32>,
1540 ) -> String {
1541 if !flatten || sequence_name.is_none() {
1542 return base_name.to_string();
1543 }
1544
1545 let seq_name = sequence_name.unwrap();
1546 let prefix = format!("{}_", seq_name);
1547
1548 // Check if already prefixed with sequence name
1549 if base_name.starts_with(&prefix) {
1550 base_name.to_string()
1551 } else {
1552 // Add sequence (and optionally frame) prefix
1553 match frame_number {
1554 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1555 None => format!("{}{}", prefix, base_name),
1556 }
1557 }
1558 }
1559
1560 /// List available annotation sets for the specified dataset.
1561 pub async fn annotation_sets(
1562 &self,
1563 dataset_id: DatasetID,
1564 ) -> Result<Vec<AnnotationSet>, Error> {
1565 let params = HashMap::from([("dataset_id", dataset_id)]);
1566 self.rpc("annset.list".to_owned(), Some(params)).await
1567 }
1568
1569 /// Create a new annotation set for the specified dataset.
1570 ///
1571 /// # Arguments
1572 ///
1573 /// * `dataset_id` - The ID of the dataset to create the annotation set in
1574 /// * `name` - The name of the new annotation set
1575 /// * `description` - Optional description for the annotation set
1576 ///
1577 /// # Returns
1578 ///
1579 /// Returns the annotation set ID of the newly created annotation set.
1580 pub async fn create_annotation_set(
1581 &self,
1582 dataset_id: DatasetID,
1583 name: &str,
1584 description: Option<&str>,
1585 ) -> Result<AnnotationSetID, Error> {
1586 #[derive(Serialize)]
1587 struct Params<'a> {
1588 dataset_id: DatasetID,
1589 name: &'a str,
1590 operator: &'a str,
1591 #[serde(skip_serializing_if = "Option::is_none")]
1592 description: Option<&'a str>,
1593 }
1594
1595 #[derive(Deserialize)]
1596 struct CreateAnnotationSetResult {
1597 id: AnnotationSetID,
1598 }
1599
1600 let username = self.username().await?;
1601 let result: CreateAnnotationSetResult = self
1602 .rpc(
1603 "annset.add".to_owned(),
1604 Some(Params {
1605 dataset_id,
1606 name,
1607 operator: &username,
1608 description,
1609 }),
1610 )
1611 .await?;
1612 Ok(result.id)
1613 }
1614
1615 /// Deletes an annotation set by marking it as deleted.
1616 ///
1617 /// # Arguments
1618 ///
1619 /// * `annotation_set_id` - The ID of the annotation set to delete
1620 ///
1621 /// # Returns
1622 ///
1623 /// Returns `Ok(())` if the annotation set was successfully marked as
1624 /// deleted.
1625 pub async fn delete_annotation_set(
1626 &self,
1627 annotation_set_id: AnnotationSetID,
1628 ) -> Result<(), Error> {
1629 let params = HashMap::from([("id", annotation_set_id)]);
1630 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1631 Ok(())
1632 }
1633
1634 /// Retrieve the annotation set with the specified ID.
1635 pub async fn annotation_set(
1636 &self,
1637 annotation_set_id: AnnotationSetID,
1638 ) -> Result<AnnotationSet, Error> {
1639 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1640 self.rpc("annset.get".to_owned(), Some(params)).await
1641 }
1642
1643 /// Get the annotations for the specified annotation set with the
1644 /// requested annotation types. The annotation types are used to filter
1645 /// the annotations returned. The groups parameter is used to filter for
1646 /// dataset groups (train, val, test). Images which do not have any
1647 /// annotations are also included in the result as long as they are in the
1648 /// requested groups (when specified).
1649 ///
1650 /// The result is a vector of Annotations objects which contain the
1651 /// full dataset along with the annotations for the specified types.
1652 ///
1653 /// To get the annotations as a DataFrame, use the `annotations_dataframe`
1654 /// method instead.
1655 pub async fn annotations(
1656 &self,
1657 annotation_set_id: AnnotationSetID,
1658 groups: &[String],
1659 annotation_types: &[AnnotationType],
1660 progress: Option<Sender<Progress>>,
1661 ) -> Result<Vec<Annotation>, Error> {
1662 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1663 let labels = self
1664 .labels(dataset_id)
1665 .await?
1666 .into_iter()
1667 .map(|label| (label.name().to_string(), label.index()))
1668 .collect::<HashMap<_, _>>();
1669 let total = self
1670 .samples_count(
1671 dataset_id,
1672 Some(annotation_set_id),
1673 annotation_types,
1674 groups,
1675 &[],
1676 )
1677 .await?
1678 .total as usize;
1679
1680 if total == 0 {
1681 return Ok(vec![]);
1682 }
1683
1684 let context = FetchContext {
1685 dataset_id,
1686 annotation_set_id: Some(annotation_set_id),
1687 groups,
1688 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1689 labels: &labels,
1690 };
1691
1692 self.fetch_annotations_paginated(context, total, progress)
1693 .await
1694 }
1695
1696 async fn fetch_annotations_paginated(
1697 &self,
1698 context: FetchContext<'_>,
1699 total: usize,
1700 progress: Option<Sender<Progress>>,
1701 ) -> Result<Vec<Annotation>, Error> {
1702 let mut annotations = vec![];
1703 let mut continue_token: Option<String> = None;
1704 let mut current = 0;
1705
1706 loop {
1707 let params = SamplesListParams {
1708 dataset_id: context.dataset_id,
1709 annotation_set_id: context.annotation_set_id,
1710 types: context.types.clone(),
1711 group_names: context.groups.to_vec(),
1712 continue_token,
1713 };
1714
1715 let result: SamplesListResult =
1716 self.rpc("samples.list".to_owned(), Some(params)).await?;
1717 current += result.samples.len();
1718 continue_token = result.continue_token;
1719
1720 if result.samples.is_empty() {
1721 break;
1722 }
1723
1724 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1725
1726 if let Some(progress) = &progress {
1727 let _ = progress.send(Progress { current, total }).await;
1728 }
1729
1730 match &continue_token {
1731 Some(token) if !token.is_empty() => continue,
1732 _ => break,
1733 }
1734 }
1735
1736 drop(progress);
1737 Ok(annotations)
1738 }
1739
1740 fn process_sample_annotations(
1741 &self,
1742 samples: &[Sample],
1743 labels: &HashMap<String, u64>,
1744 annotations: &mut Vec<Annotation>,
1745 ) {
1746 for sample in samples {
1747 if sample.annotations().is_empty() {
1748 let mut annotation = Annotation::new();
1749 annotation.set_sample_id(sample.id());
1750 annotation.set_name(sample.name());
1751 annotation.set_sequence_name(sample.sequence_name().cloned());
1752 annotation.set_frame_number(sample.frame_number());
1753 annotation.set_group(sample.group().cloned());
1754 annotations.push(annotation);
1755 continue;
1756 }
1757
1758 for annotation in sample.annotations() {
1759 let mut annotation = annotation.clone();
1760 annotation.set_sample_id(sample.id());
1761 annotation.set_name(sample.name());
1762 annotation.set_sequence_name(sample.sequence_name().cloned());
1763 annotation.set_frame_number(sample.frame_number());
1764 annotation.set_group(sample.group().cloned());
1765 Self::set_label_index_from_map(&mut annotation, labels);
1766 annotations.push(annotation);
1767 }
1768 }
1769 }
1770
1771 /// Delete annotations in bulk from specified samples.
1772 ///
1773 /// This method calls the `annotation.bulk.del` API to efficiently remove
1774 /// annotations from multiple samples at once. Useful for clearing
1775 /// annotations before re-importing updated data.
1776 ///
1777 /// # Arguments
1778 /// * `annotation_set_id` - The annotation set containing the annotations
1779 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
1780 /// for masks
1781 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
1782 ///
1783 /// # Example
1784 /// ```no_run
1785 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
1786 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1787 /// # let client = Client::new()?.with_login("user", "pass").await?;
1788 /// let annotation_set_id = AnnotationSetID::from(123);
1789 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
1790 ///
1791 /// client
1792 /// .delete_annotations_bulk(
1793 /// annotation_set_id,
1794 /// &["box".to_string(), "seg".to_string()],
1795 /// &sample_ids,
1796 /// )
1797 /// .await?;
1798 /// # Ok(())
1799 /// # }
1800 /// ```
1801 pub async fn delete_annotations_bulk(
1802 &self,
1803 annotation_set_id: AnnotationSetID,
1804 annotation_types: &[String],
1805 sample_ids: &[SampleID],
1806 ) -> Result<(), Error> {
1807 use crate::api::AnnotationBulkDeleteParams;
1808
1809 let params = AnnotationBulkDeleteParams {
1810 annotation_set_id: annotation_set_id.into(),
1811 annotation_types: annotation_types.to_vec(),
1812 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
1813 delete_all: None,
1814 };
1815
1816 let _: String = self
1817 .rpc("annotation.bulk.del".to_owned(), Some(params))
1818 .await?;
1819 Ok(())
1820 }
1821
1822 /// Add annotations in bulk.
1823 ///
1824 /// This method calls the `annotation.add_bulk` API to efficiently add
1825 /// multiple annotations at once. The annotations must be in server format
1826 /// with image_id references.
1827 ///
1828 /// # Arguments
1829 /// * `annotation_set_id` - The annotation set to add annotations to
1830 /// * `annotations` - Vector of server-format annotations to add
1831 ///
1832 /// # Returns
1833 /// Vector of created annotation records from the server.
1834 pub async fn add_annotations_bulk(
1835 &self,
1836 annotation_set_id: AnnotationSetID,
1837 annotations: Vec<crate::api::ServerAnnotation>,
1838 ) -> Result<Vec<serde_json::Value>, Error> {
1839 use crate::api::AnnotationAddBulkParams;
1840
1841 let params = AnnotationAddBulkParams {
1842 annotation_set_id: annotation_set_id.into(),
1843 annotations,
1844 };
1845
1846 self.rpc("annotation.add_bulk".to_owned(), Some(params))
1847 .await
1848 }
1849
1850 /// Helper to parse frame number from image_name when sequence_name is
1851 /// present. This ensures frame_number is always derived from the image
1852 /// filename, not from the server's frame_number field (which may be
1853 /// inconsistent).
1854 ///
1855 /// Returns Some(frame_number) if sequence_name is present and frame can be
1856 /// parsed, otherwise None.
1857 fn parse_frame_from_image_name(
1858 image_name: Option<&String>,
1859 sequence_name: Option<&String>,
1860 ) -> Option<u32> {
1861 use std::path::Path;
1862
1863 let sequence = sequence_name?;
1864 let name = image_name?;
1865
1866 // Extract stem (remove extension)
1867 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1868
1869 // Parse frame from format: "sequence_XXX" where XXX is the frame number
1870 stem.strip_prefix(sequence)
1871 .and_then(|suffix| suffix.strip_prefix('_'))
1872 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1873 }
1874
1875 /// Helper to set label index from a label map
1876 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1877 if let Some(label) = annotation.label() {
1878 annotation.set_label_index(Some(labels[label.as_str()]));
1879 }
1880 }
1881
1882 pub async fn samples_count(
1883 &self,
1884 dataset_id: DatasetID,
1885 annotation_set_id: Option<AnnotationSetID>,
1886 annotation_types: &[AnnotationType],
1887 groups: &[String],
1888 types: &[FileType],
1889 ) -> Result<SamplesCountResult, Error> {
1890 // Use server type names for API calls (e.g., "box" instead of "box2d")
1891 let types = annotation_types
1892 .iter()
1893 .map(|t| t.as_server_type().to_string())
1894 .chain(types.iter().map(|t| t.to_string()))
1895 .collect::<Vec<_>>();
1896
1897 let params = SamplesListParams {
1898 dataset_id,
1899 annotation_set_id,
1900 group_names: groups.to_vec(),
1901 types,
1902 continue_token: None,
1903 };
1904
1905 self.rpc("samples.count".to_owned(), Some(params)).await
1906 }
1907
1908 pub async fn samples(
1909 &self,
1910 dataset_id: DatasetID,
1911 annotation_set_id: Option<AnnotationSetID>,
1912 annotation_types: &[AnnotationType],
1913 groups: &[String],
1914 types: &[FileType],
1915 progress: Option<Sender<Progress>>,
1916 ) -> Result<Vec<Sample>, Error> {
1917 // Use server type names for API calls (e.g., "box" instead of "box2d")
1918 let types_vec = annotation_types
1919 .iter()
1920 .map(|t| t.as_server_type().to_string())
1921 .chain(types.iter().map(|t| t.to_string()))
1922 .collect::<Vec<_>>();
1923 let labels = self
1924 .labels(dataset_id)
1925 .await?
1926 .into_iter()
1927 .map(|label| (label.name().to_string(), label.index()))
1928 .collect::<HashMap<_, _>>();
1929 let total = self
1930 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1931 .await?
1932 .total as usize;
1933
1934 if total == 0 {
1935 return Ok(vec![]);
1936 }
1937
1938 let context = FetchContext {
1939 dataset_id,
1940 annotation_set_id,
1941 groups,
1942 types: types_vec,
1943 labels: &labels,
1944 };
1945
1946 self.fetch_samples_paginated(context, total, progress).await
1947 }
1948
1949 /// Get all sample names in a dataset.
1950 ///
1951 /// This is an efficient method for checking which samples already exist,
1952 /// useful for resuming interrupted imports. It only retrieves sample names
1953 /// without loading full annotation data.
1954 ///
1955 /// # Arguments
1956 /// * `dataset_id` - The dataset to query
1957 /// * `groups` - Optional group filter (empty = all groups)
1958 /// * `progress` - Optional progress channel
1959 ///
1960 /// # Returns
1961 /// A HashSet of sample names (image_name field) that exist in the dataset.
1962 pub async fn sample_names(
1963 &self,
1964 dataset_id: DatasetID,
1965 groups: &[String],
1966 progress: Option<Sender<Progress>>,
1967 ) -> Result<std::collections::HashSet<String>, Error> {
1968 use std::collections::HashSet;
1969
1970 let total = self
1971 .samples_count(dataset_id, None, &[], groups, &[])
1972 .await?
1973 .total as usize;
1974
1975 if total == 0 {
1976 return Ok(HashSet::new());
1977 }
1978
1979 let mut names = HashSet::with_capacity(total);
1980 let mut continue_token: Option<String> = None;
1981 let mut current = 0;
1982
1983 loop {
1984 let params = SamplesListParams {
1985 dataset_id,
1986 annotation_set_id: None,
1987 types: vec![], // No type filter - we just want names
1988 group_names: groups.to_vec(),
1989 continue_token: continue_token.clone(),
1990 };
1991
1992 let result: SamplesListResult =
1993 self.rpc("samples.list".to_owned(), Some(params)).await?;
1994 current += result.samples.len();
1995 continue_token = result.continue_token;
1996
1997 if result.samples.is_empty() {
1998 break;
1999 }
2000
2001 // Extract sample names (normalized without extension)
2002 for sample in result.samples {
2003 if let Some(name) = sample.name() {
2004 names.insert(name);
2005 }
2006 }
2007
2008 if let Some(ref p) = progress {
2009 let _ = p.send(Progress { current, total }).await;
2010 }
2011
2012 match &continue_token {
2013 Some(token) if !token.is_empty() => continue,
2014 _ => break,
2015 }
2016 }
2017
2018 Ok(names)
2019 }
2020
2021 async fn fetch_samples_paginated(
2022 &self,
2023 context: FetchContext<'_>,
2024 total: usize,
2025 progress: Option<Sender<Progress>>,
2026 ) -> Result<Vec<Sample>, Error> {
2027 let mut samples = vec![];
2028 let mut continue_token: Option<String> = None;
2029 let mut current = 0;
2030
2031 loop {
2032 let params = SamplesListParams {
2033 dataset_id: context.dataset_id,
2034 annotation_set_id: context.annotation_set_id,
2035 types: context.types.clone(),
2036 group_names: context.groups.to_vec(),
2037 continue_token: continue_token.clone(),
2038 };
2039
2040 let result: SamplesListResult =
2041 self.rpc("samples.list".to_owned(), Some(params)).await?;
2042 current += result.samples.len();
2043 continue_token = result.continue_token;
2044
2045 if result.samples.is_empty() {
2046 break;
2047 }
2048
2049 samples.append(
2050 &mut result
2051 .samples
2052 .into_iter()
2053 .map(|s| {
2054 // Use server's frame_number if valid (>= 0 after deserialization)
2055 // Otherwise parse from image_name as fallback
2056 // This ensures we respect explicit frame_number from uploads
2057 // while still handling legacy data that only has filename encoding
2058 let frame_number = s.frame_number.or_else(|| {
2059 Self::parse_frame_from_image_name(
2060 s.image_name.as_ref(),
2061 s.sequence_name.as_ref(),
2062 )
2063 });
2064
2065 let mut anns = s.annotations().to_vec();
2066 for ann in &mut anns {
2067 // Set annotation fields from parent sample
2068 ann.set_name(s.name());
2069 ann.set_group(s.group().cloned());
2070 ann.set_sequence_name(s.sequence_name().cloned());
2071 ann.set_frame_number(frame_number);
2072 Self::set_label_index_from_map(ann, context.labels);
2073 }
2074 s.with_annotations(anns).with_frame_number(frame_number)
2075 })
2076 .collect::<Vec<_>>(),
2077 );
2078
2079 if let Some(progress) = &progress {
2080 let _ = progress.send(Progress { current, total }).await;
2081 }
2082
2083 match &continue_token {
2084 Some(token) if !token.is_empty() => continue,
2085 _ => break,
2086 }
2087 }
2088
2089 drop(progress);
2090 Ok(samples)
2091 }
2092
2093 /// Populates (imports) samples into a dataset using the `samples.populate2`
2094 /// API.
2095 ///
2096 /// This method creates new samples in the specified dataset, optionally
2097 /// with annotations and sensor data files. For each sample, the `files`
2098 /// field is checked for local file paths. If a filename is a valid path
2099 /// to an existing file, the file will be automatically uploaded to S3
2100 /// using presigned URLs returned by the server. The filename in the
2101 /// request is replaced with the basename (path removed) before sending
2102 /// to the server.
2103 ///
2104 /// # Important Notes
2105 ///
2106 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2107 /// annotations. Without it, the server will accept the request but will
2108 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2109 /// available annotation sets for a dataset, or create a new one via the
2110 /// Studio UI.
2111 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2112 /// boxes. Divide pixel coordinates by image width/height before creating
2113 /// [`Box2d`](crate::Box2d) annotations.
2114 /// - **Files are uploaded automatically** when the filename is a valid
2115 /// local path. The method will replace the full path with just the
2116 /// basename before sending to the server.
2117 /// - **Image dimensions are extracted automatically** for image files using
2118 /// the `imagesize` crate. The width/height are sent to the server, but
2119 /// note that the server currently doesn't return these fields when
2120 /// fetching samples back.
2121 /// - **UUIDs are generated automatically** if not provided. If you need
2122 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
2123 /// that the server doesn't currently return UUIDs in sample queries.
2124 ///
2125 /// # Arguments
2126 ///
2127 /// * `dataset_id` - The ID of the dataset to populate
2128 /// * `annotation_set_id` - **Required** if samples contain annotations,
2129 /// otherwise they will be ignored. Query with
2130 /// [`Client::annotation_sets`].
2131 /// * `samples` - Vector of samples to import with metadata and file
2132 /// references. For files, use the full local path - it will be uploaded
2133 /// automatically. UUIDs and image dimensions will be
2134 /// auto-generated/extracted if not provided.
2135 ///
2136 /// # Returns
2137 ///
2138 /// Returns the API result with sample UUIDs and upload status.
2139 ///
2140 /// # Example
2141 ///
2142 /// ```no_run
2143 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2144 ///
2145 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2146 /// # let client = Client::new()?.with_login("user", "pass").await?;
2147 /// # let dataset_id = DatasetID::from(1);
2148 /// // Query available annotation sets for the dataset
2149 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2150 /// let annotation_set_id = annotation_sets
2151 /// .first()
2152 /// .ok_or_else(|| {
2153 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2154 /// })?
2155 /// .id();
2156 ///
2157 /// // Create sample with annotation (UUID will be auto-generated)
2158 /// let mut sample = Sample::new();
2159 /// sample.width = Some(1920);
2160 /// sample.height = Some(1080);
2161 /// sample.group = Some("train".to_string());
2162 ///
2163 /// // Add file - use full path to local file, it will be uploaded automatically
2164 /// sample.files = vec![SampleFile::with_filename(
2165 /// "image".to_string(),
2166 /// "/path/to/image.jpg".to_string(),
2167 /// )];
2168 ///
2169 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2170 /// let mut annotation = Annotation::new();
2171 /// annotation.set_label(Some("person".to_string()));
2172 /// // Normalize pixel coordinates by dividing by image dimensions
2173 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2174 /// annotation.set_box2d(Some(bbox));
2175 /// sample.annotations = vec![annotation];
2176 ///
2177 /// // Populate with annotation_set_id (REQUIRED for annotations)
2178 /// let result = client
2179 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2180 /// .await?;
2181 /// # Ok(())
2182 /// # }
2183 /// ```
2184 pub async fn populate_samples(
2185 &self,
2186 dataset_id: DatasetID,
2187 annotation_set_id: Option<AnnotationSetID>,
2188 samples: Vec<Sample>,
2189 progress: Option<Sender<Progress>>,
2190 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2191 self.populate_samples_with_concurrency(
2192 dataset_id,
2193 annotation_set_id,
2194 samples,
2195 progress,
2196 None,
2197 )
2198 .await
2199 }
2200
2201 /// Populate samples with custom upload concurrency.
2202 ///
2203 /// Same as [`populate_samples`](Self::populate_samples) but allows
2204 /// specifying the maximum number of concurrent file uploads. Use this
2205 /// for bulk imports where higher concurrency can significantly reduce
2206 /// upload time.
2207 pub async fn populate_samples_with_concurrency(
2208 &self,
2209 dataset_id: DatasetID,
2210 annotation_set_id: Option<AnnotationSetID>,
2211 samples: Vec<Sample>,
2212 progress: Option<Sender<Progress>>,
2213 concurrency: Option<usize>,
2214 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2215 use crate::api::SamplesPopulateParams;
2216
2217 // Track which files need to be uploaded
2218 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
2219
2220 // Process samples to detect local files and generate UUIDs
2221 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
2222
2223 let has_files_to_upload = !files_to_upload.is_empty();
2224
2225 // Call populate API with presigned_urls=true if we have files to upload
2226 let params = SamplesPopulateParams {
2227 dataset_id,
2228 annotation_set_id,
2229 presigned_urls: Some(has_files_to_upload),
2230 samples,
2231 };
2232
2233 let results: Vec<crate::SamplesPopulateResult> = self
2234 .rpc("samples.populate2".to_owned(), Some(params))
2235 .await?;
2236
2237 // Upload files if we have any
2238 if has_files_to_upload {
2239 self.upload_sample_files(&results, files_to_upload, progress, concurrency)
2240 .await?;
2241 }
2242
2243 Ok(results)
2244 }
2245
2246 fn prepare_samples_for_upload(
2247 &self,
2248 samples: Vec<Sample>,
2249 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
2250 ) -> Result<Vec<Sample>, Error> {
2251 Ok(samples
2252 .into_iter()
2253 .map(|mut sample| {
2254 // Generate UUID if not provided
2255 if sample.uuid.is_none() {
2256 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2257 }
2258
2259 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2260
2261 // Process files: detect local paths and queue for upload
2262 let files_copy = sample.files.clone();
2263 let updated_files: Vec<crate::SampleFile> = files_copy
2264 .iter()
2265 .map(|file| {
2266 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2267 })
2268 .collect();
2269
2270 sample.files = updated_files;
2271 sample
2272 })
2273 .collect())
2274 }
2275
2276 fn process_sample_file(
2277 &self,
2278 file: &crate::SampleFile,
2279 sample_uuid: &str,
2280 sample: &mut Sample,
2281 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
2282 ) -> crate::SampleFile {
2283 use std::path::Path;
2284
2285 if let Some(filename) = file.filename() {
2286 let path = Path::new(filename);
2287
2288 // Check if this is a valid local file path
2289 if path.exists()
2290 && path.is_file()
2291 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2292 {
2293 // For image files, try to extract dimensions if not already set
2294 if file.file_type() == "image"
2295 && (sample.width.is_none() || sample.height.is_none())
2296 && let Ok(size) = imagesize::size(path)
2297 {
2298 sample.width = Some(size.width as u32);
2299 sample.height = Some(size.height as u32);
2300 }
2301
2302 // Store the full path for later upload
2303 files_to_upload.push((
2304 sample_uuid.to_string(),
2305 file.file_type().to_string(),
2306 path.to_path_buf(),
2307 basename.to_string(),
2308 ));
2309
2310 // Return SampleFile with just the basename
2311 return crate::SampleFile::with_filename(
2312 file.file_type().to_string(),
2313 basename.to_string(),
2314 );
2315 }
2316 }
2317 // Return the file unchanged if not a local path
2318 file.clone()
2319 }
2320
2321 async fn upload_sample_files(
2322 &self,
2323 results: &[crate::SamplesPopulateResult],
2324 files_to_upload: Vec<(String, String, PathBuf, String)>,
2325 progress: Option<Sender<Progress>>,
2326 concurrency: Option<usize>,
2327 ) -> Result<(), Error> {
2328 // Build a map from (sample_uuid, basename) -> local_path
2329 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
2330 for (uuid, _file_type, path, basename) in files_to_upload {
2331 upload_map.insert((uuid, basename), path);
2332 }
2333
2334 let http = self.http.clone();
2335
2336 // Extract the data we need for parallel upload
2337 let upload_tasks: Vec<_> = results
2338 .iter()
2339 .map(|result| (result.uuid.clone(), result.urls.clone()))
2340 .collect();
2341
2342 parallel_foreach_items(
2343 upload_tasks,
2344 progress.clone(),
2345 concurrency,
2346 move |(uuid, urls)| {
2347 let http = http.clone();
2348 let upload_map = upload_map.clone();
2349
2350 async move {
2351 // Upload all files for this sample
2352 for url_info in &urls {
2353 if let Some(local_path) =
2354 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2355 {
2356 // Upload the file
2357 upload_file_to_presigned_url(
2358 http.clone(),
2359 &url_info.url,
2360 local_path.clone(),
2361 )
2362 .await?;
2363 }
2364 }
2365
2366 Ok(())
2367 }
2368 },
2369 )
2370 .await
2371 }
2372
2373 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
2374 // Uses default 120s timeout from client
2375 let resp = self.http.get(url).send().await?;
2376
2377 if !resp.status().is_success() {
2378 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2379 }
2380
2381 let bytes = resp.bytes().await?;
2382 Ok(bytes.to_vec())
2383 }
2384
2385 /// Get the AnnotationGroup for the specified annotation set with the
2386 /// requested annotation types. The annotation type is used to filter
2387 /// the annotations returned. Images which do not have any annotations
2388 /// are included in the result.
2389 ///
2390 /// Get annotations as a DataFrame (2025.01 schema).
2391 ///
2392 /// **DEPRECATED**: Use [`Client::samples_dataframe()`] instead for full
2393 /// 2025.10 schema support including optional metadata columns.
2394 ///
2395 /// The result is a DataFrame following the EdgeFirst Dataset Format
2396 /// definition with 9 columns (original schema). Does not include new
2397 /// optional columns added in 2025.10.
2398 ///
2399 /// # Migration
2400 ///
2401 /// ```rust,no_run
2402 /// # use edgefirst_client::Client;
2403 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2404 /// # let client = Client::new()?;
2405 /// # let dataset_id = 1.into();
2406 /// # let annotation_set_id = 1.into();
2407 /// # let groups = vec![];
2408 /// # let types = vec![];
2409 /// // OLD (deprecated):
2410 /// let df = client
2411 /// .annotations_dataframe(annotation_set_id, &groups, &types, None)
2412 /// .await?;
2413 ///
2414 /// // NEW (recommended):
2415 /// let df = client
2416 /// .samples_dataframe(dataset_id, Some(annotation_set_id), &groups, &types, None)
2417 /// .await?;
2418 /// # Ok(())
2419 /// # }
2420 /// ```
2421 ///
2422 /// To get the annotations as a vector of Annotation objects, use the
2423 /// `annotations` method instead.
2424 #[deprecated(
2425 since = "0.8.0",
2426 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
2427 )]
2428 #[cfg(feature = "polars")]
2429 pub async fn annotations_dataframe(
2430 &self,
2431 annotation_set_id: AnnotationSetID,
2432 groups: &[String],
2433 types: &[AnnotationType],
2434 progress: Option<Sender<Progress>>,
2435 ) -> Result<DataFrame, Error> {
2436 #[allow(deprecated)]
2437 use crate::dataset::annotations_dataframe;
2438
2439 let annotations = self
2440 .annotations(annotation_set_id, groups, types, progress)
2441 .await?;
2442 #[allow(deprecated)]
2443 annotations_dataframe(&annotations)
2444 }
2445
2446 /// Get samples as a DataFrame with complete 2025.10 schema.
2447 ///
2448 /// This is the recommended method for obtaining dataset annotations in
2449 /// DataFrame format. It includes all sample metadata (size, location,
2450 /// pose, degradation) as optional columns.
2451 ///
2452 /// # Arguments
2453 ///
2454 /// * `dataset_id` - Dataset identifier
2455 /// * `annotation_set_id` - Optional annotation set filter
2456 /// * `groups` - Dataset groups to include (train, val, test)
2457 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2458 /// * `progress` - Optional progress callback
2459 ///
2460 /// # Example
2461 ///
2462 /// ```rust,no_run
2463 /// use edgefirst_client::Client;
2464 ///
2465 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2466 /// # let client = Client::new()?;
2467 /// # let dataset_id = 1.into();
2468 /// # let annotation_set_id = 1.into();
2469 /// let df = client
2470 /// .samples_dataframe(
2471 /// dataset_id,
2472 /// Some(annotation_set_id),
2473 /// &["train".to_string()],
2474 /// &[],
2475 /// None,
2476 /// )
2477 /// .await?;
2478 /// println!("DataFrame shape: {:?}", df.shape());
2479 /// # Ok(())
2480 /// # }
2481 /// ```
2482 #[cfg(feature = "polars")]
2483 pub async fn samples_dataframe(
2484 &self,
2485 dataset_id: DatasetID,
2486 annotation_set_id: Option<AnnotationSetID>,
2487 groups: &[String],
2488 types: &[AnnotationType],
2489 progress: Option<Sender<Progress>>,
2490 ) -> Result<DataFrame, Error> {
2491 use crate::dataset::samples_dataframe;
2492
2493 let samples = self
2494 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
2495 .await?;
2496 samples_dataframe(&samples)
2497 }
2498
2499 /// List available snapshots. If a name is provided, only snapshots
2500 /// containing that name are returned.
2501 ///
2502 /// Results are sorted by match quality: exact matches first, then
2503 /// case-insensitive exact matches, then shorter descriptions (more
2504 /// specific), then alphabetically.
2505 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
2506 let snapshots: Vec<Snapshot> = self
2507 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
2508 .await?;
2509 if let Some(name) = name {
2510 Ok(filter_and_sort_by_name(snapshots, name, |s| {
2511 s.description()
2512 }))
2513 } else {
2514 Ok(snapshots)
2515 }
2516 }
2517
2518 /// Get the snapshot with the specified id.
2519 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
2520 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2521 self.rpc("snapshots.get".to_owned(), Some(params)).await
2522 }
2523
2524 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
2525 ///
2526 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
2527 /// pairs) that serve two primary purposes:
2528 ///
2529 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
2530 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
2531 /// restored with AGTG (Automatic Ground Truth Generation) and optional
2532 /// auto-depth processing.
2533 ///
2534 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
2535 /// migration between EdgeFirst Studio instances using the create →
2536 /// download → upload → restore workflow.
2537 ///
2538 /// Large files are automatically chunked into 100MB parts and uploaded
2539 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
2540 /// is streamed without loading into memory, maintaining constant memory
2541 /// usage.
2542 ///
2543 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2544 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
2545 /// better for large files to avoid timeout issues. Higher values (16-32)
2546 /// are better for many small files.
2547 ///
2548 /// # Arguments
2549 ///
2550 /// * `path` - Local file path to MCAP file or directory containing
2551 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
2552 /// * `progress` - Optional channel to receive upload progress updates
2553 ///
2554 /// # Returns
2555 ///
2556 /// Returns a `Snapshot` object with ID, description, status, path, and
2557 /// creation timestamp on success.
2558 ///
2559 /// # Errors
2560 ///
2561 /// Returns an error if:
2562 /// * Path doesn't exist or contains invalid UTF-8
2563 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
2564 /// * Upload fails or network error occurs
2565 /// * Server rejects the snapshot
2566 ///
2567 /// # Example
2568 ///
2569 /// ```no_run
2570 /// # use edgefirst_client::{Client, Progress};
2571 /// # use tokio::sync::mpsc;
2572 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2573 /// let client = Client::new()?.with_token_path(None)?;
2574 ///
2575 /// // Upload MCAP file with progress tracking
2576 /// let (tx, mut rx) = mpsc::channel(1);
2577 /// tokio::spawn(async move {
2578 /// while let Some(Progress { current, total }) = rx.recv().await {
2579 /// println!(
2580 /// "Upload: {}/{} bytes ({:.1}%)",
2581 /// current,
2582 /// total,
2583 /// (current as f64 / total as f64) * 100.0
2584 /// );
2585 /// }
2586 /// });
2587 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
2588 /// println!("Created snapshot: {:?}", snapshot.id());
2589 ///
2590 /// // Upload dataset directory (no progress)
2591 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
2592 /// # Ok(())
2593 /// # }
2594 /// ```
2595 ///
2596 /// # See Also
2597 ///
2598 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2599 /// dataset
2600 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2601 /// data
2602 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2603 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2604 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
2605 pub async fn create_snapshot(
2606 &self,
2607 path: &str,
2608 progress: Option<Sender<Progress>>,
2609 ) -> Result<Snapshot, Error> {
2610 let path = Path::new(path);
2611
2612 if path.is_dir() {
2613 let path_str = path.to_str().ok_or_else(|| {
2614 Error::IoError(std::io::Error::new(
2615 std::io::ErrorKind::InvalidInput,
2616 "Path contains invalid UTF-8",
2617 ))
2618 })?;
2619 return self.create_snapshot_folder(path_str, progress).await;
2620 }
2621
2622 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2623 Error::IoError(std::io::Error::new(
2624 std::io::ErrorKind::InvalidInput,
2625 "Invalid filename",
2626 ))
2627 })?;
2628 let total = path.metadata()?.len() as usize;
2629 let current = Arc::new(AtomicUsize::new(0));
2630
2631 if let Some(progress) = &progress {
2632 let _ = progress.send(Progress { current: 0, total }).await;
2633 }
2634
2635 let params = SnapshotCreateMultipartParams {
2636 snapshot_name: name.to_owned(),
2637 keys: vec![name.to_owned()],
2638 file_sizes: vec![total],
2639 };
2640 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2641 .rpc(
2642 "snapshots.create_upload_url_multipart".to_owned(),
2643 Some(params),
2644 )
2645 .await?;
2646
2647 let snapshot_id = match multipart.get("snapshot_id") {
2648 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2649 _ => return Err(Error::InvalidResponse),
2650 };
2651
2652 let snapshot = self.snapshot(snapshot_id).await?;
2653 let part_prefix = snapshot
2654 .path()
2655 .split("::/")
2656 .last()
2657 .ok_or(Error::InvalidResponse)?
2658 .to_owned();
2659 let part_key = format!("{}/{}", part_prefix, name);
2660 let mut part = match multipart.get(&part_key) {
2661 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2662 _ => return Err(Error::InvalidResponse),
2663 }
2664 .clone();
2665 part.key = Some(part_key);
2666
2667 let params = upload_multipart(
2668 self.http.clone(),
2669 part.clone(),
2670 path.to_path_buf(),
2671 total,
2672 current,
2673 progress.clone(),
2674 )
2675 .await?;
2676
2677 let complete: String = self
2678 .rpc(
2679 "snapshots.complete_multipart_upload".to_owned(),
2680 Some(params),
2681 )
2682 .await?;
2683 debug!("Snapshot Multipart Complete: {:?}", complete);
2684
2685 let params: SnapshotStatusParams = SnapshotStatusParams {
2686 snapshot_id,
2687 status: "available".to_owned(),
2688 };
2689 let _: SnapshotStatusResult = self
2690 .rpc("snapshots.update".to_owned(), Some(params))
2691 .await?;
2692
2693 if let Some(progress) = progress {
2694 drop(progress);
2695 }
2696
2697 self.snapshot(snapshot_id).await
2698 }
2699
2700 async fn create_snapshot_folder(
2701 &self,
2702 path: &str,
2703 progress: Option<Sender<Progress>>,
2704 ) -> Result<Snapshot, Error> {
2705 let path = Path::new(path);
2706 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2707 Error::IoError(std::io::Error::new(
2708 std::io::ErrorKind::InvalidInput,
2709 "Invalid directory name",
2710 ))
2711 })?;
2712
2713 let files = WalkDir::new(path)
2714 .into_iter()
2715 .filter_map(|entry| entry.ok())
2716 .filter(|entry| entry.file_type().is_file())
2717 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
2718 .collect::<Vec<_>>();
2719
2720 let total: usize = files
2721 .iter()
2722 .filter_map(|file| path.join(file).metadata().ok())
2723 .map(|metadata| metadata.len() as usize)
2724 .sum();
2725 let current = Arc::new(AtomicUsize::new(0));
2726
2727 if let Some(progress) = &progress {
2728 let _ = progress.send(Progress { current: 0, total }).await;
2729 }
2730
2731 let keys = files
2732 .iter()
2733 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
2734 .collect::<Vec<_>>();
2735 let file_sizes = files
2736 .iter()
2737 .filter_map(|key| path.join(key).metadata().ok())
2738 .map(|metadata| metadata.len() as usize)
2739 .collect::<Vec<_>>();
2740
2741 let params = SnapshotCreateMultipartParams {
2742 snapshot_name: name.to_owned(),
2743 keys,
2744 file_sizes,
2745 };
2746
2747 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2748 .rpc(
2749 "snapshots.create_upload_url_multipart".to_owned(),
2750 Some(params),
2751 )
2752 .await?;
2753
2754 let snapshot_id = match multipart.get("snapshot_id") {
2755 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2756 _ => return Err(Error::InvalidResponse),
2757 };
2758
2759 let snapshot = self.snapshot(snapshot_id).await?;
2760 let part_prefix = snapshot
2761 .path()
2762 .split("::/")
2763 .last()
2764 .ok_or(Error::InvalidResponse)?
2765 .to_owned();
2766
2767 for file in files {
2768 let file_str = file.to_str().ok_or_else(|| {
2769 Error::IoError(std::io::Error::new(
2770 std::io::ErrorKind::InvalidInput,
2771 "File path contains invalid UTF-8",
2772 ))
2773 })?;
2774 let part_key = format!("{}/{}", part_prefix, file_str);
2775 let mut part = match multipart.get(&part_key) {
2776 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2777 _ => return Err(Error::InvalidResponse),
2778 }
2779 .clone();
2780 part.key = Some(part_key);
2781
2782 let params = upload_multipart(
2783 self.http.clone(),
2784 part.clone(),
2785 path.join(file),
2786 total,
2787 current.clone(),
2788 progress.clone(),
2789 )
2790 .await?;
2791
2792 let complete: String = self
2793 .rpc(
2794 "snapshots.complete_multipart_upload".to_owned(),
2795 Some(params),
2796 )
2797 .await?;
2798 debug!("Snapshot Part Complete: {:?}", complete);
2799 }
2800
2801 let params = SnapshotStatusParams {
2802 snapshot_id,
2803 status: "available".to_owned(),
2804 };
2805 let _: SnapshotStatusResult = self
2806 .rpc("snapshots.update".to_owned(), Some(params))
2807 .await?;
2808
2809 if let Some(progress) = progress {
2810 drop(progress);
2811 }
2812
2813 self.snapshot(snapshot_id).await
2814 }
2815
2816 /// Delete a snapshot from EdgeFirst Studio.
2817 ///
2818 /// Permanently removes a snapshot and its associated data. This operation
2819 /// cannot be undone.
2820 ///
2821 /// # Arguments
2822 ///
2823 /// * `snapshot_id` - The snapshot ID to delete
2824 ///
2825 /// # Errors
2826 ///
2827 /// Returns an error if:
2828 /// * Snapshot doesn't exist
2829 /// * User lacks permission to delete the snapshot
2830 /// * Server error occurs
2831 ///
2832 /// # Example
2833 ///
2834 /// ```no_run
2835 /// # use edgefirst_client::{Client, SnapshotID};
2836 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2837 /// let client = Client::new()?.with_token_path(None)?;
2838 /// let snapshot_id = SnapshotID::from(123);
2839 /// client.delete_snapshot(snapshot_id).await?;
2840 /// # Ok(())
2841 /// # }
2842 /// ```
2843 ///
2844 /// # See Also
2845 ///
2846 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2847 /// * [`snapshots`](Self::snapshots) - List all snapshots
2848 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
2849 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2850 let _: String = self
2851 .rpc("snapshots.delete".to_owned(), Some(params))
2852 .await?;
2853 Ok(())
2854 }
2855
2856 /// Create a snapshot from an existing dataset on the server.
2857 ///
2858 /// Triggers server-side snapshot generation which exports the dataset's
2859 /// images and annotations into a downloadable EdgeFirst Dataset Format
2860 /// snapshot.
2861 ///
2862 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
2863 /// while restore creates a dataset from a snapshot, this method creates a
2864 /// snapshot from a dataset.
2865 ///
2866 /// # Arguments
2867 ///
2868 /// * `dataset_id` - The dataset ID to create snapshot from
2869 /// * `description` - Description for the created snapshot
2870 ///
2871 /// # Returns
2872 ///
2873 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
2874 /// for monitoring progress.
2875 ///
2876 /// # Errors
2877 ///
2878 /// Returns an error if:
2879 /// * Dataset doesn't exist
2880 /// * User lacks permission to access the dataset
2881 /// * Server rejects the request
2882 ///
2883 /// # Example
2884 ///
2885 /// ```no_run
2886 /// # use edgefirst_client::{Client, DatasetID};
2887 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2888 /// let client = Client::new()?.with_token_path(None)?;
2889 /// let dataset_id = DatasetID::from(123);
2890 ///
2891 /// // Create snapshot from dataset (all annotation sets)
2892 /// let result = client
2893 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
2894 /// .await?;
2895 /// println!("Created snapshot: {:?}", result.id);
2896 ///
2897 /// // Monitor progress via task ID
2898 /// if let Some(task_id) = result.task_id {
2899 /// println!("Task: {}", task_id);
2900 /// }
2901 /// # Ok(())
2902 /// # }
2903 /// ```
2904 ///
2905 /// # See Also
2906 ///
2907 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
2908 /// snapshot
2909 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2910 /// dataset
2911 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2912 pub async fn create_snapshot_from_dataset(
2913 &self,
2914 dataset_id: DatasetID,
2915 description: &str,
2916 annotation_set_id: Option<AnnotationSetID>,
2917 ) -> Result<SnapshotFromDatasetResult, Error> {
2918 // Resolve annotation_set_id: use provided value or fetch default
2919 let annotation_set_id = match annotation_set_id {
2920 Some(id) => id,
2921 None => {
2922 // Fetch annotation sets and find default ("annotations") or use first
2923 let sets = self.annotation_sets(dataset_id).await?;
2924 if sets.is_empty() {
2925 return Err(Error::InvalidParameters(
2926 "No annotation sets available for dataset".to_owned(),
2927 ));
2928 }
2929 // Look for "annotations" set (default), otherwise use first
2930 sets.iter()
2931 .find(|s| s.name() == "annotations")
2932 .unwrap_or(&sets[0])
2933 .id()
2934 }
2935 };
2936 let params = SnapshotCreateFromDataset {
2937 description: description.to_owned(),
2938 dataset_id,
2939 annotation_set_id,
2940 };
2941 self.rpc("snapshots.create".to_owned(), Some(params)).await
2942 }
2943
2944 /// Download a snapshot from EdgeFirst Studio to local storage.
2945 ///
2946 /// Downloads all files in a snapshot (single MCAP file or directory of
2947 /// EdgeFirst Dataset Format files) to the specified output path. Files are
2948 /// downloaded concurrently with progress tracking.
2949 ///
2950 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2951 /// downloads (default: half of CPU cores, min 2, max 8).
2952 ///
2953 /// # Arguments
2954 ///
2955 /// * `snapshot_id` - The snapshot ID to download
2956 /// * `output` - Local directory path to save downloaded files
2957 /// * `progress` - Optional channel to receive download progress updates
2958 ///
2959 /// # Errors
2960 ///
2961 /// Returns an error if:
2962 /// * Snapshot doesn't exist
2963 /// * Output directory cannot be created
2964 /// * Download fails or network error occurs
2965 ///
2966 /// # Example
2967 ///
2968 /// ```no_run
2969 /// # use edgefirst_client::{Client, SnapshotID, Progress};
2970 /// # use tokio::sync::mpsc;
2971 /// # use std::path::PathBuf;
2972 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2973 /// let client = Client::new()?.with_token_path(None)?;
2974 /// let snapshot_id = SnapshotID::from(123);
2975 ///
2976 /// // Download with progress tracking
2977 /// let (tx, mut rx) = mpsc::channel(1);
2978 /// tokio::spawn(async move {
2979 /// while let Some(Progress { current, total }) = rx.recv().await {
2980 /// println!("Download: {}/{} bytes", current, total);
2981 /// }
2982 /// });
2983 /// client
2984 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
2985 /// .await?;
2986 /// # Ok(())
2987 /// # }
2988 /// ```
2989 ///
2990 /// # See Also
2991 ///
2992 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2993 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2994 /// dataset
2995 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2996 pub async fn download_snapshot(
2997 &self,
2998 snapshot_id: SnapshotID,
2999 output: PathBuf,
3000 progress: Option<Sender<Progress>>,
3001 ) -> Result<(), Error> {
3002 fs::create_dir_all(&output).await?;
3003
3004 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3005 let items: HashMap<String, String> = self
3006 .rpc("snapshots.create_download_url".to_owned(), Some(params))
3007 .await?;
3008
3009 let total = Arc::new(AtomicUsize::new(0));
3010 let current = Arc::new(AtomicUsize::new(0));
3011 let sem = Arc::new(Semaphore::new(max_tasks()));
3012
3013 let tasks = items
3014 .iter()
3015 .map(|(key, url)| {
3016 let http = self.http.clone();
3017 let key = key.clone();
3018 let url = url.clone();
3019 let output = output.clone();
3020 let progress = progress.clone();
3021 let current = current.clone();
3022 let total = total.clone();
3023 let sem = sem.clone();
3024
3025 tokio::spawn(async move {
3026 let _permit = sem.acquire().await.map_err(|_| {
3027 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3028 })?;
3029 let res = http.get(url).send().await?;
3030 let content_length = res.content_length().unwrap_or(0) as usize;
3031
3032 if let Some(progress) = &progress {
3033 let total = total.fetch_add(content_length, Ordering::SeqCst);
3034 let _ = progress
3035 .send(Progress {
3036 current: current.load(Ordering::SeqCst),
3037 total: total + content_length,
3038 })
3039 .await;
3040 }
3041
3042 let mut file = File::create(output.join(key)).await?;
3043 let mut stream = res.bytes_stream();
3044
3045 while let Some(chunk) = stream.next().await {
3046 let chunk = chunk?;
3047 file.write_all(&chunk).await?;
3048 let len = chunk.len();
3049
3050 if let Some(progress) = &progress {
3051 let total = total.load(Ordering::SeqCst);
3052 let current = current.fetch_add(len, Ordering::SeqCst);
3053
3054 let _ = progress
3055 .send(Progress {
3056 current: current + len,
3057 total,
3058 })
3059 .await;
3060 }
3061 }
3062
3063 Ok::<(), Error>(())
3064 })
3065 })
3066 .collect::<Vec<_>>();
3067
3068 join_all(tasks)
3069 .await
3070 .into_iter()
3071 .collect::<Result<Vec<_>, _>>()?
3072 .into_iter()
3073 .collect::<Result<Vec<_>, _>>()?;
3074
3075 Ok(())
3076 }
3077
3078 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
3079 ///
3080 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
3081 /// the specified project. For MCAP files, supports:
3082 ///
3083 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
3084 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
3085 /// present)
3086 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
3087 /// * **Topic filtering**: Select specific MCAP topics to restore
3088 ///
3089 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
3090 /// dataset structure.
3091 ///
3092 /// # Arguments
3093 ///
3094 /// * `project_id` - Target project ID
3095 /// * `snapshot_id` - Snapshot ID to restore
3096 /// * `topics` - MCAP topics to include (empty = all topics)
3097 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
3098 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
3099 /// * `dataset_name` - Optional custom dataset name
3100 /// * `dataset_description` - Optional dataset description
3101 ///
3102 /// # Returns
3103 ///
3104 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
3105 ///
3106 /// # Errors
3107 ///
3108 /// Returns an error if:
3109 /// * Snapshot or project doesn't exist
3110 /// * Snapshot format is invalid
3111 /// * Server rejects restoration parameters
3112 ///
3113 /// # Example
3114 ///
3115 /// ```no_run
3116 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
3117 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3118 /// let client = Client::new()?.with_token_path(None)?;
3119 /// let project_id = ProjectID::from(1);
3120 /// let snapshot_id = SnapshotID::from(123);
3121 ///
3122 /// // Restore MCAP with AGTG for "person" and "car" detection
3123 /// let result = client
3124 /// .restore_snapshot(
3125 /// project_id,
3126 /// snapshot_id,
3127 /// &[], // All topics
3128 /// &["person".to_string(), "car".to_string()], // AGTG labels
3129 /// true, // Auto-depth
3130 /// Some("Highway Dataset"),
3131 /// Some("Collected on I-95"),
3132 /// )
3133 /// .await?;
3134 /// println!("Restored to dataset: {:?}", result.dataset_id);
3135 /// # Ok(())
3136 /// # }
3137 /// ```
3138 ///
3139 /// # See Also
3140 ///
3141 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3142 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3143 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3144 #[allow(clippy::too_many_arguments)]
3145 pub async fn restore_snapshot(
3146 &self,
3147 project_id: ProjectID,
3148 snapshot_id: SnapshotID,
3149 topics: &[String],
3150 autolabel: &[String],
3151 autodepth: bool,
3152 dataset_name: Option<&str>,
3153 dataset_description: Option<&str>,
3154 ) -> Result<SnapshotRestoreResult, Error> {
3155 let params = SnapshotRestore {
3156 project_id,
3157 snapshot_id,
3158 fps: 1,
3159 autodepth,
3160 agtg_pipeline: !autolabel.is_empty(),
3161 autolabel: autolabel.to_vec(),
3162 topics: topics.to_vec(),
3163 dataset_name: dataset_name.map(|s| s.to_owned()),
3164 dataset_description: dataset_description.map(|s| s.to_owned()),
3165 };
3166 self.rpc("snapshots.restore".to_owned(), Some(params)).await
3167 }
3168
3169 /// Returns a list of experiments available to the user. The experiments
3170 /// are returned as a vector of Experiment objects. If name is provided
3171 /// then only experiments containing this string are returned.
3172 ///
3173 /// Results are sorted by match quality: exact matches first, then
3174 /// case-insensitive exact matches, then shorter names (more specific),
3175 /// then alphabetically.
3176 ///
3177 /// Experiments provide a method of organizing training and validation
3178 /// sessions together and are akin to an Experiment in MLFlow terminology.
3179 /// Each experiment can have multiple trainer sessions associated with it,
3180 /// these would be akin to runs in MLFlow terminology.
3181 pub async fn experiments(
3182 &self,
3183 project_id: ProjectID,
3184 name: Option<&str>,
3185 ) -> Result<Vec<Experiment>, Error> {
3186 let params = HashMap::from([("project_id", project_id)]);
3187 let experiments: Vec<Experiment> =
3188 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
3189 if let Some(name) = name {
3190 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
3191 } else {
3192 Ok(experiments)
3193 }
3194 }
3195
3196 /// Return the experiment with the specified experiment ID. If the
3197 /// experiment does not exist, an error is returned.
3198 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
3199 let params = HashMap::from([("trainer_id", experiment_id)]);
3200 self.rpc("trainer.get".to_owned(), Some(params)).await
3201 }
3202
3203 /// Returns a list of trainer sessions available to the user. The trainer
3204 /// sessions are returned as a vector of TrainingSession objects. If name
3205 /// is provided then only trainer sessions containing this string are
3206 /// returned.
3207 ///
3208 /// Results are sorted by match quality: exact matches first, then
3209 /// case-insensitive exact matches, then shorter names (more specific),
3210 /// then alphabetically.
3211 ///
3212 /// Trainer sessions are akin to runs in MLFlow terminology. These
3213 /// represent an actual training session which will produce metrics and
3214 /// model artifacts.
3215 pub async fn training_sessions(
3216 &self,
3217 experiment_id: ExperimentID,
3218 name: Option<&str>,
3219 ) -> Result<Vec<TrainingSession>, Error> {
3220 let params = HashMap::from([("trainer_id", experiment_id)]);
3221 let sessions: Vec<TrainingSession> = self
3222 .rpc("trainer.session.list".to_owned(), Some(params))
3223 .await?;
3224 if let Some(name) = name {
3225 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
3226 } else {
3227 Ok(sessions)
3228 }
3229 }
3230
3231 /// Return the trainer session with the specified trainer session ID. If
3232 /// the trainer session does not exist, an error is returned.
3233 pub async fn training_session(
3234 &self,
3235 session_id: TrainingSessionID,
3236 ) -> Result<TrainingSession, Error> {
3237 let params = HashMap::from([("trainer_session_id", session_id)]);
3238 self.rpc("trainer.session.get".to_owned(), Some(params))
3239 .await
3240 }
3241
3242 /// List validation sessions for the given project.
3243 pub async fn validation_sessions(
3244 &self,
3245 project_id: ProjectID,
3246 ) -> Result<Vec<ValidationSession>, Error> {
3247 let params = HashMap::from([("project_id", project_id)]);
3248 self.rpc("validate.session.list".to_owned(), Some(params))
3249 .await
3250 }
3251
3252 /// Retrieve a specific validation session.
3253 pub async fn validation_session(
3254 &self,
3255 session_id: ValidationSessionID,
3256 ) -> Result<ValidationSession, Error> {
3257 let params = HashMap::from([("validate_session_id", session_id)]);
3258 self.rpc("validate.session.get".to_owned(), Some(params))
3259 .await
3260 }
3261
3262 /// List the artifacts for the specified trainer session. The artifacts
3263 /// are returned as a vector of strings.
3264 pub async fn artifacts(
3265 &self,
3266 training_session_id: TrainingSessionID,
3267 ) -> Result<Vec<Artifact>, Error> {
3268 let params = HashMap::from([("training_session_id", training_session_id)]);
3269 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
3270 .await
3271 }
3272
3273 /// Download the model artifact for the specified trainer session to the
3274 /// specified file path, if path is not provided it will be downloaded to
3275 /// the current directory with the same filename. A progress callback can
3276 /// be provided to monitor the progress of the download over a watch
3277 /// channel.
3278 pub async fn download_artifact(
3279 &self,
3280 training_session_id: TrainingSessionID,
3281 modelname: &str,
3282 filename: Option<PathBuf>,
3283 progress: Option<Sender<Progress>>,
3284 ) -> Result<(), Error> {
3285 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
3286 let resp = self
3287 .http
3288 .get(format!(
3289 "{}/download_model?training_session_id={}&file={}",
3290 self.url,
3291 training_session_id.value(),
3292 modelname
3293 ))
3294 .header("Authorization", format!("Bearer {}", self.token().await))
3295 .send()
3296 .await?;
3297 if !resp.status().is_success() {
3298 let err = resp.error_for_status_ref().unwrap_err();
3299 return Err(Error::HttpError(err));
3300 }
3301
3302 if let Some(parent) = filename.parent() {
3303 fs::create_dir_all(parent).await?;
3304 }
3305
3306 if let Some(progress) = progress {
3307 let total = resp.content_length().unwrap_or(0) as usize;
3308 let _ = progress.send(Progress { current: 0, total }).await;
3309
3310 let mut file = File::create(filename).await?;
3311 let mut current = 0;
3312 let mut stream = resp.bytes_stream();
3313
3314 while let Some(item) = stream.next().await {
3315 let chunk = item?;
3316 file.write_all(&chunk).await?;
3317 current += chunk.len();
3318 let _ = progress.send(Progress { current, total }).await;
3319 }
3320 } else {
3321 let body = resp.bytes().await?;
3322 fs::write(filename, body).await?;
3323 }
3324
3325 Ok(())
3326 }
3327
3328 /// Download the model checkpoint associated with the specified trainer
3329 /// session to the specified file path, if path is not provided it will be
3330 /// downloaded to the current directory with the same filename. A progress
3331 /// callback can be provided to monitor the progress of the download over a
3332 /// watch channel.
3333 ///
3334 /// There is no API for listing checkpoints it is expected that trainers are
3335 /// aware of possible checkpoints and their names within the checkpoint
3336 /// folder on the server.
3337 pub async fn download_checkpoint(
3338 &self,
3339 training_session_id: TrainingSessionID,
3340 checkpoint: &str,
3341 filename: Option<PathBuf>,
3342 progress: Option<Sender<Progress>>,
3343 ) -> Result<(), Error> {
3344 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
3345 let resp = self
3346 .http
3347 .get(format!(
3348 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
3349 self.url,
3350 training_session_id.value(),
3351 checkpoint
3352 ))
3353 .header("Authorization", format!("Bearer {}", self.token().await))
3354 .send()
3355 .await?;
3356 if !resp.status().is_success() {
3357 let err = resp.error_for_status_ref().unwrap_err();
3358 return Err(Error::HttpError(err));
3359 }
3360
3361 if let Some(parent) = filename.parent() {
3362 fs::create_dir_all(parent).await?;
3363 }
3364
3365 if let Some(progress) = progress {
3366 let total = resp.content_length().unwrap_or(0) as usize;
3367 let _ = progress.send(Progress { current: 0, total }).await;
3368
3369 let mut file = File::create(filename).await?;
3370 let mut current = 0;
3371 let mut stream = resp.bytes_stream();
3372
3373 while let Some(item) = stream.next().await {
3374 let chunk = item?;
3375 file.write_all(&chunk).await?;
3376 current += chunk.len();
3377 let _ = progress.send(Progress { current, total }).await;
3378 }
3379 } else {
3380 let body = resp.bytes().await?;
3381 fs::write(filename, body).await?;
3382 }
3383
3384 Ok(())
3385 }
3386
3387 /// Return a list of tasks for the current user.
3388 ///
3389 /// # Arguments
3390 ///
3391 /// * `name` - Optional filter for task name (client-side substring match)
3392 /// * `workflow` - Optional filter for workflow/task type. If provided,
3393 /// filters server-side by exact match. Valid values include: "trainer",
3394 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
3395 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
3396 /// "convertor", "twostage"
3397 /// * `status` - Optional filter for task status (e.g., "running",
3398 /// "complete", "error")
3399 /// * `manager` - Optional filter for task manager type (e.g., "aws",
3400 /// "user", "kubernetes")
3401 pub async fn tasks(
3402 &self,
3403 name: Option<&str>,
3404 workflow: Option<&str>,
3405 status: Option<&str>,
3406 manager: Option<&str>,
3407 ) -> Result<Vec<Task>, Error> {
3408 let mut params = TasksListParams {
3409 continue_token: None,
3410 types: workflow.map(|w| vec![w.to_owned()]),
3411 status: status.map(|s| vec![s.to_owned()]),
3412 manager: manager.map(|m| vec![m.to_owned()]),
3413 };
3414 let mut tasks = Vec::new();
3415
3416 loop {
3417 let result = self
3418 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
3419 .await?;
3420 tasks.extend(result.tasks);
3421
3422 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
3423 params.continue_token = None;
3424 } else {
3425 params.continue_token = result.continue_token;
3426 }
3427
3428 if params.continue_token.is_none() {
3429 break;
3430 }
3431 }
3432
3433 if let Some(name) = name {
3434 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
3435 }
3436
3437 Ok(tasks)
3438 }
3439
3440 /// Retrieve the task information and status.
3441 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
3442 self.rpc(
3443 "task.get".to_owned(),
3444 Some(HashMap::from([("id", task_id)])),
3445 )
3446 .await
3447 }
3448
3449 /// Updates the tasks status.
3450 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
3451 let status = TaskStatus {
3452 task_id,
3453 status: status.to_owned(),
3454 };
3455 self.rpc("docker.update.status".to_owned(), Some(status))
3456 .await
3457 }
3458
3459 /// Defines the stages for the task. The stages are defined as a mapping
3460 /// from stage names to their descriptions. Once stages are defined their
3461 /// status can be updated using the update_stage method.
3462 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
3463 let stages: Vec<HashMap<String, String>> = stages
3464 .iter()
3465 .map(|(key, value)| {
3466 let mut stage_map = HashMap::new();
3467 stage_map.insert(key.to_string(), value.to_string());
3468 stage_map
3469 })
3470 .collect();
3471 let params = TaskStages { task_id, stages };
3472 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
3473 Ok(())
3474 }
3475
3476 /// Updates the progress of the task for the provided stage and status
3477 /// information.
3478 pub async fn update_stage(
3479 &self,
3480 task_id: TaskID,
3481 stage: &str,
3482 status: &str,
3483 message: &str,
3484 percentage: u8,
3485 ) -> Result<(), Error> {
3486 let stage = Stage::new(
3487 Some(task_id),
3488 stage.to_owned(),
3489 Some(status.to_owned()),
3490 Some(message.to_owned()),
3491 percentage,
3492 );
3493 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
3494 Ok(())
3495 }
3496
3497 /// Raw fetch from the Studio server is used for downloading files.
3498 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
3499 let req = self
3500 .http
3501 .get(format!("{}/{}", self.url, query))
3502 .header("User-Agent", "EdgeFirst Client")
3503 .header("Authorization", format!("Bearer {}", self.token().await));
3504 let resp = req.send().await?;
3505
3506 if resp.status().is_success() {
3507 let body = resp.bytes().await?;
3508
3509 if log_enabled!(Level::Trace) {
3510 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
3511 }
3512
3513 Ok(body.to_vec())
3514 } else {
3515 let err = resp.error_for_status_ref().unwrap_err();
3516 Err(Error::HttpError(err))
3517 }
3518 }
3519
3520 /// Sends a multipart post request to the server. This is used by the
3521 /// upload and download APIs which do not use JSON-RPC but instead transfer
3522 /// files using multipart/form-data.
3523 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
3524 let req = self
3525 .http
3526 .post(format!("{}/api?method={}", self.url, method))
3527 .header("Accept", "application/json")
3528 .header("User-Agent", "EdgeFirst Client")
3529 .header("Authorization", format!("Bearer {}", self.token().await))
3530 .multipart(form);
3531 let resp = req.send().await?;
3532
3533 if resp.status().is_success() {
3534 let body = resp.bytes().await?;
3535
3536 if log_enabled!(Level::Trace) {
3537 trace!(
3538 "POST Multipart Response: {}",
3539 String::from_utf8_lossy(&body)
3540 );
3541 }
3542
3543 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
3544 Ok(response) => response,
3545 Err(err) => {
3546 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3547 return Err(err.into());
3548 }
3549 };
3550
3551 if let Some(error) = response.error {
3552 Err(Error::RpcError(error.code, error.message))
3553 } else if let Some(result) = response.result {
3554 Ok(result)
3555 } else {
3556 Err(Error::InvalidResponse)
3557 }
3558 } else {
3559 let err = resp.error_for_status_ref().unwrap_err();
3560 Err(Error::HttpError(err))
3561 }
3562 }
3563
3564 /// Send a JSON-RPC request to the server. The method is the name of the
3565 /// method to call on the server. The params are the parameters to pass to
3566 /// the method. The method and params are serialized into a JSON-RPC
3567 /// request and sent to the server. The response is deserialized into
3568 /// the specified type and returned to the caller.
3569 ///
3570 /// NOTE: This API would generally not be called directly and instead users
3571 /// should use the higher-level methods provided by the client.
3572 pub async fn rpc<Params, RpcResult>(
3573 &self,
3574 method: String,
3575 params: Option<Params>,
3576 ) -> Result<RpcResult, Error>
3577 where
3578 Params: Serialize,
3579 RpcResult: DeserializeOwned,
3580 {
3581 let auth_expires = self.token_expiration().await?;
3582 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
3583 self.renew_token().await?;
3584 }
3585
3586 self.rpc_without_auth(method, params).await
3587 }
3588
3589 async fn rpc_without_auth<Params, RpcResult>(
3590 &self,
3591 method: String,
3592 params: Option<Params>,
3593 ) -> Result<RpcResult, Error>
3594 where
3595 Params: Serialize,
3596 RpcResult: DeserializeOwned,
3597 {
3598 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
3599 .ok()
3600 .and_then(|s| s.parse().ok())
3601 .unwrap_or(3usize);
3602
3603 let url = format!("{}/api", self.url);
3604
3605 // Serialize request body once before retry loop to avoid Clone bound on Params
3606 let request = RpcRequest {
3607 method: method.clone(),
3608 params,
3609 ..Default::default()
3610 };
3611
3612 if log_enabled!(Level::Trace) {
3613 trace!(
3614 "RPC Request: {}",
3615 serde_json::ser::to_string_pretty(&request)?
3616 );
3617 }
3618
3619 let request_body = serde_json::to_vec(&request)?;
3620 let mut last_error: Option<Error> = None;
3621
3622 for attempt in 0..=max_retries {
3623 if attempt > 0 {
3624 // Exponential backoff: 1s, 2s, 4s, 8s, ...
3625 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
3626 warn!(
3627 "Retry {}/{} for RPC '{}' after {:?}",
3628 attempt, max_retries, method, delay
3629 );
3630 tokio::time::sleep(delay).await;
3631 }
3632
3633 let result = self
3634 .http
3635 .post(&url)
3636 .header("Accept", "application/json")
3637 .header("Content-Type", "application/json")
3638 .header("User-Agent", "EdgeFirst Client")
3639 .header("Authorization", format!("Bearer {}", self.token().await))
3640 .body(request_body.clone())
3641 .send()
3642 .await;
3643
3644 match result {
3645 Ok(res) => {
3646 let status = res.status();
3647 let status_code = status.as_u16();
3648
3649 // Check for retryable HTTP status codes before processing response
3650 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
3651 && attempt < max_retries
3652 {
3653 warn!(
3654 "RPC '{}' failed with HTTP {} (retrying)",
3655 method, status_code
3656 );
3657 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
3658 continue;
3659 }
3660
3661 // Process the response
3662 match self.process_rpc_response(res).await {
3663 Ok(result) => {
3664 if attempt > 0 {
3665 debug!("RPC '{}' succeeded on retry {}", method, attempt);
3666 }
3667 return Ok(result);
3668 }
3669 Err(e) => {
3670 // Don't retry client errors (4xx except 408, 429)
3671 if attempt > 0 {
3672 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
3673 }
3674 return Err(e);
3675 }
3676 }
3677 }
3678 Err(e) => {
3679 // Transport error (timeout, connection failure, etc.)
3680 let is_timeout = e.is_timeout();
3681 let is_connect = e.is_connect();
3682
3683 if (is_timeout || is_connect) && attempt < max_retries {
3684 warn!(
3685 "RPC '{}' transport error (retrying): {}",
3686 method,
3687 if is_timeout {
3688 "timeout"
3689 } else {
3690 "connection failed"
3691 }
3692 );
3693 last_error = Some(Error::HttpError(e));
3694 continue;
3695 }
3696
3697 if attempt > 0 {
3698 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
3699 }
3700 return Err(Error::HttpError(e));
3701 }
3702 }
3703 }
3704
3705 // Should not reach here
3706 Err(last_error.unwrap_or_else(|| {
3707 Error::InvalidParameters(format!(
3708 "RPC '{}' failed after {} retries",
3709 method, max_retries
3710 ))
3711 }))
3712 }
3713
3714 async fn process_rpc_response<RpcResult>(
3715 &self,
3716 res: reqwest::Response,
3717 ) -> Result<RpcResult, Error>
3718 where
3719 RpcResult: DeserializeOwned,
3720 {
3721 let body = res.bytes().await?;
3722
3723 if log_enabled!(Level::Trace) {
3724 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
3725 }
3726
3727 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
3728 Ok(response) => response,
3729 Err(err) => {
3730 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3731 return Err(err.into());
3732 }
3733 };
3734
3735 // FIXME: Studio Server always returns 999 as the id.
3736 // if request.id.to_string() != response.id {
3737 // return Err(Error::InvalidRpcId(response.id));
3738 // }
3739
3740 if let Some(error) = response.error {
3741 Err(Error::RpcError(error.code, error.message))
3742 } else if let Some(result) = response.result {
3743 Ok(result)
3744 } else {
3745 Err(Error::InvalidResponse)
3746 }
3747 }
3748}
3749
3750/// Process items in parallel with semaphore concurrency control and progress
3751/// tracking.
3752///
3753/// This helper eliminates boilerplate for parallel item processing with:
3754/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
3755/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
3756/// - Atomic progress counter with automatic item-level updates
3757/// - Progress updates sent after each item completes (not byte-level streaming)
3758/// - Proper error propagation from spawned tasks
3759///
3760/// Note: This is optimized for discrete items with post-completion progress
3761/// updates. For byte-level streaming progress or custom retry logic, use
3762/// specialized implementations.
3763///
3764/// # Arguments
3765///
3766/// * `items` - Collection of items to process in parallel
3767/// * `progress` - Optional progress channel for tracking completion
3768/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
3769/// * `work_fn` - Async function to execute for each item
3770///
3771/// # Examples
3772///
3773/// ```rust,ignore
3774/// // Use default concurrency
3775/// parallel_foreach_items(samples, progress, None, |sample| async move {
3776/// sample.download(&client, file_type).await?;
3777/// Ok(())
3778/// }).await?;
3779/// ```
3780async fn parallel_foreach_items<T, F, Fut>(
3781 items: Vec<T>,
3782 progress: Option<Sender<Progress>>,
3783 concurrency: Option<usize>,
3784 work_fn: F,
3785) -> Result<(), Error>
3786where
3787 T: Send + 'static,
3788 F: Fn(T) -> Fut + Send + Sync + 'static,
3789 Fut: Future<Output = Result<(), Error>> + Send + 'static,
3790{
3791 let total = items.len();
3792 let current = Arc::new(AtomicUsize::new(0));
3793 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
3794 let work_fn = Arc::new(work_fn);
3795
3796 let tasks = items
3797 .into_iter()
3798 .map(|item| {
3799 let sem = sem.clone();
3800 let current = current.clone();
3801 let progress = progress.clone();
3802 let work_fn = work_fn.clone();
3803
3804 tokio::spawn(async move {
3805 let _permit = sem.acquire().await.map_err(|_| {
3806 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3807 })?;
3808
3809 // Execute the actual work
3810 work_fn(item).await?;
3811
3812 // Update progress
3813 if let Some(progress) = &progress {
3814 let current = current.fetch_add(1, Ordering::SeqCst);
3815 let _ = progress
3816 .send(Progress {
3817 current: current + 1,
3818 total,
3819 })
3820 .await;
3821 }
3822
3823 Ok::<(), Error>(())
3824 })
3825 })
3826 .collect::<Vec<_>>();
3827
3828 join_all(tasks)
3829 .await
3830 .into_iter()
3831 .collect::<Result<Vec<_>, _>>()?
3832 .into_iter()
3833 .collect::<Result<Vec<_>, _>>()?;
3834
3835 if let Some(progress) = progress {
3836 drop(progress);
3837 }
3838
3839 Ok(())
3840}
3841
3842/// Upload a file to S3 using multipart upload with presigned URLs.
3843///
3844/// Splits a file into chunks (100MB each) and uploads them in parallel using
3845/// S3 multipart upload protocol. Returns completion parameters with ETags for
3846/// finalizing the upload.
3847///
3848/// This function handles:
3849/// - Splitting files into parts based on PART_SIZE (100MB)
3850/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
3851/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
3852/// - Retry logic (handled by reqwest client)
3853/// - Progress tracking across all parts
3854///
3855/// # Arguments
3856///
3857/// * `http` - HTTP client for making requests
3858/// * `part` - Snapshot part info with presigned URLs for each chunk
3859/// * `path` - Local file path to upload
3860/// * `total` - Total bytes across all files for progress calculation
3861/// * `current` - Atomic counter tracking bytes uploaded across all operations
3862/// * `progress` - Optional channel for sending progress updates
3863///
3864/// # Returns
3865///
3866/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
3867async fn upload_multipart(
3868 http: reqwest::Client,
3869 part: SnapshotPart,
3870 path: PathBuf,
3871 total: usize,
3872 current: Arc<AtomicUsize>,
3873 progress: Option<Sender<Progress>>,
3874) -> Result<SnapshotCompleteMultipartParams, Error> {
3875 let filesize = path.metadata()?.len() as usize;
3876 let n_parts = filesize.div_ceil(PART_SIZE);
3877 let sem = Arc::new(Semaphore::new(max_tasks()));
3878
3879 let key = part.key.ok_or(Error::InvalidResponse)?;
3880 let upload_id = part.upload_id;
3881
3882 let urls = part.urls.clone();
3883 // Pre-allocate ETag slots for all parts
3884 let etags = Arc::new(tokio::sync::Mutex::new(vec![
3885 EtagPart {
3886 etag: "".to_owned(),
3887 part_number: 0,
3888 };
3889 n_parts
3890 ]));
3891
3892 // Upload all parts in parallel with concurrency limiting
3893 let tasks = (0..n_parts)
3894 .map(|part| {
3895 let http = http.clone();
3896 let url = urls[part].clone();
3897 let etags = etags.clone();
3898 let path = path.to_owned();
3899 let sem = sem.clone();
3900 let progress = progress.clone();
3901 let current = current.clone();
3902
3903 tokio::spawn(async move {
3904 // Acquire semaphore permit to limit concurrent uploads
3905 let _permit = sem.acquire().await?;
3906
3907 // Upload part (retry is handled by reqwest client)
3908 let etag =
3909 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
3910
3911 // Store ETag for this part (needed to complete multipart upload)
3912 let mut etags = etags.lock().await;
3913 etags[part] = EtagPart {
3914 etag,
3915 part_number: part + 1,
3916 };
3917
3918 // Update progress counter
3919 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
3920 if let Some(progress) = &progress {
3921 let _ = progress
3922 .send(Progress {
3923 current: current + PART_SIZE,
3924 total,
3925 })
3926 .await;
3927 }
3928
3929 Ok::<(), Error>(())
3930 })
3931 })
3932 .collect::<Vec<_>>();
3933
3934 // Wait for all parts to complete
3935 join_all(tasks)
3936 .await
3937 .into_iter()
3938 .collect::<Result<Vec<_>, _>>()?;
3939
3940 Ok(SnapshotCompleteMultipartParams {
3941 key,
3942 upload_id,
3943 etag_list: etags.lock().await.clone(),
3944 })
3945}
3946
3947async fn upload_part(
3948 http: reqwest::Client,
3949 url: String,
3950 path: PathBuf,
3951 part: usize,
3952 n_parts: usize,
3953) -> Result<String, Error> {
3954 let filesize = path.metadata()?.len() as usize;
3955 let mut file = File::open(path).await?;
3956 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
3957 .await?;
3958 let file = file.take(PART_SIZE as u64);
3959
3960 let body_length = if part + 1 == n_parts {
3961 filesize % PART_SIZE
3962 } else {
3963 PART_SIZE
3964 };
3965
3966 let stream = FramedRead::new(file, BytesCodec::new());
3967 let body = Body::wrap_stream(stream);
3968
3969 let resp = http
3970 .put(url.clone())
3971 .header(CONTENT_LENGTH, body_length)
3972 .body(body)
3973 .send()
3974 .await?
3975 .error_for_status()?;
3976
3977 let etag = resp
3978 .headers()
3979 .get("etag")
3980 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
3981 .to_str()
3982 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
3983 .to_owned();
3984
3985 // Studio Server requires etag without the quotes.
3986 let etag = etag
3987 .strip_prefix("\"")
3988 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
3989 let etag = etag
3990 .strip_suffix("\"")
3991 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
3992
3993 Ok(etag.to_owned())
3994}
3995
3996/// Upload a complete file to a presigned S3 URL using HTTP PUT.
3997///
3998/// This is used for populate_samples to upload files to S3 after
3999/// receiving presigned URLs from the server.
4000///
4001/// Includes explicit retry logic with exponential backoff for transient
4002/// failures.
4003async fn upload_file_to_presigned_url(
4004 http: reqwest::Client,
4005 url: &str,
4006 path: PathBuf,
4007) -> Result<(), Error> {
4008 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4009 .ok()
4010 .and_then(|s| s.parse().ok())
4011 .unwrap_or(3usize);
4012
4013 // Read the entire file into memory once
4014 let file_data = fs::read(&path).await?;
4015 let file_size = file_data.len();
4016 let filename = path.file_name().unwrap_or_default().to_string_lossy();
4017
4018 let mut last_error: Option<Error> = None;
4019
4020 for attempt in 0..=max_retries {
4021 if attempt > 0 {
4022 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4023 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4024 warn!(
4025 "Retry {}/{} for upload '{}' after {:?}",
4026 attempt, max_retries, filename, delay
4027 );
4028 tokio::time::sleep(delay).await;
4029 }
4030
4031 // Attempt upload
4032 let result = http
4033 .put(url)
4034 .header(CONTENT_LENGTH, file_size)
4035 .body(file_data.clone())
4036 .send()
4037 .await;
4038
4039 match result {
4040 Ok(resp) => {
4041 if resp.status().is_success() {
4042 if attempt > 0 {
4043 debug!(
4044 "Upload '{}' succeeded on retry {} ({} bytes)",
4045 filename, attempt, file_size
4046 );
4047 } else {
4048 debug!(
4049 "Successfully uploaded file: {} ({} bytes)",
4050 filename, file_size
4051 );
4052 }
4053 return Ok(());
4054 }
4055
4056 let status = resp.status();
4057 let status_code = status.as_u16();
4058
4059 // Check if error is retryable
4060 let is_retryable =
4061 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
4062
4063 if is_retryable && attempt < max_retries {
4064 let error_text = resp.text().await.unwrap_or_default();
4065 warn!(
4066 "Upload '{}' failed with HTTP {} (retryable): {}",
4067 filename, status_code, error_text
4068 );
4069 last_error = Some(Error::InvalidParameters(format!(
4070 "Upload failed: HTTP {} - {}",
4071 status, error_text
4072 )));
4073 continue;
4074 }
4075
4076 // Non-retryable error or max retries exceeded
4077 let error_text = resp.text().await.unwrap_or_default();
4078 if attempt > 0 {
4079 error!(
4080 "Upload '{}' failed after {} retries: HTTP {} - {}",
4081 filename, attempt, status, error_text
4082 );
4083 }
4084 return Err(Error::InvalidParameters(format!(
4085 "Upload failed: HTTP {} - {}",
4086 status, error_text
4087 )));
4088 }
4089 Err(e) => {
4090 // Transport error (timeout, connection failure, etc.)
4091 let is_timeout = e.is_timeout();
4092 let is_connect = e.is_connect();
4093
4094 if (is_timeout || is_connect) && attempt < max_retries {
4095 warn!(
4096 "Upload '{}' transport error (retrying): {}",
4097 filename,
4098 if is_timeout {
4099 "timeout"
4100 } else {
4101 "connection failed"
4102 }
4103 );
4104 last_error = Some(Error::HttpError(e));
4105 continue;
4106 }
4107
4108 // Non-retryable or max retries exceeded
4109 if attempt > 0 {
4110 error!(
4111 "Upload '{}' failed after {} retries: {}",
4112 filename, attempt, e
4113 );
4114 }
4115 return Err(Error::HttpError(e));
4116 }
4117 }
4118 }
4119
4120 // Should not reach here, but return last error if we do
4121 Err(last_error.unwrap_or_else(|| {
4122 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
4123 }))
4124}
4125
4126#[cfg(test)]
4127mod tests {
4128 use super::*;
4129
4130 #[test]
4131 fn test_filter_and_sort_by_name_exact_match_first() {
4132 // Test that exact matches come first
4133 let items = vec![
4134 "Deer Roundtrip 123".to_string(),
4135 "Deer".to_string(),
4136 "Reindeer".to_string(),
4137 "DEER".to_string(),
4138 ];
4139 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
4140 assert_eq!(result[0], "Deer"); // Exact match first
4141 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
4142 }
4143
4144 #[test]
4145 fn test_filter_and_sort_by_name_shorter_names_preferred() {
4146 // Test that shorter names (more specific) come before longer ones
4147 let items = vec![
4148 "Test Dataset ABC".to_string(),
4149 "Test".to_string(),
4150 "Test Dataset".to_string(),
4151 ];
4152 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
4153 assert_eq!(result[0], "Test"); // Exact match first
4154 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
4155 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
4156 }
4157
4158 #[test]
4159 fn test_filter_and_sort_by_name_case_insensitive_filter() {
4160 // Test that filtering is case-insensitive
4161 let items = vec![
4162 "UPPERCASE".to_string(),
4163 "lowercase".to_string(),
4164 "MixedCase".to_string(),
4165 ];
4166 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
4167 assert_eq!(result.len(), 3); // All items should match
4168 }
4169
4170 #[test]
4171 fn test_filter_and_sort_by_name_no_matches() {
4172 // Test that empty result is returned when no matches
4173 let items = vec!["Apple".to_string(), "Banana".to_string()];
4174 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
4175 assert!(result.is_empty());
4176 }
4177
4178 #[test]
4179 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
4180 // Test alphabetical ordering for same-length names
4181 let items = vec![
4182 "TestC".to_string(),
4183 "TestA".to_string(),
4184 "TestB".to_string(),
4185 ];
4186 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
4187 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
4188 }
4189
4190 #[test]
4191 fn test_build_filename_no_flatten() {
4192 // When flatten=false, should return base_name unchanged
4193 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
4194 assert_eq!(result, "image.jpg");
4195
4196 let result = Client::build_filename("test.png", false, None, None);
4197 assert_eq!(result, "test.png");
4198 }
4199
4200 #[test]
4201 fn test_build_filename_flatten_no_sequence() {
4202 // When flatten=true but no sequence, should return base_name unchanged
4203 let result = Client::build_filename("standalone.jpg", true, None, None);
4204 assert_eq!(result, "standalone.jpg");
4205 }
4206
4207 #[test]
4208 fn test_build_filename_flatten_with_sequence_not_prefixed() {
4209 // When flatten=true, in sequence, filename not prefixed → add prefix
4210 let result = Client::build_filename(
4211 "image.camera.jpeg",
4212 true,
4213 Some(&"deer_sequence".to_string()),
4214 Some(42),
4215 );
4216 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
4217 }
4218
4219 #[test]
4220 fn test_build_filename_flatten_with_sequence_no_frame() {
4221 // When flatten=true, in sequence, no frame number → prefix with sequence only
4222 let result =
4223 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
4224 assert_eq!(result, "sequence_A_image.jpg");
4225 }
4226
4227 #[test]
4228 fn test_build_filename_flatten_already_prefixed() {
4229 // When flatten=true, filename already starts with sequence_ → return unchanged
4230 let result = Client::build_filename(
4231 "deer_sequence_042.camera.jpeg",
4232 true,
4233 Some(&"deer_sequence".to_string()),
4234 Some(42),
4235 );
4236 assert_eq!(result, "deer_sequence_042.camera.jpeg");
4237 }
4238
4239 #[test]
4240 fn test_build_filename_flatten_already_prefixed_different_frame() {
4241 // Edge case: filename has sequence prefix but we're adding different frame
4242 // Should still respect existing prefix
4243 let result = Client::build_filename(
4244 "sequence_A_001.jpg",
4245 true,
4246 Some(&"sequence_A".to_string()),
4247 Some(2),
4248 );
4249 assert_eq!(result, "sequence_A_001.jpg");
4250 }
4251
4252 #[test]
4253 fn test_build_filename_flatten_partial_match() {
4254 // Edge case: filename contains sequence name but not as prefix
4255 let result = Client::build_filename(
4256 "test_sequence_A_image.jpg",
4257 true,
4258 Some(&"sequence_A".to_string()),
4259 Some(5),
4260 );
4261 // Should add prefix because it doesn't START with "sequence_A_"
4262 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
4263 }
4264
4265 #[test]
4266 fn test_build_filename_flatten_preserves_extension() {
4267 // Verify that file extensions are preserved correctly
4268 let extensions = vec![
4269 "jpeg",
4270 "jpg",
4271 "png",
4272 "camera.jpeg",
4273 "lidar.pcd",
4274 "depth.png",
4275 ];
4276
4277 for ext in extensions {
4278 let filename = format!("image.{}", ext);
4279 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
4280 assert!(
4281 result.ends_with(&format!(".{}", ext)),
4282 "Extension .{} not preserved in {}",
4283 ext,
4284 result
4285 );
4286 }
4287 }
4288
4289 #[test]
4290 fn test_build_filename_flatten_sanitization_compatibility() {
4291 // Test with sanitized path components (no special chars)
4292 let result = Client::build_filename(
4293 "sample_001.jpg",
4294 true,
4295 Some(&"seq_name_with_underscores".to_string()),
4296 Some(10),
4297 );
4298 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
4299 }
4300
4301 // =========================================================================
4302 // Additional filter_and_sort_by_name tests for exact match determinism
4303 // =========================================================================
4304
4305 #[test]
4306 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
4307 // Test that searching for "Deer" always returns "Deer" first, not
4308 // "Deer Roundtrip 20251129" or similar
4309 let items = vec![
4310 "Deer Roundtrip 20251129".to_string(),
4311 "White-Tailed Deer".to_string(),
4312 "Deer".to_string(),
4313 "Deer Snapshot Test".to_string(),
4314 "Reindeer Dataset".to_string(),
4315 ];
4316
4317 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
4318
4319 // CRITICAL: First result must be exact match "Deer"
4320 assert_eq!(
4321 result.first().map(|s| s.as_str()),
4322 Some("Deer"),
4323 "Expected exact match 'Deer' first, got: {:?}",
4324 result.first()
4325 );
4326
4327 // Verify all items containing "Deer" are present (case-insensitive)
4328 assert_eq!(result.len(), 5);
4329 }
4330
4331 #[test]
4332 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
4333 // Verify case-sensitive exact match takes priority over case-insensitive
4334 let items = vec![
4335 "DEER".to_string(),
4336 "deer".to_string(),
4337 "Deer".to_string(),
4338 "Deer Test".to_string(),
4339 ];
4340
4341 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
4342
4343 // Priority 1: Case-sensitive exact match "Deer" first
4344 assert_eq!(result[0], "Deer");
4345 // Priority 2: Case-insensitive exact matches next
4346 assert!(result[1] == "DEER" || result[1] == "deer");
4347 assert!(result[2] == "DEER" || result[2] == "deer");
4348 }
4349
4350 #[test]
4351 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
4352 // Realistic scenario: User searches for snapshot "Deer" and multiple
4353 // snapshots exist with similar names
4354 let items = vec![
4355 "Unit Testing - Deer Dataset Backup".to_string(),
4356 "Deer".to_string(),
4357 "Deer Snapshot 2025-01-15".to_string(),
4358 "Original Deer".to_string(),
4359 ];
4360
4361 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
4362
4363 // MUST return exact match first for deterministic test behavior
4364 assert_eq!(
4365 result[0], "Deer",
4366 "Searching for 'Deer' should return exact 'Deer' first"
4367 );
4368 }
4369
4370 #[test]
4371 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
4372 // Realistic scenario: User searches for dataset "Deer" but multiple
4373 // datasets have "Deer" in their name
4374 let items = vec![
4375 "Deer Roundtrip".to_string(),
4376 "Deer".to_string(),
4377 "deer".to_string(),
4378 "White-Tailed Deer".to_string(),
4379 "Deer-V2".to_string(),
4380 ];
4381
4382 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
4383
4384 // Exact case-sensitive match must be first
4385 assert_eq!(result[0], "Deer");
4386 // Case-insensitive exact match should be second
4387 assert_eq!(result[1], "deer");
4388 // Shorter names should come before longer names
4389 assert!(
4390 result.iter().position(|s| s == "Deer-V2").unwrap()
4391 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
4392 );
4393 }
4394
4395 #[test]
4396 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
4397 // CRITICAL: The first result should ALWAYS be the best match
4398 // This is essential for deterministic test behavior
4399 let scenarios = vec![
4400 // (items, filter, expected_first)
4401 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
4402 (vec!["test", "TEST", "Test Data"], "test", "test"),
4403 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
4404 ];
4405
4406 for (items, filter, expected_first) in scenarios {
4407 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
4408 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
4409
4410 assert_eq!(
4411 result.first().map(|s| s.as_str()),
4412 Some(expected_first),
4413 "For filter '{}', expected first result '{}', got: {:?}",
4414 filter,
4415 expected_first,
4416 result.first()
4417 );
4418 }
4419 }
4420
4421 #[test]
4422 fn test_with_server_clears_storage() {
4423 use crate::storage::MemoryTokenStorage;
4424
4425 // Create client with memory storage and a token
4426 let storage = Arc::new(MemoryTokenStorage::new());
4427 storage.store("test-token").unwrap();
4428
4429 let client = Client::new().unwrap().with_storage(storage.clone());
4430
4431 // Verify token is loaded
4432 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
4433
4434 // Change server - should clear storage
4435 let _new_client = client.with_server("test").unwrap();
4436
4437 // Verify storage was cleared
4438 assert_eq!(storage.load().unwrap(), None);
4439 }
4440}