edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID, SnapshotRestore,
10 SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
11 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
12 ValidationSessionID,
13 },
14 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
15 retry::{create_retry_policy, log_retry_configuration},
16 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
17};
18use base64::Engine as _;
19use chrono::{DateTime, Utc};
20use directories::ProjectDirs;
21use futures::{StreamExt as _, future::join_all};
22use log::{Level, debug, error, log_enabled, trace, warn};
23use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use std::{
26 collections::HashMap,
27 ffi::OsStr,
28 fs::create_dir_all,
29 io::{SeekFrom, Write as _},
30 path::{Path, PathBuf},
31 sync::{
32 Arc,
33 atomic::{AtomicUsize, Ordering},
34 },
35 time::Duration,
36 vec,
37};
38use tokio::{
39 fs::{self, File},
40 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
41 sync::{RwLock, Semaphore, mpsc::Sender},
42};
43use tokio_util::codec::{BytesCodec, FramedRead};
44use walkdir::WalkDir;
45
46#[cfg(feature = "polars")]
47use polars::prelude::*;
48
49static PART_SIZE: usize = 100 * 1024 * 1024;
50
51fn max_tasks() -> usize {
52 std::env::var("MAX_TASKS")
53 .ok()
54 .and_then(|v| v.parse().ok())
55 .unwrap_or_else(|| {
56 // Default to half the number of CPUs, minimum 2, maximum 8
57 // Lower max prevents timeout issues with large file uploads
58 let cpus = std::thread::available_parallelism()
59 .map(|n| n.get())
60 .unwrap_or(4);
61 (cpus / 2).clamp(2, 8)
62 })
63}
64
65/// Filters items by name and sorts by match quality.
66///
67/// Match quality priority (best to worst):
68/// 1. Exact match (case-sensitive)
69/// 2. Exact match (case-insensitive)
70/// 3. Substring match (shorter names first, then alphabetically)
71///
72/// This ensures that searching for "Deer" returns "Deer" before
73/// "Deer Roundtrip 20251129" or "Reindeer".
74fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
75where
76 F: Fn(&T) -> &str,
77{
78 let filter_lower = filter.to_lowercase();
79 let mut filtered: Vec<T> = items
80 .into_iter()
81 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
82 .collect();
83
84 filtered.sort_by(|a, b| {
85 let name_a = get_name(a);
86 let name_b = get_name(b);
87
88 // Priority 1: Exact match (case-sensitive)
89 let exact_a = name_a == filter;
90 let exact_b = name_b == filter;
91 if exact_a != exact_b {
92 return exact_b.cmp(&exact_a); // true (exact) comes first
93 }
94
95 // Priority 2: Exact match (case-insensitive)
96 let exact_ci_a = name_a.to_lowercase() == filter_lower;
97 let exact_ci_b = name_b.to_lowercase() == filter_lower;
98 if exact_ci_a != exact_ci_b {
99 return exact_ci_b.cmp(&exact_ci_a);
100 }
101
102 // Priority 3: Shorter names first (more specific matches)
103 let len_cmp = name_a.len().cmp(&name_b.len());
104 if len_cmp != std::cmp::Ordering::Equal {
105 return len_cmp;
106 }
107
108 // Priority 4: Alphabetical order for stability
109 name_a.cmp(name_b)
110 });
111
112 filtered
113}
114
115fn sanitize_path_component(name: &str) -> String {
116 let trimmed = name.trim();
117 if trimmed.is_empty() {
118 return "unnamed".to_string();
119 }
120
121 let component = Path::new(trimmed)
122 .file_name()
123 .unwrap_or_else(|| OsStr::new(trimmed));
124
125 let sanitized: String = component
126 .to_string_lossy()
127 .chars()
128 .map(|c| match c {
129 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
130 _ => c,
131 })
132 .collect();
133
134 if sanitized.is_empty() {
135 "unnamed".to_string()
136 } else {
137 sanitized
138 }
139}
140
141/// Progress information for long-running operations.
142///
143/// This struct tracks the current progress of operations like file uploads,
144/// downloads, or dataset processing. It provides the current count and total
145/// count to enable progress reporting in applications.
146///
147/// # Examples
148///
149/// ```rust
150/// use edgefirst_client::Progress;
151///
152/// let progress = Progress {
153/// current: 25,
154/// total: 100,
155/// };
156/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
157/// println!(
158/// "Progress: {:.1}% ({}/{})",
159/// percentage, progress.current, progress.total
160/// );
161/// ```
162#[derive(Debug, Clone)]
163pub struct Progress {
164 /// Current number of completed items.
165 pub current: usize,
166 /// Total number of items to process.
167 pub total: usize,
168}
169
170#[derive(Serialize)]
171struct RpcRequest<Params> {
172 id: u64,
173 jsonrpc: String,
174 method: String,
175 params: Option<Params>,
176}
177
178impl<T> Default for RpcRequest<T> {
179 fn default() -> Self {
180 RpcRequest {
181 id: 0,
182 jsonrpc: "2.0".to_string(),
183 method: "".to_string(),
184 params: None,
185 }
186 }
187}
188
189#[derive(Deserialize)]
190struct RpcError {
191 code: i32,
192 message: String,
193}
194
195#[derive(Deserialize)]
196struct RpcResponse<RpcResult> {
197 #[allow(dead_code)]
198 id: String,
199 #[allow(dead_code)]
200 jsonrpc: String,
201 error: Option<RpcError>,
202 result: Option<RpcResult>,
203}
204
205#[derive(Deserialize)]
206#[allow(dead_code)]
207struct EmptyResult {}
208
209#[derive(Debug, Serialize)]
210#[allow(dead_code)]
211struct SnapshotCreateParams {
212 snapshot_name: String,
213 keys: Vec<String>,
214}
215
216#[derive(Debug, Deserialize)]
217#[allow(dead_code)]
218struct SnapshotCreateResult {
219 snapshot_id: SnapshotID,
220 urls: Vec<String>,
221}
222
223#[derive(Debug, Serialize)]
224struct SnapshotCreateMultipartParams {
225 snapshot_name: String,
226 keys: Vec<String>,
227 file_sizes: Vec<usize>,
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(untagged)]
232enum SnapshotCreateMultipartResultField {
233 Id(u64),
234 Part(SnapshotPart),
235}
236
237#[derive(Debug, Serialize)]
238struct SnapshotCompleteMultipartParams {
239 key: String,
240 upload_id: String,
241 etag_list: Vec<EtagPart>,
242}
243
244#[derive(Debug, Clone, Serialize)]
245struct EtagPart {
246 #[serde(rename = "ETag")]
247 etag: String,
248 #[serde(rename = "PartNumber")]
249 part_number: usize,
250}
251
252#[derive(Debug, Clone, Deserialize)]
253struct SnapshotPart {
254 key: Option<String>,
255 upload_id: String,
256 urls: Vec<String>,
257}
258
259#[derive(Debug, Serialize)]
260struct SnapshotStatusParams {
261 snapshot_id: SnapshotID,
262 status: String,
263}
264
265#[derive(Deserialize, Debug)]
266struct SnapshotStatusResult {
267 #[allow(dead_code)]
268 pub id: SnapshotID,
269 #[allow(dead_code)]
270 pub uid: String,
271 #[allow(dead_code)]
272 pub description: String,
273 #[allow(dead_code)]
274 pub date: String,
275 #[allow(dead_code)]
276 pub status: String,
277}
278
279#[derive(Serialize)]
280#[allow(dead_code)]
281struct ImageListParams {
282 images_filter: ImagesFilter,
283 image_files_filter: HashMap<String, String>,
284 only_ids: bool,
285}
286
287#[derive(Serialize)]
288#[allow(dead_code)]
289struct ImagesFilter {
290 dataset_id: DatasetID,
291}
292
293/// Main client for interacting with EdgeFirst Studio Server.
294///
295/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
296/// and manages authentication, RPC calls, and data operations. It provides
297/// methods for managing projects, datasets, experiments, training sessions,
298/// and various utility functions for data processing.
299///
300/// The client supports multiple authentication methods and can work with both
301/// SaaS and self-hosted EdgeFirst Studio instances.
302///
303/// # Features
304///
305/// - **Authentication**: Token-based authentication with automatic persistence
306/// - **Dataset Management**: Upload, download, and manipulate datasets
307/// - **Project Operations**: Create and manage projects and experiments
308/// - **Training & Validation**: Submit and monitor ML training jobs
309/// - **Data Integration**: Convert between EdgeFirst datasets and popular
310/// formats
311/// - **Progress Tracking**: Real-time progress updates for long-running
312/// operations
313///
314/// # Examples
315///
316/// ```no_run
317/// use edgefirst_client::{Client, DatasetID};
318/// use std::str::FromStr;
319///
320/// # async fn example() -> Result<(), edgefirst_client::Error> {
321/// // Create a new client and authenticate
322/// let mut client = Client::new()?;
323/// let client = client
324/// .with_login("your-email@example.com", "password")
325/// .await?;
326///
327/// // Or use an existing token
328/// let base_client = Client::new()?;
329/// let client = base_client.with_token("your-token-here")?;
330///
331/// // Get organization and projects
332/// let org = client.organization().await?;
333/// let projects = client.projects(None).await?;
334///
335/// // Work with datasets
336/// let dataset_id = DatasetID::from_str("ds-abc123")?;
337/// let dataset = client.dataset(dataset_id).await?;
338/// # Ok(())
339/// # }
340/// ```
341/// Client is Clone but cannot derive Debug due to dyn TokenStorage
342#[derive(Clone)]
343pub struct Client {
344 http: reqwest::Client,
345 url: String,
346 token: Arc<RwLock<String>>,
347 /// Token storage backend. When set, tokens are automatically persisted.
348 storage: Option<Arc<dyn TokenStorage>>,
349 /// Legacy token path field for backwards compatibility with
350 /// with_token_path(). Deprecated: Use with_storage() instead.
351 token_path: Option<PathBuf>,
352}
353
354impl std::fmt::Debug for Client {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("Client")
357 .field("url", &self.url)
358 .field("has_storage", &self.storage.is_some())
359 .field("token_path", &self.token_path)
360 .finish()
361 }
362}
363
364/// Private context struct for pagination operations
365struct FetchContext<'a> {
366 dataset_id: DatasetID,
367 annotation_set_id: Option<AnnotationSetID>,
368 groups: &'a [String],
369 types: Vec<String>,
370 labels: &'a HashMap<String, u64>,
371}
372
373impl Client {
374 /// Create a new unauthenticated client with the default saas server.
375 ///
376 /// By default, the client uses [`FileTokenStorage`] for token persistence.
377 /// Use [`with_storage`][Self::with_storage],
378 /// [`with_memory_storage`][Self::with_memory_storage],
379 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
380 /// behavior.
381 ///
382 /// To connect to a different server, use [`with_server`][Self::with_server]
383 /// or [`with_token`][Self::with_token] (tokens include the server
384 /// instance).
385 ///
386 /// This client is created without a token and will need to authenticate
387 /// before using methods that require authentication.
388 ///
389 /// # Examples
390 ///
391 /// ```rust,no_run
392 /// use edgefirst_client::Client;
393 ///
394 /// # fn main() -> Result<(), edgefirst_client::Error> {
395 /// // Create client with default file storage
396 /// let client = Client::new()?;
397 ///
398 /// // Create client without token persistence
399 /// let client = Client::new()?.with_memory_storage();
400 /// # Ok(())
401 /// # }
402 /// ```
403 pub fn new() -> Result<Self, Error> {
404 log_retry_configuration();
405
406 // Get timeout from environment or use default
407 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
408 .ok()
409 .and_then(|s| s.parse().ok())
410 .unwrap_or(30); // Default 30s timeout for API calls
411
412 // Create single HTTP client with URL-based retry policy
413 //
414 // The retry policy classifies requests into two categories:
415 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
416 // errors
417 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
418 //
419 // This allows the same client to handle both API calls and file operations
420 // with appropriate retry behavior for each. See retry.rs for details.
421 let http = reqwest::Client::builder()
422 .connect_timeout(Duration::from_secs(10))
423 .timeout(Duration::from_secs(timeout_secs))
424 .pool_idle_timeout(Duration::from_secs(90))
425 .pool_max_idle_per_host(10)
426 .retry(create_retry_policy())
427 .build()?;
428
429 // Default to file storage, loading any existing token
430 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
431 Ok(file_storage) => Arc::new(file_storage),
432 Err(e) => {
433 warn!(
434 "Could not initialize file token storage: {}. Using memory storage.",
435 e
436 );
437 Arc::new(MemoryTokenStorage::new())
438 }
439 };
440
441 // Try to load existing token from storage
442 let token = match storage.load() {
443 Ok(Some(t)) => t,
444 Ok(None) => String::new(),
445 Err(e) => {
446 warn!(
447 "Failed to load token from storage: {}. Starting with empty token.",
448 e
449 );
450 String::new()
451 }
452 };
453
454 // Extract server from token if available
455 let url = if !token.is_empty() {
456 match Self::extract_server_from_token(&token) {
457 Ok(server) => format!("https://{}.edgefirst.studio", server),
458 Err(e) => {
459 warn!("Failed to extract server from token: {}. Using default server.", e);
460 "https://edgefirst.studio".to_string()
461 }
462 }
463 } else {
464 "https://edgefirst.studio".to_string()
465 };
466
467 Ok(Client {
468 http,
469 url,
470 token: Arc::new(tokio::sync::RwLock::new(token)),
471 storage: Some(storage),
472 token_path: None,
473 })
474 }
475
476 /// Returns a new client connected to the specified server instance.
477 ///
478 /// The server parameter is an instance name that maps to a URL:
479 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
480 /// server)
481 /// - `"test"` → `https://test.edgefirst.studio`
482 /// - `"stage"` → `https://stage.edgefirst.studio`
483 /// - `"dev"` → `https://dev.edgefirst.studio`
484 /// - `"{name}"` → `https://{name}.edgefirst.studio`
485 ///
486 /// # Server Selection Priority
487 ///
488 /// When using the CLI or Python API, server selection follows this
489 /// priority:
490 ///
491 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
492 /// they were issued for. If you have a valid token, its server is used.
493 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
494 /// token is available. If a token exists with a different server, a
495 /// warning is emitted and the token's server takes priority.
496 /// 3. **Default `"saas"`** - If no token and no server specified, the
497 /// production server (`https://edgefirst.studio`) is used.
498 ///
499 /// # Important Notes
500 ///
501 /// - If a token is already set in the client, calling this method will
502 /// **drop the token** as tokens are specific to the server instance.
503 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
504 /// token's server before calling this method.
505 /// - For login operations, call `with_server()` first, then authenticate.
506 ///
507 /// # Examples
508 ///
509 /// ```rust,no_run
510 /// use edgefirst_client::Client;
511 ///
512 /// # fn main() -> Result<(), edgefirst_client::Error> {
513 /// let client = Client::new()?.with_server("test")?;
514 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
515 /// # Ok(())
516 /// # }
517 /// ```
518 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
519 let url = match server {
520 "" | "saas" => "https://edgefirst.studio".to_string(),
521 name => format!("https://{}.edgefirst.studio", name),
522 };
523
524 // Clear token from storage when changing servers to prevent
525 // authentication issues with stale tokens from different instances
526 if let Some(ref storage) = self.storage
527 && let Err(e) = storage.clear()
528 {
529 warn!(
530 "Failed to clear token from storage when changing servers: {}",
531 e
532 );
533 }
534
535 Ok(Client {
536 url,
537 token: Arc::new(tokio::sync::RwLock::new(String::new())),
538 ..self.clone()
539 })
540 }
541
542 /// Returns a new client with the specified token storage backend.
543 ///
544 /// Use this to configure custom token storage, such as platform-specific
545 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
546 ///
547 /// # Examples
548 ///
549 /// ```rust,no_run
550 /// use edgefirst_client::{Client, FileTokenStorage};
551 /// use std::{path::PathBuf, sync::Arc};
552 ///
553 /// # fn main() -> Result<(), edgefirst_client::Error> {
554 /// // Use a custom file path for token storage
555 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
556 /// let client = Client::new()?.with_storage(Arc::new(storage));
557 /// # Ok(())
558 /// # }
559 /// ```
560 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
561 // Try to load existing token from the new storage
562 let token = match storage.load() {
563 Ok(Some(t)) => t,
564 Ok(None) => String::new(),
565 Err(e) => {
566 warn!(
567 "Failed to load token from storage: {}. Starting with empty token.",
568 e
569 );
570 String::new()
571 }
572 };
573
574 Client {
575 token: Arc::new(tokio::sync::RwLock::new(token)),
576 storage: Some(storage),
577 token_path: None,
578 ..self
579 }
580 }
581
582 /// Returns a new client with in-memory token storage (no persistence).
583 ///
584 /// Tokens are stored in memory only and lost when the application exits.
585 /// This is useful for testing or when you want to manage token persistence
586 /// externally.
587 ///
588 /// # Examples
589 ///
590 /// ```rust,no_run
591 /// use edgefirst_client::Client;
592 ///
593 /// # fn main() -> Result<(), edgefirst_client::Error> {
594 /// let client = Client::new()?.with_memory_storage();
595 /// # Ok(())
596 /// # }
597 /// ```
598 pub fn with_memory_storage(self) -> Self {
599 Client {
600 token: Arc::new(tokio::sync::RwLock::new(String::new())),
601 storage: Some(Arc::new(MemoryTokenStorage::new())),
602 token_path: None,
603 ..self
604 }
605 }
606
607 /// Returns a new client with no token storage.
608 ///
609 /// Tokens are not persisted. Use this when you want to manage tokens
610 /// entirely manually.
611 ///
612 /// # Examples
613 ///
614 /// ```rust,no_run
615 /// use edgefirst_client::Client;
616 ///
617 /// # fn main() -> Result<(), edgefirst_client::Error> {
618 /// let client = Client::new()?.with_no_storage();
619 /// # Ok(())
620 /// # }
621 /// ```
622 pub fn with_no_storage(self) -> Self {
623 Client {
624 storage: None,
625 token_path: None,
626 ..self
627 }
628 }
629
630 /// Returns a new client authenticated with the provided username and
631 /// password.
632 ///
633 /// The token is automatically persisted to storage (if configured).
634 ///
635 /// # Examples
636 ///
637 /// ```rust,no_run
638 /// use edgefirst_client::Client;
639 ///
640 /// # async fn example() -> Result<(), edgefirst_client::Error> {
641 /// let client = Client::new()?
642 /// .with_server("test")?
643 /// .with_login("user@example.com", "password")
644 /// .await?;
645 /// # Ok(())
646 /// # }
647 /// ```
648 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
649 let params = HashMap::from([("username", username), ("password", password)]);
650 let login: LoginResult = self
651 .rpc_without_auth("auth.login".to_owned(), Some(params))
652 .await?;
653
654 // Validate that the server returned a non-empty token
655 if login.token.is_empty() {
656 return Err(Error::EmptyToken);
657 }
658
659 // Persist token to storage if configured
660 if let Some(ref storage) = self.storage
661 && let Err(e) = storage.store(&login.token)
662 {
663 warn!("Failed to persist token to storage: {}", e);
664 }
665
666 Ok(Client {
667 token: Arc::new(tokio::sync::RwLock::new(login.token)),
668 ..self.clone()
669 })
670 }
671
672 /// Returns a new client which will load and save the token to the specified
673 /// path.
674 ///
675 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
676 /// [`FileTokenStorage`] instead for more flexible token management.
677 ///
678 /// This method is maintained for backwards compatibility with existing
679 /// code. It disables the default storage and uses file-based storage at
680 /// the specified path.
681 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
682 let token_path = match token_path {
683 Some(path) => path.to_path_buf(),
684 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
685 .ok_or_else(|| {
686 Error::IoError(std::io::Error::new(
687 std::io::ErrorKind::NotFound,
688 "Could not determine user config directory",
689 ))
690 })?
691 .config_dir()
692 .join("token"),
693 };
694
695 debug!("Using token path (legacy): {:?}", token_path);
696
697 let token = match token_path.exists() {
698 true => std::fs::read_to_string(&token_path)?,
699 false => "".to_string(),
700 };
701
702 if !token.is_empty() {
703 match self.with_token(&token) {
704 Ok(client) => Ok(Client {
705 token_path: Some(token_path),
706 storage: None, // Disable new storage when using legacy token_path
707 ..client
708 }),
709 Err(e) => {
710 // Token is corrupted or invalid - remove it and continue with no token
711 warn!(
712 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
713 token_path, e
714 );
715 if let Err(remove_err) = std::fs::remove_file(&token_path) {
716 warn!("Failed to remove corrupted token file: {:?}", remove_err);
717 }
718 Ok(Client {
719 token_path: Some(token_path),
720 storage: None,
721 ..self.clone()
722 })
723 }
724 }
725 } else {
726 Ok(Client {
727 token_path: Some(token_path),
728 storage: None,
729 ..self.clone()
730 })
731 }
732 }
733
734 /// Returns a new client authenticated with the provided token.
735 ///
736 /// The token is automatically persisted to storage (if configured).
737 /// The server URL is extracted from the token payload.
738 ///
739 /// # Examples
740 ///
741 /// ```rust,no_run
742 /// use edgefirst_client::Client;
743 ///
744 /// # fn main() -> Result<(), edgefirst_client::Error> {
745 /// let client = Client::new()?.with_token("your-jwt-token")?;
746 /// # Ok(())
747 /// # }
748 /// ```
749 /// Extract server name from JWT token payload.
750 ///
751 /// Helper method to parse the JWT token and extract the "server" field
752 /// from the payload. Returns the server name (e.g., "test", "stage", "")
753 /// or an error if the token is invalid.
754 fn extract_server_from_token(token: &str) -> Result<String, Error> {
755 let token_parts: Vec<&str> = token.split('.').collect();
756 if token_parts.len() != 3 {
757 return Err(Error::InvalidToken);
758 }
759
760 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
761 .decode(token_parts[1])
762 .map_err(|_| Error::InvalidToken)?;
763 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
764 let server = match payload.get("server") {
765 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
766 None => return Err(Error::InvalidToken),
767 };
768
769 Ok(server)
770 }
771
772 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
773 if token.is_empty() {
774 return Ok(self.clone());
775 }
776
777 let server = Self::extract_server_from_token(token)?;
778
779 // Persist token to storage if configured
780 if let Some(ref storage) = self.storage
781 && let Err(e) = storage.store(token)
782 {
783 warn!("Failed to persist token to storage: {}", e);
784 }
785
786 Ok(Client {
787 url: format!("https://{}.edgefirst.studio", server),
788 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
789 ..self.clone()
790 })
791 }
792
793 /// Persist the current token to storage.
794 ///
795 /// This is automatically called when using [`with_login`][Self::with_login]
796 /// or [`with_token`][Self::with_token], so you typically don't need to call
797 /// this directly.
798 ///
799 /// If using the legacy `token_path` configuration, saves to the file path.
800 /// If using the new storage abstraction, saves to the configured storage.
801 pub async fn save_token(&self) -> Result<(), Error> {
802 let token = self.token.read().await;
803
804 // Try new storage first
805 if let Some(ref storage) = self.storage {
806 storage.store(&token)?;
807 debug!("Token saved to storage");
808 return Ok(());
809 }
810
811 // Fall back to legacy token_path behavior
812 let path = self.token_path.clone().unwrap_or_else(|| {
813 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
814 .map(|dirs| dirs.config_dir().join("token"))
815 .unwrap_or_else(|| PathBuf::from(".token"))
816 });
817
818 create_dir_all(path.parent().ok_or_else(|| {
819 Error::IoError(std::io::Error::new(
820 std::io::ErrorKind::InvalidInput,
821 "Token path has no parent directory",
822 ))
823 })?)?;
824 let mut file = std::fs::File::create(&path)?;
825 file.write_all(token.as_bytes())?;
826
827 debug!("Saved token to {:?}", path);
828
829 Ok(())
830 }
831
832 /// Return the version of the EdgeFirst Studio server for the current
833 /// client connection.
834 pub async fn version(&self) -> Result<String, Error> {
835 let version: HashMap<String, String> = self
836 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
837 .await?;
838 let version = version.get("version").ok_or(Error::InvalidResponse)?;
839 Ok(version.to_owned())
840 }
841
842 /// Clear the token used to authenticate the client with the server.
843 ///
844 /// Clears the token from memory and from storage (if configured).
845 /// If using the legacy `token_path` configuration, removes the token file.
846 pub async fn logout(&self) -> Result<(), Error> {
847 {
848 let mut token = self.token.write().await;
849 *token = "".to_string();
850 }
851
852 // Clear from new storage if configured
853 if let Some(ref storage) = self.storage
854 && let Err(e) = storage.clear()
855 {
856 warn!("Failed to clear token from storage: {}", e);
857 }
858
859 // Also clear legacy token_path if configured
860 if let Some(path) = &self.token_path
861 && path.exists()
862 {
863 fs::remove_file(path).await?;
864 }
865
866 Ok(())
867 }
868
869 /// Return the token used to authenticate the client with the server. When
870 /// logging into the server using a username and password, the token is
871 /// returned by the server and stored in the client for future interactions.
872 pub async fn token(&self) -> String {
873 self.token.read().await.clone()
874 }
875
876 /// Verify the token used to authenticate the client with the server. This
877 /// method is used to ensure that the token is still valid and has not
878 /// expired. If the token is invalid, the server will return an error and
879 /// the client will need to login again.
880 pub async fn verify_token(&self) -> Result<(), Error> {
881 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
882 .await?;
883 Ok::<(), Error>(())
884 }
885
886 /// Renew the token used to authenticate the client with the server.
887 ///
888 /// Refreshes the token before it expires. If the token has already expired,
889 /// the server will return an error and you will need to login again.
890 ///
891 /// The new token is automatically persisted to storage (if configured).
892 pub async fn renew_token(&self) -> Result<(), Error> {
893 let params = HashMap::from([("username".to_string(), self.username().await?)]);
894 let result: LoginResult = self
895 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
896 .await?;
897
898 {
899 let mut token = self.token.write().await;
900 *token = result.token.clone();
901 }
902
903 // Persist to new storage if configured
904 if let Some(ref storage) = self.storage
905 && let Err(e) = storage.store(&result.token)
906 {
907 warn!("Failed to persist renewed token to storage: {}", e);
908 }
909
910 // Also persist to legacy token_path if configured
911 if self.token_path.is_some() {
912 self.save_token().await?;
913 }
914
915 Ok(())
916 }
917
918 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
919 let token = self.token.read().await;
920 if token.is_empty() {
921 return Err(Error::EmptyToken);
922 }
923
924 let token_parts: Vec<&str> = token.split('.').collect();
925 if token_parts.len() != 3 {
926 return Err(Error::InvalidToken);
927 }
928
929 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
930 .decode(token_parts[1])
931 .map_err(|_| Error::InvalidToken)?;
932 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
933 match payload.get(field) {
934 Some(value) => Ok(value.to_owned()),
935 None => Err(Error::InvalidToken),
936 }
937 }
938
939 /// Returns the URL of the EdgeFirst Studio server for the current client.
940 pub fn url(&self) -> &str {
941 &self.url
942 }
943
944 /// Returns the server name for the current client.
945 ///
946 /// This extracts the server name from the client's URL:
947 /// - `https://edgefirst.studio` → `"saas"`
948 /// - `https://test.edgefirst.studio` → `"test"`
949 /// - `https://{name}.edgefirst.studio` → `"{name}"`
950 ///
951 /// # Examples
952 ///
953 /// ```rust,no_run
954 /// use edgefirst_client::Client;
955 ///
956 /// # fn main() -> Result<(), edgefirst_client::Error> {
957 /// let client = Client::new()?.with_server("test")?;
958 /// assert_eq!(client.server(), "test");
959 ///
960 /// let client = Client::new()?; // default
961 /// assert_eq!(client.server(), "saas");
962 /// # Ok(())
963 /// # }
964 /// ```
965 pub fn server(&self) -> &str {
966 if self.url == "https://edgefirst.studio" {
967 "saas"
968 } else if let Some(name) = self.url.strip_prefix("https://") {
969 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
970 } else {
971 "saas"
972 }
973 }
974
975 /// Returns the username associated with the current token.
976 pub async fn username(&self) -> Result<String, Error> {
977 match self.token_field("username").await? {
978 serde_json::Value::String(username) => Ok(username),
979 _ => Err(Error::InvalidToken),
980 }
981 }
982
983 /// Returns the expiration time for the current token.
984 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
985 let ts = match self.token_field("exp").await? {
986 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
987 _ => return Err(Error::InvalidToken),
988 };
989
990 match DateTime::<Utc>::from_timestamp_secs(ts) {
991 Some(dt) => Ok(dt),
992 None => Err(Error::InvalidToken),
993 }
994 }
995
996 /// Returns the organization information for the current user.
997 pub async fn organization(&self) -> Result<Organization, Error> {
998 self.rpc::<(), Organization>("org.get".to_owned(), None)
999 .await
1000 }
1001
1002 /// Returns a list of projects available to the user. The projects are
1003 /// returned as a vector of Project objects. If a name filter is
1004 /// provided, only projects matching the filter are returned.
1005 ///
1006 /// Results are sorted by match quality: exact matches first, then
1007 /// case-insensitive exact matches, then shorter names (more specific),
1008 /// then alphabetically.
1009 ///
1010 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1011 /// Projects contain datasets, trainers, and trainer sessions. Projects
1012 /// are used to group related datasets and trainers together.
1013 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1014 let projects = self
1015 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1016 .await?;
1017 if let Some(name) = name {
1018 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1019 } else {
1020 Ok(projects)
1021 }
1022 }
1023
1024 /// Return the project with the specified project ID. If the project does
1025 /// not exist, an error is returned.
1026 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1027 let params = HashMap::from([("project_id", project_id)]);
1028 self.rpc("project.get".to_owned(), Some(params)).await
1029 }
1030
1031 /// Returns a list of datasets available to the user. The datasets are
1032 /// returned as a vector of Dataset objects. If a name filter is
1033 /// provided, only datasets matching the filter are returned.
1034 ///
1035 /// Results are sorted by match quality: exact matches first, then
1036 /// case-insensitive exact matches, then shorter names (more specific),
1037 /// then alphabetically. This ensures "Deer" returns before "Deer
1038 /// Roundtrip".
1039 pub async fn datasets(
1040 &self,
1041 project_id: ProjectID,
1042 name: Option<&str>,
1043 ) -> Result<Vec<Dataset>, Error> {
1044 let params = HashMap::from([("project_id", project_id)]);
1045 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1046 if let Some(name) = name {
1047 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1048 } else {
1049 Ok(datasets)
1050 }
1051 }
1052
1053 /// Return the dataset with the specified dataset ID. If the dataset does
1054 /// not exist, an error is returned.
1055 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1056 let params = HashMap::from([("dataset_id", dataset_id)]);
1057 self.rpc("dataset.get".to_owned(), Some(params)).await
1058 }
1059
1060 /// Lists the labels for the specified dataset.
1061 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1062 let params = HashMap::from([("dataset_id", dataset_id)]);
1063 self.rpc("label.list".to_owned(), Some(params)).await
1064 }
1065
1066 /// Add a new label to the dataset with the specified name.
1067 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1068 let new_label = NewLabel {
1069 dataset_id,
1070 labels: vec![NewLabelObject {
1071 name: name.to_owned(),
1072 }],
1073 };
1074 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1075 Ok(())
1076 }
1077
1078 /// Removes the label with the specified ID from the dataset. Label IDs are
1079 /// globally unique so the dataset_id is not required.
1080 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1081 let params = HashMap::from([("label_id", label_id)]);
1082 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1083 Ok(())
1084 }
1085
1086 /// Creates a new dataset in the specified project.
1087 ///
1088 /// # Arguments
1089 ///
1090 /// * `project_id` - The ID of the project to create the dataset in
1091 /// * `name` - The name of the new dataset
1092 /// * `description` - Optional description for the dataset
1093 ///
1094 /// # Returns
1095 ///
1096 /// Returns the dataset ID of the newly created dataset.
1097 pub async fn create_dataset(
1098 &self,
1099 project_id: &str,
1100 name: &str,
1101 description: Option<&str>,
1102 ) -> Result<DatasetID, Error> {
1103 let mut params = HashMap::new();
1104 params.insert("project_id", project_id);
1105 params.insert("name", name);
1106 if let Some(desc) = description {
1107 params.insert("description", desc);
1108 }
1109
1110 #[derive(Deserialize)]
1111 struct CreateDatasetResult {
1112 id: DatasetID,
1113 }
1114
1115 let result: CreateDatasetResult =
1116 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1117 Ok(result.id)
1118 }
1119
1120 /// Deletes a dataset by marking it as deleted.
1121 ///
1122 /// # Arguments
1123 ///
1124 /// * `dataset_id` - The ID of the dataset to delete
1125 ///
1126 /// # Returns
1127 ///
1128 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1129 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1130 let params = HashMap::from([("id", dataset_id)]);
1131 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1132 Ok(())
1133 }
1134
1135 /// Updates the label with the specified ID to have the new name or index.
1136 /// Label IDs cannot be changed. Label IDs are globally unique so the
1137 /// dataset_id is not required.
1138 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1139 #[derive(Serialize)]
1140 struct Params {
1141 dataset_id: DatasetID,
1142 label_id: u64,
1143 label_name: String,
1144 label_index: u64,
1145 }
1146
1147 let _: String = self
1148 .rpc(
1149 "label.update".to_owned(),
1150 Some(Params {
1151 dataset_id: label.dataset_id(),
1152 label_id: label.id(),
1153 label_name: label.name().to_owned(),
1154 label_index: label.index(),
1155 }),
1156 )
1157 .await?;
1158 Ok(())
1159 }
1160
1161 /// Downloads dataset samples to the local filesystem.
1162 ///
1163 /// # Arguments
1164 ///
1165 /// * `dataset_id` - The unique identifier of the dataset
1166 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1167 /// * `file_types` - File types to download (e.g., Image, LidarPcd)
1168 /// * `output` - Local directory to save downloaded files
1169 /// * `flatten` - If true, download all files to output root without
1170 /// sequence subdirectories. When flattening, filenames are prefixed with
1171 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1172 /// unavailable) unless the filename already starts with
1173 /// `{sequence_name}_`, to avoid conflicts between sequences.
1174 /// * `progress` - Optional channel for progress updates
1175 ///
1176 /// # Returns
1177 ///
1178 /// Returns `Ok(())` on success or an error if download fails.
1179 ///
1180 /// # Example
1181 ///
1182 /// ```rust,no_run
1183 /// # use edgefirst_client::{Client, DatasetID, FileType};
1184 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1185 /// let client = Client::new()?.with_token_path(None)?;
1186 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1187 ///
1188 /// // Download with sequence subdirectories (default)
1189 /// client
1190 /// .download_dataset(
1191 /// dataset_id,
1192 /// &[],
1193 /// &[FileType::Image],
1194 /// "./data".into(),
1195 /// false,
1196 /// None,
1197 /// )
1198 /// .await?;
1199 ///
1200 /// // Download flattened (all files in one directory)
1201 /// client
1202 /// .download_dataset(
1203 /// dataset_id,
1204 /// &[],
1205 /// &[FileType::Image],
1206 /// "./data".into(),
1207 /// true,
1208 /// None,
1209 /// )
1210 /// .await?;
1211 /// # Ok(())
1212 /// # }
1213 /// ```
1214 pub async fn download_dataset(
1215 &self,
1216 dataset_id: DatasetID,
1217 groups: &[String],
1218 file_types: &[FileType],
1219 output: PathBuf,
1220 flatten: bool,
1221 progress: Option<Sender<Progress>>,
1222 ) -> Result<(), Error> {
1223 let samples = self
1224 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1225 .await?;
1226 fs::create_dir_all(&output).await?;
1227
1228 let client = self.clone();
1229 let file_types = file_types.to_vec();
1230 let output = output.clone();
1231
1232 parallel_foreach_items(samples, progress, move |sample| {
1233 let client = client.clone();
1234 let file_types = file_types.clone();
1235 let output = output.clone();
1236
1237 async move {
1238 for file_type in file_types {
1239 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1240 let (file_ext, is_image) = match file_type.clone() {
1241 FileType::Image => (
1242 infer::get(&data)
1243 .expect("Failed to identify image file format for sample")
1244 .extension()
1245 .to_string(),
1246 true,
1247 ),
1248 other => (other.to_string(), false),
1249 };
1250
1251 // Determine target directory based on sequence membership and flatten
1252 // option
1253 // - flatten=false + sequence_name: dataset/sequence_name/
1254 // - flatten=false + no sequence: dataset/ (root level)
1255 // - flatten=true: dataset/ (all files in output root)
1256 // NOTE: group (train/val/test) is NOT used for directory structure
1257 let sequence_dir = sample
1258 .sequence_name()
1259 .map(|name| sanitize_path_component(name));
1260
1261 let target_dir = if flatten {
1262 output.clone()
1263 } else {
1264 sequence_dir
1265 .as_ref()
1266 .map(|seq| output.join(seq))
1267 .unwrap_or_else(|| output.clone())
1268 };
1269 fs::create_dir_all(&target_dir).await?;
1270
1271 let sanitized_sample_name = sample
1272 .name()
1273 .map(|name| sanitize_path_component(&name))
1274 .unwrap_or_else(|| "unknown".to_string());
1275
1276 let image_name = sample.image_name().map(sanitize_path_component);
1277
1278 // Construct filename with smart prefixing for flatten mode
1279 // When flatten=true and sample belongs to a sequence:
1280 // - Check if filename already starts with "{sequence_name}_"
1281 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1282 // - If yes, use filename as-is (already uniquely named)
1283 let file_name = if is_image {
1284 if let Some(img_name) = image_name {
1285 Self::build_filename(
1286 &img_name,
1287 flatten,
1288 sequence_dir.as_ref(),
1289 sample.frame_number(),
1290 )
1291 } else {
1292 format!("{}.{}", sanitized_sample_name, file_ext)
1293 }
1294 } else {
1295 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1296 Self::build_filename(
1297 &base_name,
1298 flatten,
1299 sequence_dir.as_ref(),
1300 sample.frame_number(),
1301 )
1302 };
1303
1304 let file_path = target_dir.join(&file_name);
1305
1306 let mut file = File::create(&file_path).await?;
1307 file.write_all(&data).await?;
1308 } else {
1309 warn!(
1310 "No data for sample: {}",
1311 sample
1312 .id()
1313 .map(|id| id.to_string())
1314 .unwrap_or_else(|| "unknown".to_string())
1315 );
1316 }
1317 }
1318
1319 Ok(())
1320 }
1321 })
1322 .await
1323 }
1324
1325 /// Builds a filename with smart prefixing for flatten mode.
1326 ///
1327 /// When flattening sequences into a single directory, this function ensures
1328 /// unique filenames by checking if the sequence prefix already exists and
1329 /// adding it if necessary.
1330 ///
1331 /// # Logic
1332 ///
1333 /// - If `flatten=false`: returns `base_name` unchanged
1334 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
1335 /// - If `flatten=true` and in sequence:
1336 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
1337 /// unchanged
1338 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
1339 /// `{sequence_name}_{base_name}`
1340 fn build_filename(
1341 base_name: &str,
1342 flatten: bool,
1343 sequence_name: Option<&String>,
1344 frame_number: Option<u32>,
1345 ) -> String {
1346 if !flatten || sequence_name.is_none() {
1347 return base_name.to_string();
1348 }
1349
1350 let seq_name = sequence_name.unwrap();
1351 let prefix = format!("{}_", seq_name);
1352
1353 // Check if already prefixed with sequence name
1354 if base_name.starts_with(&prefix) {
1355 base_name.to_string()
1356 } else {
1357 // Add sequence (and optionally frame) prefix
1358 match frame_number {
1359 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1360 None => format!("{}{}", prefix, base_name),
1361 }
1362 }
1363 }
1364
1365 /// List available annotation sets for the specified dataset.
1366 pub async fn annotation_sets(
1367 &self,
1368 dataset_id: DatasetID,
1369 ) -> Result<Vec<AnnotationSet>, Error> {
1370 let params = HashMap::from([("dataset_id", dataset_id)]);
1371 self.rpc("annset.list".to_owned(), Some(params)).await
1372 }
1373
1374 /// Create a new annotation set for the specified dataset.
1375 ///
1376 /// # Arguments
1377 ///
1378 /// * `dataset_id` - The ID of the dataset to create the annotation set in
1379 /// * `name` - The name of the new annotation set
1380 /// * `description` - Optional description for the annotation set
1381 ///
1382 /// # Returns
1383 ///
1384 /// Returns the annotation set ID of the newly created annotation set.
1385 pub async fn create_annotation_set(
1386 &self,
1387 dataset_id: DatasetID,
1388 name: &str,
1389 description: Option<&str>,
1390 ) -> Result<AnnotationSetID, Error> {
1391 #[derive(Serialize)]
1392 struct Params<'a> {
1393 dataset_id: DatasetID,
1394 name: &'a str,
1395 operator: &'a str,
1396 #[serde(skip_serializing_if = "Option::is_none")]
1397 description: Option<&'a str>,
1398 }
1399
1400 #[derive(Deserialize)]
1401 struct CreateAnnotationSetResult {
1402 id: AnnotationSetID,
1403 }
1404
1405 let username = self.username().await?;
1406 let result: CreateAnnotationSetResult = self
1407 .rpc(
1408 "annset.add".to_owned(),
1409 Some(Params {
1410 dataset_id,
1411 name,
1412 operator: &username,
1413 description,
1414 }),
1415 )
1416 .await?;
1417 Ok(result.id)
1418 }
1419
1420 /// Deletes an annotation set by marking it as deleted.
1421 ///
1422 /// # Arguments
1423 ///
1424 /// * `annotation_set_id` - The ID of the annotation set to delete
1425 ///
1426 /// # Returns
1427 ///
1428 /// Returns `Ok(())` if the annotation set was successfully marked as
1429 /// deleted.
1430 pub async fn delete_annotation_set(
1431 &self,
1432 annotation_set_id: AnnotationSetID,
1433 ) -> Result<(), Error> {
1434 let params = HashMap::from([("id", annotation_set_id)]);
1435 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1436 Ok(())
1437 }
1438
1439 /// Retrieve the annotation set with the specified ID.
1440 pub async fn annotation_set(
1441 &self,
1442 annotation_set_id: AnnotationSetID,
1443 ) -> Result<AnnotationSet, Error> {
1444 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1445 self.rpc("annset.get".to_owned(), Some(params)).await
1446 }
1447
1448 /// Get the annotations for the specified annotation set with the
1449 /// requested annotation types. The annotation types are used to filter
1450 /// the annotations returned. The groups parameter is used to filter for
1451 /// dataset groups (train, val, test). Images which do not have any
1452 /// annotations are also included in the result as long as they are in the
1453 /// requested groups (when specified).
1454 ///
1455 /// The result is a vector of Annotations objects which contain the
1456 /// full dataset along with the annotations for the specified types.
1457 ///
1458 /// To get the annotations as a DataFrame, use the `annotations_dataframe`
1459 /// method instead.
1460 pub async fn annotations(
1461 &self,
1462 annotation_set_id: AnnotationSetID,
1463 groups: &[String],
1464 annotation_types: &[AnnotationType],
1465 progress: Option<Sender<Progress>>,
1466 ) -> Result<Vec<Annotation>, Error> {
1467 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1468 let labels = self
1469 .labels(dataset_id)
1470 .await?
1471 .into_iter()
1472 .map(|label| (label.name().to_string(), label.index()))
1473 .collect::<HashMap<_, _>>();
1474 let total = self
1475 .samples_count(
1476 dataset_id,
1477 Some(annotation_set_id),
1478 annotation_types,
1479 groups,
1480 &[],
1481 )
1482 .await?
1483 .total as usize;
1484
1485 if total == 0 {
1486 return Ok(vec![]);
1487 }
1488
1489 let context = FetchContext {
1490 dataset_id,
1491 annotation_set_id: Some(annotation_set_id),
1492 groups,
1493 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1494 labels: &labels,
1495 };
1496
1497 self.fetch_annotations_paginated(context, total, progress)
1498 .await
1499 }
1500
1501 async fn fetch_annotations_paginated(
1502 &self,
1503 context: FetchContext<'_>,
1504 total: usize,
1505 progress: Option<Sender<Progress>>,
1506 ) -> Result<Vec<Annotation>, Error> {
1507 let mut annotations = vec![];
1508 let mut continue_token: Option<String> = None;
1509 let mut current = 0;
1510
1511 loop {
1512 let params = SamplesListParams {
1513 dataset_id: context.dataset_id,
1514 annotation_set_id: context.annotation_set_id,
1515 types: context.types.clone(),
1516 group_names: context.groups.to_vec(),
1517 continue_token,
1518 };
1519
1520 let result: SamplesListResult =
1521 self.rpc("samples.list".to_owned(), Some(params)).await?;
1522 current += result.samples.len();
1523 continue_token = result.continue_token;
1524
1525 if result.samples.is_empty() {
1526 break;
1527 }
1528
1529 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1530
1531 if let Some(progress) = &progress {
1532 let _ = progress.send(Progress { current, total }).await;
1533 }
1534
1535 match &continue_token {
1536 Some(token) if !token.is_empty() => continue,
1537 _ => break,
1538 }
1539 }
1540
1541 drop(progress);
1542 Ok(annotations)
1543 }
1544
1545 fn process_sample_annotations(
1546 &self,
1547 samples: &[Sample],
1548 labels: &HashMap<String, u64>,
1549 annotations: &mut Vec<Annotation>,
1550 ) {
1551 for sample in samples {
1552 if sample.annotations().is_empty() {
1553 let mut annotation = Annotation::new();
1554 annotation.set_sample_id(sample.id());
1555 annotation.set_name(sample.name());
1556 annotation.set_sequence_name(sample.sequence_name().cloned());
1557 annotation.set_frame_number(sample.frame_number());
1558 annotation.set_group(sample.group().cloned());
1559 annotations.push(annotation);
1560 continue;
1561 }
1562
1563 for annotation in sample.annotations() {
1564 let mut annotation = annotation.clone();
1565 annotation.set_sample_id(sample.id());
1566 annotation.set_name(sample.name());
1567 annotation.set_sequence_name(sample.sequence_name().cloned());
1568 annotation.set_frame_number(sample.frame_number());
1569 annotation.set_group(sample.group().cloned());
1570 Self::set_label_index_from_map(&mut annotation, labels);
1571 annotations.push(annotation);
1572 }
1573 }
1574 }
1575
1576 /// Helper to parse frame number from image_name when sequence_name is
1577 /// present. This ensures frame_number is always derived from the image
1578 /// filename, not from the server's frame_number field (which may be
1579 /// inconsistent).
1580 ///
1581 /// Returns Some(frame_number) if sequence_name is present and frame can be
1582 /// parsed, otherwise None.
1583 fn parse_frame_from_image_name(
1584 image_name: Option<&String>,
1585 sequence_name: Option<&String>,
1586 ) -> Option<u32> {
1587 use std::path::Path;
1588
1589 let sequence = sequence_name?;
1590 let name = image_name?;
1591
1592 // Extract stem (remove extension)
1593 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1594
1595 // Parse frame from format: "sequence_XXX" where XXX is the frame number
1596 stem.strip_prefix(sequence)
1597 .and_then(|suffix| suffix.strip_prefix('_'))
1598 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1599 }
1600
1601 /// Helper to set label index from a label map
1602 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1603 if let Some(label) = annotation.label() {
1604 annotation.set_label_index(Some(labels[label.as_str()]));
1605 }
1606 }
1607
1608 pub async fn samples_count(
1609 &self,
1610 dataset_id: DatasetID,
1611 annotation_set_id: Option<AnnotationSetID>,
1612 annotation_types: &[AnnotationType],
1613 groups: &[String],
1614 types: &[FileType],
1615 ) -> Result<SamplesCountResult, Error> {
1616 let types = annotation_types
1617 .iter()
1618 .map(|t| t.to_string())
1619 .chain(types.iter().map(|t| t.to_string()))
1620 .collect::<Vec<_>>();
1621
1622 let params = SamplesListParams {
1623 dataset_id,
1624 annotation_set_id,
1625 group_names: groups.to_vec(),
1626 types,
1627 continue_token: None,
1628 };
1629
1630 self.rpc("samples.count".to_owned(), Some(params)).await
1631 }
1632
1633 pub async fn samples(
1634 &self,
1635 dataset_id: DatasetID,
1636 annotation_set_id: Option<AnnotationSetID>,
1637 annotation_types: &[AnnotationType],
1638 groups: &[String],
1639 types: &[FileType],
1640 progress: Option<Sender<Progress>>,
1641 ) -> Result<Vec<Sample>, Error> {
1642 let types_vec = annotation_types
1643 .iter()
1644 .map(|t| t.to_string())
1645 .chain(types.iter().map(|t| t.to_string()))
1646 .collect::<Vec<_>>();
1647 let labels = self
1648 .labels(dataset_id)
1649 .await?
1650 .into_iter()
1651 .map(|label| (label.name().to_string(), label.index()))
1652 .collect::<HashMap<_, _>>();
1653 let total = self
1654 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1655 .await?
1656 .total as usize;
1657
1658 if total == 0 {
1659 return Ok(vec![]);
1660 }
1661
1662 let context = FetchContext {
1663 dataset_id,
1664 annotation_set_id,
1665 groups,
1666 types: types_vec,
1667 labels: &labels,
1668 };
1669
1670 self.fetch_samples_paginated(context, total, progress).await
1671 }
1672
1673 async fn fetch_samples_paginated(
1674 &self,
1675 context: FetchContext<'_>,
1676 total: usize,
1677 progress: Option<Sender<Progress>>,
1678 ) -> Result<Vec<Sample>, Error> {
1679 let mut samples = vec![];
1680 let mut continue_token: Option<String> = None;
1681 let mut current = 0;
1682
1683 loop {
1684 let params = SamplesListParams {
1685 dataset_id: context.dataset_id,
1686 annotation_set_id: context.annotation_set_id,
1687 types: context.types.clone(),
1688 group_names: context.groups.to_vec(),
1689 continue_token: continue_token.clone(),
1690 };
1691
1692 let result: SamplesListResult =
1693 self.rpc("samples.list".to_owned(), Some(params)).await?;
1694 current += result.samples.len();
1695 continue_token = result.continue_token;
1696
1697 if result.samples.is_empty() {
1698 break;
1699 }
1700
1701 samples.append(
1702 &mut result
1703 .samples
1704 .into_iter()
1705 .map(|s| {
1706 // Use server's frame_number if valid (>= 0 after deserialization)
1707 // Otherwise parse from image_name as fallback
1708 // This ensures we respect explicit frame_number from uploads
1709 // while still handling legacy data that only has filename encoding
1710 let frame_number = s.frame_number.or_else(|| {
1711 Self::parse_frame_from_image_name(
1712 s.image_name.as_ref(),
1713 s.sequence_name.as_ref(),
1714 )
1715 });
1716
1717 let mut anns = s.annotations().to_vec();
1718 for ann in &mut anns {
1719 // Set annotation fields from parent sample
1720 ann.set_name(s.name());
1721 ann.set_group(s.group().cloned());
1722 ann.set_sequence_name(s.sequence_name().cloned());
1723 ann.set_frame_number(frame_number);
1724 Self::set_label_index_from_map(ann, context.labels);
1725 }
1726 s.with_annotations(anns).with_frame_number(frame_number)
1727 })
1728 .collect::<Vec<_>>(),
1729 );
1730
1731 if let Some(progress) = &progress {
1732 let _ = progress.send(Progress { current, total }).await;
1733 }
1734
1735 match &continue_token {
1736 Some(token) if !token.is_empty() => continue,
1737 _ => break,
1738 }
1739 }
1740
1741 drop(progress);
1742 Ok(samples)
1743 }
1744
1745 /// Populates (imports) samples into a dataset using the `samples.populate2`
1746 /// API.
1747 ///
1748 /// This method creates new samples in the specified dataset, optionally
1749 /// with annotations and sensor data files. For each sample, the `files`
1750 /// field is checked for local file paths. If a filename is a valid path
1751 /// to an existing file, the file will be automatically uploaded to S3
1752 /// using presigned URLs returned by the server. The filename in the
1753 /// request is replaced with the basename (path removed) before sending
1754 /// to the server.
1755 ///
1756 /// # Important Notes
1757 ///
1758 /// - **`annotation_set_id` is REQUIRED** when importing samples with
1759 /// annotations. Without it, the server will accept the request but will
1760 /// not save the annotation data. Use [`Client::annotation_sets`] to query
1761 /// available annotation sets for a dataset, or create a new one via the
1762 /// Studio UI.
1763 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
1764 /// boxes. Divide pixel coordinates by image width/height before creating
1765 /// [`Box2d`](crate::Box2d) annotations.
1766 /// - **Files are uploaded automatically** when the filename is a valid
1767 /// local path. The method will replace the full path with just the
1768 /// basename before sending to the server.
1769 /// - **Image dimensions are extracted automatically** for image files using
1770 /// the `imagesize` crate. The width/height are sent to the server, but
1771 /// note that the server currently doesn't return these fields when
1772 /// fetching samples back.
1773 /// - **UUIDs are generated automatically** if not provided. If you need
1774 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
1775 /// that the server doesn't currently return UUIDs in sample queries.
1776 ///
1777 /// # Arguments
1778 ///
1779 /// * `dataset_id` - The ID of the dataset to populate
1780 /// * `annotation_set_id` - **Required** if samples contain annotations,
1781 /// otherwise they will be ignored. Query with
1782 /// [`Client::annotation_sets`].
1783 /// * `samples` - Vector of samples to import with metadata and file
1784 /// references. For files, use the full local path - it will be uploaded
1785 /// automatically. UUIDs and image dimensions will be
1786 /// auto-generated/extracted if not provided.
1787 ///
1788 /// # Returns
1789 ///
1790 /// Returns the API result with sample UUIDs and upload status.
1791 ///
1792 /// # Example
1793 ///
1794 /// ```no_run
1795 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
1796 ///
1797 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1798 /// # let client = Client::new()?.with_login("user", "pass").await?;
1799 /// # let dataset_id = DatasetID::from(1);
1800 /// // Query available annotation sets for the dataset
1801 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
1802 /// let annotation_set_id = annotation_sets
1803 /// .first()
1804 /// .ok_or_else(|| {
1805 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
1806 /// })?
1807 /// .id();
1808 ///
1809 /// // Create sample with annotation (UUID will be auto-generated)
1810 /// let mut sample = Sample::new();
1811 /// sample.width = Some(1920);
1812 /// sample.height = Some(1080);
1813 /// sample.group = Some("train".to_string());
1814 ///
1815 /// // Add file - use full path to local file, it will be uploaded automatically
1816 /// sample.files = vec![SampleFile::with_filename(
1817 /// "image".to_string(),
1818 /// "/path/to/image.jpg".to_string(),
1819 /// )];
1820 ///
1821 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
1822 /// let mut annotation = Annotation::new();
1823 /// annotation.set_label(Some("person".to_string()));
1824 /// // Normalize pixel coordinates by dividing by image dimensions
1825 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
1826 /// annotation.set_box2d(Some(bbox));
1827 /// sample.annotations = vec![annotation];
1828 ///
1829 /// // Populate with annotation_set_id (REQUIRED for annotations)
1830 /// let result = client
1831 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
1832 /// .await?;
1833 /// # Ok(())
1834 /// # }
1835 /// ```
1836 pub async fn populate_samples(
1837 &self,
1838 dataset_id: DatasetID,
1839 annotation_set_id: Option<AnnotationSetID>,
1840 samples: Vec<Sample>,
1841 progress: Option<Sender<Progress>>,
1842 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1843 use crate::api::SamplesPopulateParams;
1844
1845 // Track which files need to be uploaded
1846 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1847
1848 // Process samples to detect local files and generate UUIDs
1849 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1850
1851 let has_files_to_upload = !files_to_upload.is_empty();
1852
1853 // Call populate API with presigned_urls=true if we have files to upload
1854 let params = SamplesPopulateParams {
1855 dataset_id,
1856 annotation_set_id,
1857 presigned_urls: Some(has_files_to_upload),
1858 samples,
1859 };
1860
1861 let results: Vec<crate::SamplesPopulateResult> = self
1862 .rpc("samples.populate2".to_owned(), Some(params))
1863 .await?;
1864
1865 // Upload files if we have any
1866 if has_files_to_upload {
1867 self.upload_sample_files(&results, files_to_upload, progress)
1868 .await?;
1869 }
1870
1871 Ok(results)
1872 }
1873
1874 fn prepare_samples_for_upload(
1875 &self,
1876 samples: Vec<Sample>,
1877 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1878 ) -> Result<Vec<Sample>, Error> {
1879 Ok(samples
1880 .into_iter()
1881 .map(|mut sample| {
1882 // Generate UUID if not provided
1883 if sample.uuid.is_none() {
1884 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1885 }
1886
1887 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1888
1889 // Process files: detect local paths and queue for upload
1890 let files_copy = sample.files.clone();
1891 let updated_files: Vec<crate::SampleFile> = files_copy
1892 .iter()
1893 .map(|file| {
1894 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1895 })
1896 .collect();
1897
1898 sample.files = updated_files;
1899 sample
1900 })
1901 .collect())
1902 }
1903
1904 fn process_sample_file(
1905 &self,
1906 file: &crate::SampleFile,
1907 sample_uuid: &str,
1908 sample: &mut Sample,
1909 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1910 ) -> crate::SampleFile {
1911 use std::path::Path;
1912
1913 if let Some(filename) = file.filename() {
1914 let path = Path::new(filename);
1915
1916 // Check if this is a valid local file path
1917 if path.exists()
1918 && path.is_file()
1919 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1920 {
1921 // For image files, try to extract dimensions if not already set
1922 if file.file_type() == "image"
1923 && (sample.width.is_none() || sample.height.is_none())
1924 && let Ok(size) = imagesize::size(path)
1925 {
1926 sample.width = Some(size.width as u32);
1927 sample.height = Some(size.height as u32);
1928 }
1929
1930 // Store the full path for later upload
1931 files_to_upload.push((
1932 sample_uuid.to_string(),
1933 file.file_type().to_string(),
1934 path.to_path_buf(),
1935 basename.to_string(),
1936 ));
1937
1938 // Return SampleFile with just the basename
1939 return crate::SampleFile::with_filename(
1940 file.file_type().to_string(),
1941 basename.to_string(),
1942 );
1943 }
1944 }
1945 // Return the file unchanged if not a local path
1946 file.clone()
1947 }
1948
1949 async fn upload_sample_files(
1950 &self,
1951 results: &[crate::SamplesPopulateResult],
1952 files_to_upload: Vec<(String, String, PathBuf, String)>,
1953 progress: Option<Sender<Progress>>,
1954 ) -> Result<(), Error> {
1955 // Build a map from (sample_uuid, basename) -> local_path
1956 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1957 for (uuid, _file_type, path, basename) in files_to_upload {
1958 upload_map.insert((uuid, basename), path);
1959 }
1960
1961 let http = self.http.clone();
1962
1963 // Extract the data we need for parallel upload
1964 let upload_tasks: Vec<_> = results
1965 .iter()
1966 .map(|result| (result.uuid.clone(), result.urls.clone()))
1967 .collect();
1968
1969 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1970 let http = http.clone();
1971 let upload_map = upload_map.clone();
1972
1973 async move {
1974 // Upload all files for this sample
1975 for url_info in &urls {
1976 if let Some(local_path) =
1977 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1978 {
1979 // Upload the file
1980 upload_file_to_presigned_url(
1981 http.clone(),
1982 &url_info.url,
1983 local_path.clone(),
1984 )
1985 .await?;
1986 }
1987 }
1988
1989 Ok(())
1990 }
1991 })
1992 .await
1993 }
1994
1995 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1996 // Uses default 120s timeout from client
1997 let resp = self.http.get(url).send().await?;
1998
1999 if !resp.status().is_success() {
2000 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2001 }
2002
2003 let bytes = resp.bytes().await?;
2004 Ok(bytes.to_vec())
2005 }
2006
2007 /// Get the AnnotationGroup for the specified annotation set with the
2008 /// requested annotation types. The annotation type is used to filter
2009 /// the annotations returned. Images which do not have any annotations
2010 /// are included in the result.
2011 ///
2012 /// Get annotations as a DataFrame (2025.01 schema).
2013 ///
2014 /// **DEPRECATED**: Use [`Client::samples_dataframe()`] instead for full
2015 /// 2025.10 schema support including optional metadata columns.
2016 ///
2017 /// The result is a DataFrame following the EdgeFirst Dataset Format
2018 /// definition with 9 columns (original schema). Does not include new
2019 /// optional columns added in 2025.10.
2020 ///
2021 /// # Migration
2022 ///
2023 /// ```rust,no_run
2024 /// # use edgefirst_client::Client;
2025 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2026 /// # let client = Client::new()?;
2027 /// # let dataset_id = 1.into();
2028 /// # let annotation_set_id = 1.into();
2029 /// # let groups = vec![];
2030 /// # let types = vec![];
2031 /// // OLD (deprecated):
2032 /// let df = client
2033 /// .annotations_dataframe(annotation_set_id, &groups, &types, None)
2034 /// .await?;
2035 ///
2036 /// // NEW (recommended):
2037 /// let df = client
2038 /// .samples_dataframe(dataset_id, Some(annotation_set_id), &groups, &types, None)
2039 /// .await?;
2040 /// # Ok(())
2041 /// # }
2042 /// ```
2043 ///
2044 /// To get the annotations as a vector of Annotation objects, use the
2045 /// `annotations` method instead.
2046 #[deprecated(
2047 since = "0.8.0",
2048 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
2049 )]
2050 #[cfg(feature = "polars")]
2051 pub async fn annotations_dataframe(
2052 &self,
2053 annotation_set_id: AnnotationSetID,
2054 groups: &[String],
2055 types: &[AnnotationType],
2056 progress: Option<Sender<Progress>>,
2057 ) -> Result<DataFrame, Error> {
2058 #[allow(deprecated)]
2059 use crate::dataset::annotations_dataframe;
2060
2061 let annotations = self
2062 .annotations(annotation_set_id, groups, types, progress)
2063 .await?;
2064 #[allow(deprecated)]
2065 annotations_dataframe(&annotations)
2066 }
2067
2068 /// Get samples as a DataFrame with complete 2025.10 schema.
2069 ///
2070 /// This is the recommended method for obtaining dataset annotations in
2071 /// DataFrame format. It includes all sample metadata (size, location,
2072 /// pose, degradation) as optional columns.
2073 ///
2074 /// # Arguments
2075 ///
2076 /// * `dataset_id` - Dataset identifier
2077 /// * `annotation_set_id` - Optional annotation set filter
2078 /// * `groups` - Dataset groups to include (train, val, test)
2079 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2080 /// * `progress` - Optional progress callback
2081 ///
2082 /// # Example
2083 ///
2084 /// ```rust,no_run
2085 /// use edgefirst_client::Client;
2086 ///
2087 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2088 /// # let client = Client::new()?;
2089 /// # let dataset_id = 1.into();
2090 /// # let annotation_set_id = 1.into();
2091 /// let df = client
2092 /// .samples_dataframe(
2093 /// dataset_id,
2094 /// Some(annotation_set_id),
2095 /// &["train".to_string()],
2096 /// &[],
2097 /// None,
2098 /// )
2099 /// .await?;
2100 /// println!("DataFrame shape: {:?}", df.shape());
2101 /// # Ok(())
2102 /// # }
2103 /// ```
2104 #[cfg(feature = "polars")]
2105 pub async fn samples_dataframe(
2106 &self,
2107 dataset_id: DatasetID,
2108 annotation_set_id: Option<AnnotationSetID>,
2109 groups: &[String],
2110 types: &[AnnotationType],
2111 progress: Option<Sender<Progress>>,
2112 ) -> Result<DataFrame, Error> {
2113 use crate::dataset::samples_dataframe;
2114
2115 let samples = self
2116 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
2117 .await?;
2118 samples_dataframe(&samples)
2119 }
2120
2121 /// List available snapshots. If a name is provided, only snapshots
2122 /// containing that name are returned.
2123 ///
2124 /// Results are sorted by match quality: exact matches first, then
2125 /// case-insensitive exact matches, then shorter descriptions (more
2126 /// specific), then alphabetically.
2127 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
2128 let snapshots: Vec<Snapshot> = self
2129 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
2130 .await?;
2131 if let Some(name) = name {
2132 Ok(filter_and_sort_by_name(snapshots, name, |s| {
2133 s.description()
2134 }))
2135 } else {
2136 Ok(snapshots)
2137 }
2138 }
2139
2140 /// Get the snapshot with the specified id.
2141 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
2142 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2143 self.rpc("snapshots.get".to_owned(), Some(params)).await
2144 }
2145
2146 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
2147 ///
2148 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
2149 /// pairs) that serve two primary purposes:
2150 ///
2151 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
2152 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
2153 /// restored with AGTG (Automatic Ground Truth Generation) and optional
2154 /// auto-depth processing.
2155 ///
2156 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
2157 /// migration between EdgeFirst Studio instances using the create →
2158 /// download → upload → restore workflow.
2159 ///
2160 /// Large files are automatically chunked into 100MB parts and uploaded
2161 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
2162 /// is streamed without loading into memory, maintaining constant memory
2163 /// usage.
2164 ///
2165 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2166 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
2167 /// better for large files to avoid timeout issues. Higher values (16-32)
2168 /// are better for many small files.
2169 ///
2170 /// # Arguments
2171 ///
2172 /// * `path` - Local file path to MCAP file or directory containing
2173 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
2174 /// * `progress` - Optional channel to receive upload progress updates
2175 ///
2176 /// # Returns
2177 ///
2178 /// Returns a `Snapshot` object with ID, description, status, path, and
2179 /// creation timestamp on success.
2180 ///
2181 /// # Errors
2182 ///
2183 /// Returns an error if:
2184 /// * Path doesn't exist or contains invalid UTF-8
2185 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
2186 /// * Upload fails or network error occurs
2187 /// * Server rejects the snapshot
2188 ///
2189 /// # Example
2190 ///
2191 /// ```no_run
2192 /// # use edgefirst_client::{Client, Progress};
2193 /// # use tokio::sync::mpsc;
2194 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2195 /// let client = Client::new()?.with_token_path(None)?;
2196 ///
2197 /// // Upload MCAP file with progress tracking
2198 /// let (tx, mut rx) = mpsc::channel(1);
2199 /// tokio::spawn(async move {
2200 /// while let Some(Progress { current, total }) = rx.recv().await {
2201 /// println!(
2202 /// "Upload: {}/{} bytes ({:.1}%)",
2203 /// current,
2204 /// total,
2205 /// (current as f64 / total as f64) * 100.0
2206 /// );
2207 /// }
2208 /// });
2209 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
2210 /// println!("Created snapshot: {:?}", snapshot.id());
2211 ///
2212 /// // Upload dataset directory (no progress)
2213 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
2214 /// # Ok(())
2215 /// # }
2216 /// ```
2217 ///
2218 /// # See Also
2219 ///
2220 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2221 /// dataset
2222 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2223 /// data
2224 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2225 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2226 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
2227 pub async fn create_snapshot(
2228 &self,
2229 path: &str,
2230 progress: Option<Sender<Progress>>,
2231 ) -> Result<Snapshot, Error> {
2232 let path = Path::new(path);
2233
2234 if path.is_dir() {
2235 let path_str = path.to_str().ok_or_else(|| {
2236 Error::IoError(std::io::Error::new(
2237 std::io::ErrorKind::InvalidInput,
2238 "Path contains invalid UTF-8",
2239 ))
2240 })?;
2241 return self.create_snapshot_folder(path_str, progress).await;
2242 }
2243
2244 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2245 Error::IoError(std::io::Error::new(
2246 std::io::ErrorKind::InvalidInput,
2247 "Invalid filename",
2248 ))
2249 })?;
2250 let total = path.metadata()?.len() as usize;
2251 let current = Arc::new(AtomicUsize::new(0));
2252
2253 if let Some(progress) = &progress {
2254 let _ = progress.send(Progress { current: 0, total }).await;
2255 }
2256
2257 let params = SnapshotCreateMultipartParams {
2258 snapshot_name: name.to_owned(),
2259 keys: vec![name.to_owned()],
2260 file_sizes: vec![total],
2261 };
2262 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2263 .rpc(
2264 "snapshots.create_upload_url_multipart".to_owned(),
2265 Some(params),
2266 )
2267 .await?;
2268
2269 let snapshot_id = match multipart.get("snapshot_id") {
2270 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2271 _ => return Err(Error::InvalidResponse),
2272 };
2273
2274 let snapshot = self.snapshot(snapshot_id).await?;
2275 let part_prefix = snapshot
2276 .path()
2277 .split("::/")
2278 .last()
2279 .ok_or(Error::InvalidResponse)?
2280 .to_owned();
2281 let part_key = format!("{}/{}", part_prefix, name);
2282 let mut part = match multipart.get(&part_key) {
2283 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2284 _ => return Err(Error::InvalidResponse),
2285 }
2286 .clone();
2287 part.key = Some(part_key);
2288
2289 let params = upload_multipart(
2290 self.http.clone(),
2291 part.clone(),
2292 path.to_path_buf(),
2293 total,
2294 current,
2295 progress.clone(),
2296 )
2297 .await?;
2298
2299 let complete: String = self
2300 .rpc(
2301 "snapshots.complete_multipart_upload".to_owned(),
2302 Some(params),
2303 )
2304 .await?;
2305 debug!("Snapshot Multipart Complete: {:?}", complete);
2306
2307 let params: SnapshotStatusParams = SnapshotStatusParams {
2308 snapshot_id,
2309 status: "available".to_owned(),
2310 };
2311 let _: SnapshotStatusResult = self
2312 .rpc("snapshots.update".to_owned(), Some(params))
2313 .await?;
2314
2315 if let Some(progress) = progress {
2316 drop(progress);
2317 }
2318
2319 self.snapshot(snapshot_id).await
2320 }
2321
2322 async fn create_snapshot_folder(
2323 &self,
2324 path: &str,
2325 progress: Option<Sender<Progress>>,
2326 ) -> Result<Snapshot, Error> {
2327 let path = Path::new(path);
2328 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2329 Error::IoError(std::io::Error::new(
2330 std::io::ErrorKind::InvalidInput,
2331 "Invalid directory name",
2332 ))
2333 })?;
2334
2335 let files = WalkDir::new(path)
2336 .into_iter()
2337 .filter_map(|entry| entry.ok())
2338 .filter(|entry| entry.file_type().is_file())
2339 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
2340 .collect::<Vec<_>>();
2341
2342 let total: usize = files
2343 .iter()
2344 .filter_map(|file| path.join(file).metadata().ok())
2345 .map(|metadata| metadata.len() as usize)
2346 .sum();
2347 let current = Arc::new(AtomicUsize::new(0));
2348
2349 if let Some(progress) = &progress {
2350 let _ = progress.send(Progress { current: 0, total }).await;
2351 }
2352
2353 let keys = files
2354 .iter()
2355 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
2356 .collect::<Vec<_>>();
2357 let file_sizes = files
2358 .iter()
2359 .filter_map(|key| path.join(key).metadata().ok())
2360 .map(|metadata| metadata.len() as usize)
2361 .collect::<Vec<_>>();
2362
2363 let params = SnapshotCreateMultipartParams {
2364 snapshot_name: name.to_owned(),
2365 keys,
2366 file_sizes,
2367 };
2368
2369 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2370 .rpc(
2371 "snapshots.create_upload_url_multipart".to_owned(),
2372 Some(params),
2373 )
2374 .await?;
2375
2376 let snapshot_id = match multipart.get("snapshot_id") {
2377 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2378 _ => return Err(Error::InvalidResponse),
2379 };
2380
2381 let snapshot = self.snapshot(snapshot_id).await?;
2382 let part_prefix = snapshot
2383 .path()
2384 .split("::/")
2385 .last()
2386 .ok_or(Error::InvalidResponse)?
2387 .to_owned();
2388
2389 for file in files {
2390 let file_str = file.to_str().ok_or_else(|| {
2391 Error::IoError(std::io::Error::new(
2392 std::io::ErrorKind::InvalidInput,
2393 "File path contains invalid UTF-8",
2394 ))
2395 })?;
2396 let part_key = format!("{}/{}", part_prefix, file_str);
2397 let mut part = match multipart.get(&part_key) {
2398 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2399 _ => return Err(Error::InvalidResponse),
2400 }
2401 .clone();
2402 part.key = Some(part_key);
2403
2404 let params = upload_multipart(
2405 self.http.clone(),
2406 part.clone(),
2407 path.join(file),
2408 total,
2409 current.clone(),
2410 progress.clone(),
2411 )
2412 .await?;
2413
2414 let complete: String = self
2415 .rpc(
2416 "snapshots.complete_multipart_upload".to_owned(),
2417 Some(params),
2418 )
2419 .await?;
2420 debug!("Snapshot Part Complete: {:?}", complete);
2421 }
2422
2423 let params = SnapshotStatusParams {
2424 snapshot_id,
2425 status: "available".to_owned(),
2426 };
2427 let _: SnapshotStatusResult = self
2428 .rpc("snapshots.update".to_owned(), Some(params))
2429 .await?;
2430
2431 if let Some(progress) = progress {
2432 drop(progress);
2433 }
2434
2435 self.snapshot(snapshot_id).await
2436 }
2437
2438 /// Delete a snapshot from EdgeFirst Studio.
2439 ///
2440 /// Permanently removes a snapshot and its associated data. This operation
2441 /// cannot be undone.
2442 ///
2443 /// # Arguments
2444 ///
2445 /// * `snapshot_id` - The snapshot ID to delete
2446 ///
2447 /// # Errors
2448 ///
2449 /// Returns an error if:
2450 /// * Snapshot doesn't exist
2451 /// * User lacks permission to delete the snapshot
2452 /// * Server error occurs
2453 ///
2454 /// # Example
2455 ///
2456 /// ```no_run
2457 /// # use edgefirst_client::{Client, SnapshotID};
2458 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2459 /// let client = Client::new()?.with_token_path(None)?;
2460 /// let snapshot_id = SnapshotID::from(123);
2461 /// client.delete_snapshot(snapshot_id).await?;
2462 /// # Ok(())
2463 /// # }
2464 /// ```
2465 ///
2466 /// # See Also
2467 ///
2468 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2469 /// * [`snapshots`](Self::snapshots) - List all snapshots
2470 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
2471 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2472 let _: String = self
2473 .rpc("snapshots.delete".to_owned(), Some(params))
2474 .await?;
2475 Ok(())
2476 }
2477
2478 /// Create a snapshot from an existing dataset on the server.
2479 ///
2480 /// Triggers server-side snapshot generation which exports the dataset's
2481 /// images and annotations into a downloadable EdgeFirst Dataset Format
2482 /// snapshot.
2483 ///
2484 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
2485 /// while restore creates a dataset from a snapshot, this method creates a
2486 /// snapshot from a dataset.
2487 ///
2488 /// # Arguments
2489 ///
2490 /// * `dataset_id` - The dataset ID to create snapshot from
2491 /// * `description` - Description for the created snapshot
2492 ///
2493 /// # Returns
2494 ///
2495 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
2496 /// for monitoring progress.
2497 ///
2498 /// # Errors
2499 ///
2500 /// Returns an error if:
2501 /// * Dataset doesn't exist
2502 /// * User lacks permission to access the dataset
2503 /// * Server rejects the request
2504 ///
2505 /// # Example
2506 ///
2507 /// ```no_run
2508 /// # use edgefirst_client::{Client, DatasetID};
2509 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2510 /// let client = Client::new()?.with_token_path(None)?;
2511 /// let dataset_id = DatasetID::from(123);
2512 ///
2513 /// // Create snapshot from dataset (all annotation sets)
2514 /// let result = client
2515 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
2516 /// .await?;
2517 /// println!("Created snapshot: {:?}", result.id);
2518 ///
2519 /// // Monitor progress via task ID
2520 /// if let Some(task_id) = result.task_id {
2521 /// println!("Task: {}", task_id);
2522 /// }
2523 /// # Ok(())
2524 /// # }
2525 /// ```
2526 ///
2527 /// # See Also
2528 ///
2529 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
2530 /// snapshot
2531 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2532 /// dataset
2533 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2534 pub async fn create_snapshot_from_dataset(
2535 &self,
2536 dataset_id: DatasetID,
2537 description: &str,
2538 annotation_set_id: Option<AnnotationSetID>,
2539 ) -> Result<SnapshotFromDatasetResult, Error> {
2540 // Resolve annotation_set_id: use provided value or fetch default
2541 let annotation_set_id = match annotation_set_id {
2542 Some(id) => id,
2543 None => {
2544 // Fetch annotation sets and find default ("annotations") or use first
2545 let sets = self.annotation_sets(dataset_id).await?;
2546 if sets.is_empty() {
2547 return Err(Error::InvalidParameters(
2548 "No annotation sets available for dataset".to_owned(),
2549 ));
2550 }
2551 // Look for "annotations" set (default), otherwise use first
2552 sets.iter()
2553 .find(|s| s.name() == "annotations")
2554 .unwrap_or(&sets[0])
2555 .id()
2556 }
2557 };
2558 let params = SnapshotCreateFromDataset {
2559 description: description.to_owned(),
2560 dataset_id,
2561 annotation_set_id,
2562 };
2563 self.rpc("snapshots.create".to_owned(), Some(params)).await
2564 }
2565
2566 /// Download a snapshot from EdgeFirst Studio to local storage.
2567 ///
2568 /// Downloads all files in a snapshot (single MCAP file or directory of
2569 /// EdgeFirst Dataset Format files) to the specified output path. Files are
2570 /// downloaded concurrently with progress tracking.
2571 ///
2572 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2573 /// downloads (default: half of CPU cores, min 2, max 8).
2574 ///
2575 /// # Arguments
2576 ///
2577 /// * `snapshot_id` - The snapshot ID to download
2578 /// * `output` - Local directory path to save downloaded files
2579 /// * `progress` - Optional channel to receive download progress updates
2580 ///
2581 /// # Errors
2582 ///
2583 /// Returns an error if:
2584 /// * Snapshot doesn't exist
2585 /// * Output directory cannot be created
2586 /// * Download fails or network error occurs
2587 ///
2588 /// # Example
2589 ///
2590 /// ```no_run
2591 /// # use edgefirst_client::{Client, SnapshotID, Progress};
2592 /// # use tokio::sync::mpsc;
2593 /// # use std::path::PathBuf;
2594 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2595 /// let client = Client::new()?.with_token_path(None)?;
2596 /// let snapshot_id = SnapshotID::from(123);
2597 ///
2598 /// // Download with progress tracking
2599 /// let (tx, mut rx) = mpsc::channel(1);
2600 /// tokio::spawn(async move {
2601 /// while let Some(Progress { current, total }) = rx.recv().await {
2602 /// println!("Download: {}/{} bytes", current, total);
2603 /// }
2604 /// });
2605 /// client
2606 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
2607 /// .await?;
2608 /// # Ok(())
2609 /// # }
2610 /// ```
2611 ///
2612 /// # See Also
2613 ///
2614 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2615 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2616 /// dataset
2617 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2618 pub async fn download_snapshot(
2619 &self,
2620 snapshot_id: SnapshotID,
2621 output: PathBuf,
2622 progress: Option<Sender<Progress>>,
2623 ) -> Result<(), Error> {
2624 fs::create_dir_all(&output).await?;
2625
2626 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2627 let items: HashMap<String, String> = self
2628 .rpc("snapshots.create_download_url".to_owned(), Some(params))
2629 .await?;
2630
2631 let total = Arc::new(AtomicUsize::new(0));
2632 let current = Arc::new(AtomicUsize::new(0));
2633 let sem = Arc::new(Semaphore::new(max_tasks()));
2634
2635 let tasks = items
2636 .iter()
2637 .map(|(key, url)| {
2638 let http = self.http.clone();
2639 let key = key.clone();
2640 let url = url.clone();
2641 let output = output.clone();
2642 let progress = progress.clone();
2643 let current = current.clone();
2644 let total = total.clone();
2645 let sem = sem.clone();
2646
2647 tokio::spawn(async move {
2648 let _permit = sem.acquire().await.map_err(|_| {
2649 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2650 })?;
2651 let res = http.get(url).send().await?;
2652 let content_length = res.content_length().unwrap_or(0) as usize;
2653
2654 if let Some(progress) = &progress {
2655 let total = total.fetch_add(content_length, Ordering::SeqCst);
2656 let _ = progress
2657 .send(Progress {
2658 current: current.load(Ordering::SeqCst),
2659 total: total + content_length,
2660 })
2661 .await;
2662 }
2663
2664 let mut file = File::create(output.join(key)).await?;
2665 let mut stream = res.bytes_stream();
2666
2667 while let Some(chunk) = stream.next().await {
2668 let chunk = chunk?;
2669 file.write_all(&chunk).await?;
2670 let len = chunk.len();
2671
2672 if let Some(progress) = &progress {
2673 let total = total.load(Ordering::SeqCst);
2674 let current = current.fetch_add(len, Ordering::SeqCst);
2675
2676 let _ = progress
2677 .send(Progress {
2678 current: current + len,
2679 total,
2680 })
2681 .await;
2682 }
2683 }
2684
2685 Ok::<(), Error>(())
2686 })
2687 })
2688 .collect::<Vec<_>>();
2689
2690 join_all(tasks)
2691 .await
2692 .into_iter()
2693 .collect::<Result<Vec<_>, _>>()?
2694 .into_iter()
2695 .collect::<Result<Vec<_>, _>>()?;
2696
2697 Ok(())
2698 }
2699
2700 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
2701 ///
2702 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
2703 /// the specified project. For MCAP files, supports:
2704 ///
2705 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
2706 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
2707 /// present)
2708 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
2709 /// * **Topic filtering**: Select specific MCAP topics to restore
2710 ///
2711 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
2712 /// dataset structure.
2713 ///
2714 /// # Arguments
2715 ///
2716 /// * `project_id` - Target project ID
2717 /// * `snapshot_id` - Snapshot ID to restore
2718 /// * `topics` - MCAP topics to include (empty = all topics)
2719 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
2720 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
2721 /// * `dataset_name` - Optional custom dataset name
2722 /// * `dataset_description` - Optional dataset description
2723 ///
2724 /// # Returns
2725 ///
2726 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
2727 ///
2728 /// # Errors
2729 ///
2730 /// Returns an error if:
2731 /// * Snapshot or project doesn't exist
2732 /// * Snapshot format is invalid
2733 /// * Server rejects restoration parameters
2734 ///
2735 /// # Example
2736 ///
2737 /// ```no_run
2738 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
2739 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2740 /// let client = Client::new()?.with_token_path(None)?;
2741 /// let project_id = ProjectID::from(1);
2742 /// let snapshot_id = SnapshotID::from(123);
2743 ///
2744 /// // Restore MCAP with AGTG for "person" and "car" detection
2745 /// let result = client
2746 /// .restore_snapshot(
2747 /// project_id,
2748 /// snapshot_id,
2749 /// &[], // All topics
2750 /// &["person".to_string(), "car".to_string()], // AGTG labels
2751 /// true, // Auto-depth
2752 /// Some("Highway Dataset"),
2753 /// Some("Collected on I-95"),
2754 /// )
2755 /// .await?;
2756 /// println!("Restored to dataset: {:?}", result.dataset_id);
2757 /// # Ok(())
2758 /// # }
2759 /// ```
2760 ///
2761 /// # See Also
2762 ///
2763 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2764 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2765 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2766 #[allow(clippy::too_many_arguments)]
2767 pub async fn restore_snapshot(
2768 &self,
2769 project_id: ProjectID,
2770 snapshot_id: SnapshotID,
2771 topics: &[String],
2772 autolabel: &[String],
2773 autodepth: bool,
2774 dataset_name: Option<&str>,
2775 dataset_description: Option<&str>,
2776 ) -> Result<SnapshotRestoreResult, Error> {
2777 let params = SnapshotRestore {
2778 project_id,
2779 snapshot_id,
2780 fps: 1,
2781 autodepth,
2782 agtg_pipeline: !autolabel.is_empty(),
2783 autolabel: autolabel.to_vec(),
2784 topics: topics.to_vec(),
2785 dataset_name: dataset_name.map(|s| s.to_owned()),
2786 dataset_description: dataset_description.map(|s| s.to_owned()),
2787 };
2788 self.rpc("snapshots.restore".to_owned(), Some(params)).await
2789 }
2790
2791 /// Returns a list of experiments available to the user. The experiments
2792 /// are returned as a vector of Experiment objects. If name is provided
2793 /// then only experiments containing this string are returned.
2794 ///
2795 /// Results are sorted by match quality: exact matches first, then
2796 /// case-insensitive exact matches, then shorter names (more specific),
2797 /// then alphabetically.
2798 ///
2799 /// Experiments provide a method of organizing training and validation
2800 /// sessions together and are akin to an Experiment in MLFlow terminology.
2801 /// Each experiment can have multiple trainer sessions associated with it,
2802 /// these would be akin to runs in MLFlow terminology.
2803 pub async fn experiments(
2804 &self,
2805 project_id: ProjectID,
2806 name: Option<&str>,
2807 ) -> Result<Vec<Experiment>, Error> {
2808 let params = HashMap::from([("project_id", project_id)]);
2809 let experiments: Vec<Experiment> =
2810 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
2811 if let Some(name) = name {
2812 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
2813 } else {
2814 Ok(experiments)
2815 }
2816 }
2817
2818 /// Return the experiment with the specified experiment ID. If the
2819 /// experiment does not exist, an error is returned.
2820 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
2821 let params = HashMap::from([("trainer_id", experiment_id)]);
2822 self.rpc("trainer.get".to_owned(), Some(params)).await
2823 }
2824
2825 /// Returns a list of trainer sessions available to the user. The trainer
2826 /// sessions are returned as a vector of TrainingSession objects. If name
2827 /// is provided then only trainer sessions containing this string are
2828 /// returned.
2829 ///
2830 /// Results are sorted by match quality: exact matches first, then
2831 /// case-insensitive exact matches, then shorter names (more specific),
2832 /// then alphabetically.
2833 ///
2834 /// Trainer sessions are akin to runs in MLFlow terminology. These
2835 /// represent an actual training session which will produce metrics and
2836 /// model artifacts.
2837 pub async fn training_sessions(
2838 &self,
2839 experiment_id: ExperimentID,
2840 name: Option<&str>,
2841 ) -> Result<Vec<TrainingSession>, Error> {
2842 let params = HashMap::from([("trainer_id", experiment_id)]);
2843 let sessions: Vec<TrainingSession> = self
2844 .rpc("trainer.session.list".to_owned(), Some(params))
2845 .await?;
2846 if let Some(name) = name {
2847 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
2848 } else {
2849 Ok(sessions)
2850 }
2851 }
2852
2853 /// Return the trainer session with the specified trainer session ID. If
2854 /// the trainer session does not exist, an error is returned.
2855 pub async fn training_session(
2856 &self,
2857 session_id: TrainingSessionID,
2858 ) -> Result<TrainingSession, Error> {
2859 let params = HashMap::from([("trainer_session_id", session_id)]);
2860 self.rpc("trainer.session.get".to_owned(), Some(params))
2861 .await
2862 }
2863
2864 /// List validation sessions for the given project.
2865 pub async fn validation_sessions(
2866 &self,
2867 project_id: ProjectID,
2868 ) -> Result<Vec<ValidationSession>, Error> {
2869 let params = HashMap::from([("project_id", project_id)]);
2870 self.rpc("validate.session.list".to_owned(), Some(params))
2871 .await
2872 }
2873
2874 /// Retrieve a specific validation session.
2875 pub async fn validation_session(
2876 &self,
2877 session_id: ValidationSessionID,
2878 ) -> Result<ValidationSession, Error> {
2879 let params = HashMap::from([("validate_session_id", session_id)]);
2880 self.rpc("validate.session.get".to_owned(), Some(params))
2881 .await
2882 }
2883
2884 /// List the artifacts for the specified trainer session. The artifacts
2885 /// are returned as a vector of strings.
2886 pub async fn artifacts(
2887 &self,
2888 training_session_id: TrainingSessionID,
2889 ) -> Result<Vec<Artifact>, Error> {
2890 let params = HashMap::from([("training_session_id", training_session_id)]);
2891 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2892 .await
2893 }
2894
2895 /// Download the model artifact for the specified trainer session to the
2896 /// specified file path, if path is not provided it will be downloaded to
2897 /// the current directory with the same filename. A progress callback can
2898 /// be provided to monitor the progress of the download over a watch
2899 /// channel.
2900 pub async fn download_artifact(
2901 &self,
2902 training_session_id: TrainingSessionID,
2903 modelname: &str,
2904 filename: Option<PathBuf>,
2905 progress: Option<Sender<Progress>>,
2906 ) -> Result<(), Error> {
2907 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2908 let resp = self
2909 .http
2910 .get(format!(
2911 "{}/download_model?training_session_id={}&file={}",
2912 self.url,
2913 training_session_id.value(),
2914 modelname
2915 ))
2916 .header("Authorization", format!("Bearer {}", self.token().await))
2917 .send()
2918 .await?;
2919 if !resp.status().is_success() {
2920 let err = resp.error_for_status_ref().unwrap_err();
2921 return Err(Error::HttpError(err));
2922 }
2923
2924 if let Some(parent) = filename.parent() {
2925 fs::create_dir_all(parent).await?;
2926 }
2927
2928 if let Some(progress) = progress {
2929 let total = resp.content_length().unwrap_or(0) as usize;
2930 let _ = progress.send(Progress { current: 0, total }).await;
2931
2932 let mut file = File::create(filename).await?;
2933 let mut current = 0;
2934 let mut stream = resp.bytes_stream();
2935
2936 while let Some(item) = stream.next().await {
2937 let chunk = item?;
2938 file.write_all(&chunk).await?;
2939 current += chunk.len();
2940 let _ = progress.send(Progress { current, total }).await;
2941 }
2942 } else {
2943 let body = resp.bytes().await?;
2944 fs::write(filename, body).await?;
2945 }
2946
2947 Ok(())
2948 }
2949
2950 /// Download the model checkpoint associated with the specified trainer
2951 /// session to the specified file path, if path is not provided it will be
2952 /// downloaded to the current directory with the same filename. A progress
2953 /// callback can be provided to monitor the progress of the download over a
2954 /// watch channel.
2955 ///
2956 /// There is no API for listing checkpoints it is expected that trainers are
2957 /// aware of possible checkpoints and their names within the checkpoint
2958 /// folder on the server.
2959 pub async fn download_checkpoint(
2960 &self,
2961 training_session_id: TrainingSessionID,
2962 checkpoint: &str,
2963 filename: Option<PathBuf>,
2964 progress: Option<Sender<Progress>>,
2965 ) -> Result<(), Error> {
2966 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2967 let resp = self
2968 .http
2969 .get(format!(
2970 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2971 self.url,
2972 training_session_id.value(),
2973 checkpoint
2974 ))
2975 .header("Authorization", format!("Bearer {}", self.token().await))
2976 .send()
2977 .await?;
2978 if !resp.status().is_success() {
2979 let err = resp.error_for_status_ref().unwrap_err();
2980 return Err(Error::HttpError(err));
2981 }
2982
2983 if let Some(parent) = filename.parent() {
2984 fs::create_dir_all(parent).await?;
2985 }
2986
2987 if let Some(progress) = progress {
2988 let total = resp.content_length().unwrap_or(0) as usize;
2989 let _ = progress.send(Progress { current: 0, total }).await;
2990
2991 let mut file = File::create(filename).await?;
2992 let mut current = 0;
2993 let mut stream = resp.bytes_stream();
2994
2995 while let Some(item) = stream.next().await {
2996 let chunk = item?;
2997 file.write_all(&chunk).await?;
2998 current += chunk.len();
2999 let _ = progress.send(Progress { current, total }).await;
3000 }
3001 } else {
3002 let body = resp.bytes().await?;
3003 fs::write(filename, body).await?;
3004 }
3005
3006 Ok(())
3007 }
3008
3009 /// Return a list of tasks for the current user.
3010 ///
3011 /// # Arguments
3012 ///
3013 /// * `name` - Optional filter for task name (client-side substring match)
3014 /// * `workflow` - Optional filter for workflow/task type. If provided,
3015 /// filters server-side by exact match. Valid values include: "trainer",
3016 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
3017 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
3018 /// "convertor", "twostage"
3019 /// * `status` - Optional filter for task status (e.g., "running",
3020 /// "complete", "error")
3021 /// * `manager` - Optional filter for task manager type (e.g., "aws",
3022 /// "user", "kubernetes")
3023 pub async fn tasks(
3024 &self,
3025 name: Option<&str>,
3026 workflow: Option<&str>,
3027 status: Option<&str>,
3028 manager: Option<&str>,
3029 ) -> Result<Vec<Task>, Error> {
3030 let mut params = TasksListParams {
3031 continue_token: None,
3032 types: workflow.map(|w| vec![w.to_owned()]),
3033 status: status.map(|s| vec![s.to_owned()]),
3034 manager: manager.map(|m| vec![m.to_owned()]),
3035 };
3036 let mut tasks = Vec::new();
3037
3038 loop {
3039 let result = self
3040 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
3041 .await?;
3042 tasks.extend(result.tasks);
3043
3044 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
3045 params.continue_token = None;
3046 } else {
3047 params.continue_token = result.continue_token;
3048 }
3049
3050 if params.continue_token.is_none() {
3051 break;
3052 }
3053 }
3054
3055 if let Some(name) = name {
3056 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
3057 }
3058
3059 Ok(tasks)
3060 }
3061
3062 /// Retrieve the task information and status.
3063 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
3064 self.rpc(
3065 "task.get".to_owned(),
3066 Some(HashMap::from([("id", task_id)])),
3067 )
3068 .await
3069 }
3070
3071 /// Updates the tasks status.
3072 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
3073 let status = TaskStatus {
3074 task_id,
3075 status: status.to_owned(),
3076 };
3077 self.rpc("docker.update.status".to_owned(), Some(status))
3078 .await
3079 }
3080
3081 /// Defines the stages for the task. The stages are defined as a mapping
3082 /// from stage names to their descriptions. Once stages are defined their
3083 /// status can be updated using the update_stage method.
3084 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
3085 let stages: Vec<HashMap<String, String>> = stages
3086 .iter()
3087 .map(|(key, value)| {
3088 let mut stage_map = HashMap::new();
3089 stage_map.insert(key.to_string(), value.to_string());
3090 stage_map
3091 })
3092 .collect();
3093 let params = TaskStages { task_id, stages };
3094 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
3095 Ok(())
3096 }
3097
3098 /// Updates the progress of the task for the provided stage and status
3099 /// information.
3100 pub async fn update_stage(
3101 &self,
3102 task_id: TaskID,
3103 stage: &str,
3104 status: &str,
3105 message: &str,
3106 percentage: u8,
3107 ) -> Result<(), Error> {
3108 let stage = Stage::new(
3109 Some(task_id),
3110 stage.to_owned(),
3111 Some(status.to_owned()),
3112 Some(message.to_owned()),
3113 percentage,
3114 );
3115 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
3116 Ok(())
3117 }
3118
3119 /// Raw fetch from the Studio server is used for downloading files.
3120 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
3121 let req = self
3122 .http
3123 .get(format!("{}/{}", self.url, query))
3124 .header("User-Agent", "EdgeFirst Client")
3125 .header("Authorization", format!("Bearer {}", self.token().await));
3126 let resp = req.send().await?;
3127
3128 if resp.status().is_success() {
3129 let body = resp.bytes().await?;
3130
3131 if log_enabled!(Level::Trace) {
3132 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
3133 }
3134
3135 Ok(body.to_vec())
3136 } else {
3137 let err = resp.error_for_status_ref().unwrap_err();
3138 Err(Error::HttpError(err))
3139 }
3140 }
3141
3142 /// Sends a multipart post request to the server. This is used by the
3143 /// upload and download APIs which do not use JSON-RPC but instead transfer
3144 /// files using multipart/form-data.
3145 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
3146 let req = self
3147 .http
3148 .post(format!("{}/api?method={}", self.url, method))
3149 .header("Accept", "application/json")
3150 .header("User-Agent", "EdgeFirst Client")
3151 .header("Authorization", format!("Bearer {}", self.token().await))
3152 .multipart(form);
3153 let resp = req.send().await?;
3154
3155 if resp.status().is_success() {
3156 let body = resp.bytes().await?;
3157
3158 if log_enabled!(Level::Trace) {
3159 trace!(
3160 "POST Multipart Response: {}",
3161 String::from_utf8_lossy(&body)
3162 );
3163 }
3164
3165 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
3166 Ok(response) => response,
3167 Err(err) => {
3168 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3169 return Err(err.into());
3170 }
3171 };
3172
3173 if let Some(error) = response.error {
3174 Err(Error::RpcError(error.code, error.message))
3175 } else if let Some(result) = response.result {
3176 Ok(result)
3177 } else {
3178 Err(Error::InvalidResponse)
3179 }
3180 } else {
3181 let err = resp.error_for_status_ref().unwrap_err();
3182 Err(Error::HttpError(err))
3183 }
3184 }
3185
3186 /// Send a JSON-RPC request to the server. The method is the name of the
3187 /// method to call on the server. The params are the parameters to pass to
3188 /// the method. The method and params are serialized into a JSON-RPC
3189 /// request and sent to the server. The response is deserialized into
3190 /// the specified type and returned to the caller.
3191 ///
3192 /// NOTE: This API would generally not be called directly and instead users
3193 /// should use the higher-level methods provided by the client.
3194 pub async fn rpc<Params, RpcResult>(
3195 &self,
3196 method: String,
3197 params: Option<Params>,
3198 ) -> Result<RpcResult, Error>
3199 where
3200 Params: Serialize,
3201 RpcResult: DeserializeOwned,
3202 {
3203 let auth_expires = self.token_expiration().await?;
3204 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
3205 self.renew_token().await?;
3206 }
3207
3208 self.rpc_without_auth(method, params).await
3209 }
3210
3211 async fn rpc_without_auth<Params, RpcResult>(
3212 &self,
3213 method: String,
3214 params: Option<Params>,
3215 ) -> Result<RpcResult, Error>
3216 where
3217 Params: Serialize,
3218 RpcResult: DeserializeOwned,
3219 {
3220 let request = RpcRequest {
3221 method,
3222 params,
3223 ..Default::default()
3224 };
3225
3226 if log_enabled!(Level::Trace) {
3227 trace!(
3228 "RPC Request: {}",
3229 serde_json::ser::to_string_pretty(&request)?
3230 );
3231 }
3232
3233 let url = format!("{}/api", self.url);
3234
3235 // Use client-level timeout (allows retry mechanism to work properly)
3236 // Per-request timeout overrides can prevent retries from functioning
3237 let res = self
3238 .http
3239 .post(&url)
3240 .header("Accept", "application/json")
3241 .header("User-Agent", "EdgeFirst Client")
3242 .header("Authorization", format!("Bearer {}", self.token().await))
3243 .json(&request)
3244 .send()
3245 .await?;
3246
3247 self.process_rpc_response(res).await
3248 }
3249
3250 async fn process_rpc_response<RpcResult>(
3251 &self,
3252 res: reqwest::Response,
3253 ) -> Result<RpcResult, Error>
3254 where
3255 RpcResult: DeserializeOwned,
3256 {
3257 let body = res.bytes().await?;
3258
3259 if log_enabled!(Level::Trace) {
3260 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
3261 }
3262
3263 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
3264 Ok(response) => response,
3265 Err(err) => {
3266 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3267 return Err(err.into());
3268 }
3269 };
3270
3271 // FIXME: Studio Server always returns 999 as the id.
3272 // if request.id.to_string() != response.id {
3273 // return Err(Error::InvalidRpcId(response.id));
3274 // }
3275
3276 if let Some(error) = response.error {
3277 Err(Error::RpcError(error.code, error.message))
3278 } else if let Some(result) = response.result {
3279 Ok(result)
3280 } else {
3281 Err(Error::InvalidResponse)
3282 }
3283 }
3284}
3285
3286/// Process items in parallel with semaphore concurrency control and progress
3287/// tracking.
3288///
3289/// This helper eliminates boilerplate for parallel item processing with:
3290/// - Semaphore limiting concurrent tasks to `max_tasks()` (configurable via
3291/// `MAX_TASKS` environment variable, default: half of CPU cores, min 2, max
3292/// 8)
3293/// - Atomic progress counter with automatic item-level updates
3294/// - Progress updates sent after each item completes (not byte-level streaming)
3295/// - Proper error propagation from spawned tasks
3296///
3297/// Note: This is optimized for discrete items with post-completion progress
3298/// updates. For byte-level streaming progress or custom retry logic, use
3299/// specialized implementations.
3300///
3301/// # Arguments
3302///
3303/// * `items` - Collection of items to process in parallel
3304/// * `progress` - Optional progress channel for tracking completion
3305/// * `work_fn` - Async function to execute for each item
3306///
3307/// # Examples
3308///
3309/// ```rust,ignore
3310/// parallel_foreach_items(samples, progress, |sample| async move {
3311/// // Process sample
3312/// sample.download(&client, file_type).await?;
3313/// Ok(())
3314/// }).await?;
3315/// ```
3316async fn parallel_foreach_items<T, F, Fut>(
3317 items: Vec<T>,
3318 progress: Option<Sender<Progress>>,
3319 work_fn: F,
3320) -> Result<(), Error>
3321where
3322 T: Send + 'static,
3323 F: Fn(T) -> Fut + Send + Sync + 'static,
3324 Fut: Future<Output = Result<(), Error>> + Send + 'static,
3325{
3326 let total = items.len();
3327 let current = Arc::new(AtomicUsize::new(0));
3328 let sem = Arc::new(Semaphore::new(max_tasks()));
3329 let work_fn = Arc::new(work_fn);
3330
3331 let tasks = items
3332 .into_iter()
3333 .map(|item| {
3334 let sem = sem.clone();
3335 let current = current.clone();
3336 let progress = progress.clone();
3337 let work_fn = work_fn.clone();
3338
3339 tokio::spawn(async move {
3340 let _permit = sem.acquire().await.map_err(|_| {
3341 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3342 })?;
3343
3344 // Execute the actual work
3345 work_fn(item).await?;
3346
3347 // Update progress
3348 if let Some(progress) = &progress {
3349 let current = current.fetch_add(1, Ordering::SeqCst);
3350 let _ = progress
3351 .send(Progress {
3352 current: current + 1,
3353 total,
3354 })
3355 .await;
3356 }
3357
3358 Ok::<(), Error>(())
3359 })
3360 })
3361 .collect::<Vec<_>>();
3362
3363 join_all(tasks)
3364 .await
3365 .into_iter()
3366 .collect::<Result<Vec<_>, _>>()?
3367 .into_iter()
3368 .collect::<Result<Vec<_>, _>>()?;
3369
3370 if let Some(progress) = progress {
3371 drop(progress);
3372 }
3373
3374 Ok(())
3375}
3376
3377/// Upload a file to S3 using multipart upload with presigned URLs.
3378///
3379/// Splits a file into chunks (100MB each) and uploads them in parallel using
3380/// S3 multipart upload protocol. Returns completion parameters with ETags for
3381/// finalizing the upload.
3382///
3383/// This function handles:
3384/// - Splitting files into parts based on PART_SIZE (100MB)
3385/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
3386/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
3387/// - Retry logic (handled by reqwest client)
3388/// - Progress tracking across all parts
3389///
3390/// # Arguments
3391///
3392/// * `http` - HTTP client for making requests
3393/// * `part` - Snapshot part info with presigned URLs for each chunk
3394/// * `path` - Local file path to upload
3395/// * `total` - Total bytes across all files for progress calculation
3396/// * `current` - Atomic counter tracking bytes uploaded across all operations
3397/// * `progress` - Optional channel for sending progress updates
3398///
3399/// # Returns
3400///
3401/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
3402async fn upload_multipart(
3403 http: reqwest::Client,
3404 part: SnapshotPart,
3405 path: PathBuf,
3406 total: usize,
3407 current: Arc<AtomicUsize>,
3408 progress: Option<Sender<Progress>>,
3409) -> Result<SnapshotCompleteMultipartParams, Error> {
3410 let filesize = path.metadata()?.len() as usize;
3411 let n_parts = filesize.div_ceil(PART_SIZE);
3412 let sem = Arc::new(Semaphore::new(max_tasks()));
3413
3414 let key = part.key.ok_or(Error::InvalidResponse)?;
3415 let upload_id = part.upload_id;
3416
3417 let urls = part.urls.clone();
3418 // Pre-allocate ETag slots for all parts
3419 let etags = Arc::new(tokio::sync::Mutex::new(vec![
3420 EtagPart {
3421 etag: "".to_owned(),
3422 part_number: 0,
3423 };
3424 n_parts
3425 ]));
3426
3427 // Upload all parts in parallel with concurrency limiting
3428 let tasks = (0..n_parts)
3429 .map(|part| {
3430 let http = http.clone();
3431 let url = urls[part].clone();
3432 let etags = etags.clone();
3433 let path = path.to_owned();
3434 let sem = sem.clone();
3435 let progress = progress.clone();
3436 let current = current.clone();
3437
3438 tokio::spawn(async move {
3439 // Acquire semaphore permit to limit concurrent uploads
3440 let _permit = sem.acquire().await?;
3441
3442 // Upload part (retry is handled by reqwest client)
3443 let etag =
3444 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
3445
3446 // Store ETag for this part (needed to complete multipart upload)
3447 let mut etags = etags.lock().await;
3448 etags[part] = EtagPart {
3449 etag,
3450 part_number: part + 1,
3451 };
3452
3453 // Update progress counter
3454 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
3455 if let Some(progress) = &progress {
3456 let _ = progress
3457 .send(Progress {
3458 current: current + PART_SIZE,
3459 total,
3460 })
3461 .await;
3462 }
3463
3464 Ok::<(), Error>(())
3465 })
3466 })
3467 .collect::<Vec<_>>();
3468
3469 // Wait for all parts to complete
3470 join_all(tasks)
3471 .await
3472 .into_iter()
3473 .collect::<Result<Vec<_>, _>>()?;
3474
3475 Ok(SnapshotCompleteMultipartParams {
3476 key,
3477 upload_id,
3478 etag_list: etags.lock().await.clone(),
3479 })
3480}
3481
3482async fn upload_part(
3483 http: reqwest::Client,
3484 url: String,
3485 path: PathBuf,
3486 part: usize,
3487 n_parts: usize,
3488) -> Result<String, Error> {
3489 let filesize = path.metadata()?.len() as usize;
3490 let mut file = File::open(path).await?;
3491 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
3492 .await?;
3493 let file = file.take(PART_SIZE as u64);
3494
3495 let body_length = if part + 1 == n_parts {
3496 filesize % PART_SIZE
3497 } else {
3498 PART_SIZE
3499 };
3500
3501 let stream = FramedRead::new(file, BytesCodec::new());
3502 let body = Body::wrap_stream(stream);
3503
3504 let resp = http
3505 .put(url.clone())
3506 .header(CONTENT_LENGTH, body_length)
3507 .body(body)
3508 .send()
3509 .await?
3510 .error_for_status()?;
3511
3512 let etag = resp
3513 .headers()
3514 .get("etag")
3515 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
3516 .to_str()
3517 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
3518 .to_owned();
3519
3520 // Studio Server requires etag without the quotes.
3521 let etag = etag
3522 .strip_prefix("\"")
3523 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
3524 let etag = etag
3525 .strip_suffix("\"")
3526 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
3527
3528 Ok(etag.to_owned())
3529}
3530
3531/// Upload a complete file to a presigned S3 URL using HTTP PUT.
3532///
3533/// This is used for populate_samples to upload files to S3 after
3534/// receiving presigned URLs from the server.
3535async fn upload_file_to_presigned_url(
3536 http: reqwest::Client,
3537 url: &str,
3538 path: PathBuf,
3539) -> Result<(), Error> {
3540 // Read the entire file into memory
3541 let file_data = fs::read(&path).await?;
3542 let file_size = file_data.len();
3543
3544 // Upload (retry is handled by reqwest client)
3545 let resp = http
3546 .put(url)
3547 .header(CONTENT_LENGTH, file_size)
3548 .body(file_data)
3549 .send()
3550 .await?;
3551
3552 if resp.status().is_success() {
3553 debug!(
3554 "Successfully uploaded file: {:?} ({} bytes)",
3555 path, file_size
3556 );
3557 Ok(())
3558 } else {
3559 let status = resp.status();
3560 let error_text = resp.text().await.unwrap_or_default();
3561 Err(Error::InvalidParameters(format!(
3562 "Upload failed: HTTP {} - {}",
3563 status, error_text
3564 )))
3565 }
3566}
3567
3568#[cfg(test)]
3569mod tests {
3570 use super::*;
3571
3572 #[test]
3573 fn test_filter_and_sort_by_name_exact_match_first() {
3574 // Test that exact matches come first
3575 let items = vec![
3576 "Deer Roundtrip 123".to_string(),
3577 "Deer".to_string(),
3578 "Reindeer".to_string(),
3579 "DEER".to_string(),
3580 ];
3581 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3582 assert_eq!(result[0], "Deer"); // Exact match first
3583 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
3584 }
3585
3586 #[test]
3587 fn test_filter_and_sort_by_name_shorter_names_preferred() {
3588 // Test that shorter names (more specific) come before longer ones
3589 let items = vec![
3590 "Test Dataset ABC".to_string(),
3591 "Test".to_string(),
3592 "Test Dataset".to_string(),
3593 ];
3594 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3595 assert_eq!(result[0], "Test"); // Exact match first
3596 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
3597 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
3598 }
3599
3600 #[test]
3601 fn test_filter_and_sort_by_name_case_insensitive_filter() {
3602 // Test that filtering is case-insensitive
3603 let items = vec![
3604 "UPPERCASE".to_string(),
3605 "lowercase".to_string(),
3606 "MixedCase".to_string(),
3607 ];
3608 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
3609 assert_eq!(result.len(), 3); // All items should match
3610 }
3611
3612 #[test]
3613 fn test_filter_and_sort_by_name_no_matches() {
3614 // Test that empty result is returned when no matches
3615 let items = vec!["Apple".to_string(), "Banana".to_string()];
3616 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
3617 assert!(result.is_empty());
3618 }
3619
3620 #[test]
3621 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
3622 // Test alphabetical ordering for same-length names
3623 let items = vec![
3624 "TestC".to_string(),
3625 "TestA".to_string(),
3626 "TestB".to_string(),
3627 ];
3628 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3629 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
3630 }
3631
3632 #[test]
3633 fn test_build_filename_no_flatten() {
3634 // When flatten=false, should return base_name unchanged
3635 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
3636 assert_eq!(result, "image.jpg");
3637
3638 let result = Client::build_filename("test.png", false, None, None);
3639 assert_eq!(result, "test.png");
3640 }
3641
3642 #[test]
3643 fn test_build_filename_flatten_no_sequence() {
3644 // When flatten=true but no sequence, should return base_name unchanged
3645 let result = Client::build_filename("standalone.jpg", true, None, None);
3646 assert_eq!(result, "standalone.jpg");
3647 }
3648
3649 #[test]
3650 fn test_build_filename_flatten_with_sequence_not_prefixed() {
3651 // When flatten=true, in sequence, filename not prefixed → add prefix
3652 let result = Client::build_filename(
3653 "image.camera.jpeg",
3654 true,
3655 Some(&"deer_sequence".to_string()),
3656 Some(42),
3657 );
3658 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
3659 }
3660
3661 #[test]
3662 fn test_build_filename_flatten_with_sequence_no_frame() {
3663 // When flatten=true, in sequence, no frame number → prefix with sequence only
3664 let result =
3665 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
3666 assert_eq!(result, "sequence_A_image.jpg");
3667 }
3668
3669 #[test]
3670 fn test_build_filename_flatten_already_prefixed() {
3671 // When flatten=true, filename already starts with sequence_ → return unchanged
3672 let result = Client::build_filename(
3673 "deer_sequence_042.camera.jpeg",
3674 true,
3675 Some(&"deer_sequence".to_string()),
3676 Some(42),
3677 );
3678 assert_eq!(result, "deer_sequence_042.camera.jpeg");
3679 }
3680
3681 #[test]
3682 fn test_build_filename_flatten_already_prefixed_different_frame() {
3683 // Edge case: filename has sequence prefix but we're adding different frame
3684 // Should still respect existing prefix
3685 let result = Client::build_filename(
3686 "sequence_A_001.jpg",
3687 true,
3688 Some(&"sequence_A".to_string()),
3689 Some(2),
3690 );
3691 assert_eq!(result, "sequence_A_001.jpg");
3692 }
3693
3694 #[test]
3695 fn test_build_filename_flatten_partial_match() {
3696 // Edge case: filename contains sequence name but not as prefix
3697 let result = Client::build_filename(
3698 "test_sequence_A_image.jpg",
3699 true,
3700 Some(&"sequence_A".to_string()),
3701 Some(5),
3702 );
3703 // Should add prefix because it doesn't START with "sequence_A_"
3704 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
3705 }
3706
3707 #[test]
3708 fn test_build_filename_flatten_preserves_extension() {
3709 // Verify that file extensions are preserved correctly
3710 let extensions = vec![
3711 "jpeg",
3712 "jpg",
3713 "png",
3714 "camera.jpeg",
3715 "lidar.pcd",
3716 "depth.png",
3717 ];
3718
3719 for ext in extensions {
3720 let filename = format!("image.{}", ext);
3721 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
3722 assert!(
3723 result.ends_with(&format!(".{}", ext)),
3724 "Extension .{} not preserved in {}",
3725 ext,
3726 result
3727 );
3728 }
3729 }
3730
3731 #[test]
3732 fn test_build_filename_flatten_sanitization_compatibility() {
3733 // Test with sanitized path components (no special chars)
3734 let result = Client::build_filename(
3735 "sample_001.jpg",
3736 true,
3737 Some(&"seq_name_with_underscores".to_string()),
3738 Some(10),
3739 );
3740 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
3741 }
3742
3743 // =========================================================================
3744 // Additional filter_and_sort_by_name tests for exact match determinism
3745 // =========================================================================
3746
3747 #[test]
3748 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
3749 // Test that searching for "Deer" always returns "Deer" first, not
3750 // "Deer Roundtrip 20251129" or similar
3751 let items = vec![
3752 "Deer Roundtrip 20251129".to_string(),
3753 "White-Tailed Deer".to_string(),
3754 "Deer".to_string(),
3755 "Deer Snapshot Test".to_string(),
3756 "Reindeer Dataset".to_string(),
3757 ];
3758
3759 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3760
3761 // CRITICAL: First result must be exact match "Deer"
3762 assert_eq!(
3763 result.first().map(|s| s.as_str()),
3764 Some("Deer"),
3765 "Expected exact match 'Deer' first, got: {:?}",
3766 result.first()
3767 );
3768
3769 // Verify all items containing "Deer" are present (case-insensitive)
3770 assert_eq!(result.len(), 5);
3771 }
3772
3773 #[test]
3774 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
3775 // Verify case-sensitive exact match takes priority over case-insensitive
3776 let items = vec![
3777 "DEER".to_string(),
3778 "deer".to_string(),
3779 "Deer".to_string(),
3780 "Deer Test".to_string(),
3781 ];
3782
3783 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3784
3785 // Priority 1: Case-sensitive exact match "Deer" first
3786 assert_eq!(result[0], "Deer");
3787 // Priority 2: Case-insensitive exact matches next
3788 assert!(result[1] == "DEER" || result[1] == "deer");
3789 assert!(result[2] == "DEER" || result[2] == "deer");
3790 }
3791
3792 #[test]
3793 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
3794 // Realistic scenario: User searches for snapshot "Deer" and multiple
3795 // snapshots exist with similar names
3796 let items = vec![
3797 "Unit Testing - Deer Dataset Backup".to_string(),
3798 "Deer".to_string(),
3799 "Deer Snapshot 2025-01-15".to_string(),
3800 "Original Deer".to_string(),
3801 ];
3802
3803 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3804
3805 // MUST return exact match first for deterministic test behavior
3806 assert_eq!(
3807 result[0], "Deer",
3808 "Searching for 'Deer' should return exact 'Deer' first"
3809 );
3810 }
3811
3812 #[test]
3813 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
3814 // Realistic scenario: User searches for dataset "Deer" but multiple
3815 // datasets have "Deer" in their name
3816 let items = vec![
3817 "Deer Roundtrip".to_string(),
3818 "Deer".to_string(),
3819 "deer".to_string(),
3820 "White-Tailed Deer".to_string(),
3821 "Deer-V2".to_string(),
3822 ];
3823
3824 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3825
3826 // Exact case-sensitive match must be first
3827 assert_eq!(result[0], "Deer");
3828 // Case-insensitive exact match should be second
3829 assert_eq!(result[1], "deer");
3830 // Shorter names should come before longer names
3831 assert!(
3832 result.iter().position(|s| s == "Deer-V2").unwrap()
3833 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
3834 );
3835 }
3836
3837 #[test]
3838 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
3839 // CRITICAL: The first result should ALWAYS be the best match
3840 // This is essential for deterministic test behavior
3841 let scenarios = vec![
3842 // (items, filter, expected_first)
3843 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
3844 (vec!["test", "TEST", "Test Data"], "test", "test"),
3845 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
3846 ];
3847
3848 for (items, filter, expected_first) in scenarios {
3849 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
3850 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
3851
3852 assert_eq!(
3853 result.first().map(|s| s.as_str()),
3854 Some(expected_first),
3855 "For filter '{}', expected first result '{}', got: {:?}",
3856 filter,
3857 expected_first,
3858 result.first()
3859 );
3860 }
3861 }
3862
3863 #[test]
3864 fn test_with_server_clears_storage() {
3865 use crate::storage::MemoryTokenStorage;
3866
3867 // Create client with memory storage and a token
3868 let storage = Arc::new(MemoryTokenStorage::new());
3869 storage.store("test-token").unwrap();
3870
3871 let client = Client::new().unwrap().with_storage(storage.clone());
3872
3873 // Verify token is loaded
3874 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
3875
3876 // Change server - should clear storage
3877 let _new_client = client.with_server("test").unwrap();
3878
3879 // Verify storage was cleared
3880 assert_eq!(storage.load().unwrap(), None);
3881 }
3882}