edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID, SnapshotRestore,
10 SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
11 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
12 ValidationSessionID,
13 },
14 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
15 retry::{create_retry_policy, log_retry_configuration},
16 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
17};
18use base64::Engine as _;
19use chrono::{DateTime, Utc};
20use directories::ProjectDirs;
21use futures::{StreamExt as _, future::join_all};
22use log::{Level, debug, error, log_enabled, trace, warn};
23use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use std::{
26 collections::HashMap,
27 ffi::OsStr,
28 fs::create_dir_all,
29 io::{SeekFrom, Write as _},
30 path::{Path, PathBuf},
31 sync::{
32 Arc,
33 atomic::{AtomicUsize, Ordering},
34 },
35 time::Duration,
36 vec,
37};
38use tokio::{
39 fs::{self, File},
40 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
41 sync::{RwLock, Semaphore, mpsc::Sender},
42};
43use tokio_util::codec::{BytesCodec, FramedRead};
44use walkdir::WalkDir;
45
46#[cfg(feature = "polars")]
47use polars::prelude::*;
48
49static PART_SIZE: usize = 100 * 1024 * 1024;
50
51fn max_tasks() -> usize {
52 std::env::var("MAX_TASKS")
53 .ok()
54 .and_then(|v| v.parse().ok())
55 .unwrap_or_else(|| {
56 // Default to half the number of CPUs, minimum 2, maximum 8
57 // Lower max prevents timeout issues with large file uploads
58 let cpus = std::thread::available_parallelism()
59 .map(|n| n.get())
60 .unwrap_or(4);
61 (cpus / 2).clamp(2, 8)
62 })
63}
64
65/// Filters items by name and sorts by match quality.
66///
67/// Match quality priority (best to worst):
68/// 1. Exact match (case-sensitive)
69/// 2. Exact match (case-insensitive)
70/// 3. Substring match (shorter names first, then alphabetically)
71///
72/// This ensures that searching for "Deer" returns "Deer" before
73/// "Deer Roundtrip 20251129" or "Reindeer".
74fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
75where
76 F: Fn(&T) -> &str,
77{
78 let filter_lower = filter.to_lowercase();
79 let mut filtered: Vec<T> = items
80 .into_iter()
81 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
82 .collect();
83
84 filtered.sort_by(|a, b| {
85 let name_a = get_name(a);
86 let name_b = get_name(b);
87
88 // Priority 1: Exact match (case-sensitive)
89 let exact_a = name_a == filter;
90 let exact_b = name_b == filter;
91 if exact_a != exact_b {
92 return exact_b.cmp(&exact_a); // true (exact) comes first
93 }
94
95 // Priority 2: Exact match (case-insensitive)
96 let exact_ci_a = name_a.to_lowercase() == filter_lower;
97 let exact_ci_b = name_b.to_lowercase() == filter_lower;
98 if exact_ci_a != exact_ci_b {
99 return exact_ci_b.cmp(&exact_ci_a);
100 }
101
102 // Priority 3: Shorter names first (more specific matches)
103 let len_cmp = name_a.len().cmp(&name_b.len());
104 if len_cmp != std::cmp::Ordering::Equal {
105 return len_cmp;
106 }
107
108 // Priority 4: Alphabetical order for stability
109 name_a.cmp(name_b)
110 });
111
112 filtered
113}
114
115fn sanitize_path_component(name: &str) -> String {
116 let trimmed = name.trim();
117 if trimmed.is_empty() {
118 return "unnamed".to_string();
119 }
120
121 let component = Path::new(trimmed)
122 .file_name()
123 .unwrap_or_else(|| OsStr::new(trimmed));
124
125 let sanitized: String = component
126 .to_string_lossy()
127 .chars()
128 .map(|c| match c {
129 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
130 _ => c,
131 })
132 .collect();
133
134 if sanitized.is_empty() {
135 "unnamed".to_string()
136 } else {
137 sanitized
138 }
139}
140
141/// Progress information for long-running operations.
142///
143/// This struct tracks the current progress of operations like file uploads,
144/// downloads, or dataset processing. It provides the current count and total
145/// count to enable progress reporting in applications.
146///
147/// # Examples
148///
149/// ```rust
150/// use edgefirst_client::Progress;
151///
152/// let progress = Progress {
153/// current: 25,
154/// total: 100,
155/// };
156/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
157/// println!(
158/// "Progress: {:.1}% ({}/{})",
159/// percentage, progress.current, progress.total
160/// );
161/// ```
162#[derive(Debug, Clone)]
163pub struct Progress {
164 /// Current number of completed items.
165 pub current: usize,
166 /// Total number of items to process.
167 pub total: usize,
168}
169
170#[derive(Serialize)]
171struct RpcRequest<Params> {
172 id: u64,
173 jsonrpc: String,
174 method: String,
175 params: Option<Params>,
176}
177
178impl<T> Default for RpcRequest<T> {
179 fn default() -> Self {
180 RpcRequest {
181 id: 0,
182 jsonrpc: "2.0".to_string(),
183 method: "".to_string(),
184 params: None,
185 }
186 }
187}
188
189#[derive(Deserialize)]
190struct RpcError {
191 code: i32,
192 message: String,
193}
194
195#[derive(Deserialize)]
196struct RpcResponse<RpcResult> {
197 #[allow(dead_code)]
198 id: String,
199 #[allow(dead_code)]
200 jsonrpc: String,
201 error: Option<RpcError>,
202 result: Option<RpcResult>,
203}
204
205#[derive(Deserialize)]
206#[allow(dead_code)]
207struct EmptyResult {}
208
209#[derive(Debug, Serialize)]
210#[allow(dead_code)]
211struct SnapshotCreateParams {
212 snapshot_name: String,
213 keys: Vec<String>,
214}
215
216#[derive(Debug, Deserialize)]
217#[allow(dead_code)]
218struct SnapshotCreateResult {
219 snapshot_id: SnapshotID,
220 urls: Vec<String>,
221}
222
223#[derive(Debug, Serialize)]
224struct SnapshotCreateMultipartParams {
225 snapshot_name: String,
226 keys: Vec<String>,
227 file_sizes: Vec<usize>,
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(untagged)]
232enum SnapshotCreateMultipartResultField {
233 Id(u64),
234 Part(SnapshotPart),
235}
236
237#[derive(Debug, Serialize)]
238struct SnapshotCompleteMultipartParams {
239 key: String,
240 upload_id: String,
241 etag_list: Vec<EtagPart>,
242}
243
244#[derive(Debug, Clone, Serialize)]
245struct EtagPart {
246 #[serde(rename = "ETag")]
247 etag: String,
248 #[serde(rename = "PartNumber")]
249 part_number: usize,
250}
251
252#[derive(Debug, Clone, Deserialize)]
253struct SnapshotPart {
254 key: Option<String>,
255 upload_id: String,
256 urls: Vec<String>,
257}
258
259#[derive(Debug, Serialize)]
260struct SnapshotStatusParams {
261 snapshot_id: SnapshotID,
262 status: String,
263}
264
265#[derive(Deserialize, Debug)]
266struct SnapshotStatusResult {
267 #[allow(dead_code)]
268 pub id: SnapshotID,
269 #[allow(dead_code)]
270 pub uid: String,
271 #[allow(dead_code)]
272 pub description: String,
273 #[allow(dead_code)]
274 pub date: String,
275 #[allow(dead_code)]
276 pub status: String,
277}
278
279#[derive(Serialize)]
280#[allow(dead_code)]
281struct ImageListParams {
282 images_filter: ImagesFilter,
283 image_files_filter: HashMap<String, String>,
284 only_ids: bool,
285}
286
287#[derive(Serialize)]
288#[allow(dead_code)]
289struct ImagesFilter {
290 dataset_id: DatasetID,
291}
292
293/// Main client for interacting with EdgeFirst Studio Server.
294///
295/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
296/// and manages authentication, RPC calls, and data operations. It provides
297/// methods for managing projects, datasets, experiments, training sessions,
298/// and various utility functions for data processing.
299///
300/// The client supports multiple authentication methods and can work with both
301/// SaaS and self-hosted EdgeFirst Studio instances.
302///
303/// # Features
304///
305/// - **Authentication**: Token-based authentication with automatic persistence
306/// - **Dataset Management**: Upload, download, and manipulate datasets
307/// - **Project Operations**: Create and manage projects and experiments
308/// - **Training & Validation**: Submit and monitor ML training jobs
309/// - **Data Integration**: Convert between EdgeFirst datasets and popular
310/// formats
311/// - **Progress Tracking**: Real-time progress updates for long-running
312/// operations
313///
314/// # Examples
315///
316/// ```no_run
317/// use edgefirst_client::{Client, DatasetID};
318/// use std::str::FromStr;
319///
320/// # async fn example() -> Result<(), edgefirst_client::Error> {
321/// // Create a new client and authenticate
322/// let mut client = Client::new()?;
323/// let client = client
324/// .with_login("your-email@example.com", "password")
325/// .await?;
326///
327/// // Or use an existing token
328/// let base_client = Client::new()?;
329/// let client = base_client.with_token("your-token-here")?;
330///
331/// // Get organization and projects
332/// let org = client.organization().await?;
333/// let projects = client.projects(None).await?;
334///
335/// // Work with datasets
336/// let dataset_id = DatasetID::from_str("ds-abc123")?;
337/// let dataset = client.dataset(dataset_id).await?;
338/// # Ok(())
339/// # }
340/// ```
341/// Client is Clone but cannot derive Debug due to dyn TokenStorage
342#[derive(Clone)]
343pub struct Client {
344 http: reqwest::Client,
345 url: String,
346 token: Arc<RwLock<String>>,
347 /// Token storage backend. When set, tokens are automatically persisted.
348 storage: Option<Arc<dyn TokenStorage>>,
349 /// Legacy token path field for backwards compatibility with
350 /// with_token_path(). Deprecated: Use with_storage() instead.
351 token_path: Option<PathBuf>,
352}
353
354impl std::fmt::Debug for Client {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("Client")
357 .field("url", &self.url)
358 .field("has_storage", &self.storage.is_some())
359 .field("token_path", &self.token_path)
360 .finish()
361 }
362}
363
364/// Private context struct for pagination operations
365struct FetchContext<'a> {
366 dataset_id: DatasetID,
367 annotation_set_id: Option<AnnotationSetID>,
368 groups: &'a [String],
369 types: Vec<String>,
370 labels: &'a HashMap<String, u64>,
371}
372
373impl Client {
374 /// Create a new unauthenticated client with the default saas server.
375 ///
376 /// By default, the client uses [`FileTokenStorage`] for token persistence.
377 /// Use [`with_storage`][Self::with_storage],
378 /// [`with_memory_storage`][Self::with_memory_storage],
379 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
380 /// behavior.
381 ///
382 /// To connect to a different server, use [`with_server`][Self::with_server]
383 /// or [`with_token`][Self::with_token] (tokens include the server
384 /// instance).
385 ///
386 /// This client is created without a token and will need to authenticate
387 /// before using methods that require authentication.
388 ///
389 /// # Examples
390 ///
391 /// ```rust,no_run
392 /// use edgefirst_client::Client;
393 ///
394 /// # fn main() -> Result<(), edgefirst_client::Error> {
395 /// // Create client with default file storage
396 /// let client = Client::new()?;
397 ///
398 /// // Create client without token persistence
399 /// let client = Client::new()?.with_memory_storage();
400 /// # Ok(())
401 /// # }
402 /// ```
403 pub fn new() -> Result<Self, Error> {
404 log_retry_configuration();
405
406 // Get timeout from environment or use default
407 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
408 .ok()
409 .and_then(|s| s.parse().ok())
410 .unwrap_or(30); // Default 30s timeout for API calls
411
412 // Create single HTTP client with URL-based retry policy
413 //
414 // The retry policy classifies requests into two categories:
415 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
416 // errors
417 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
418 //
419 // This allows the same client to handle both API calls and file operations
420 // with appropriate retry behavior for each. See retry.rs for details.
421 let http = reqwest::Client::builder()
422 .connect_timeout(Duration::from_secs(10))
423 .timeout(Duration::from_secs(timeout_secs))
424 .pool_idle_timeout(Duration::from_secs(90))
425 .pool_max_idle_per_host(10)
426 .retry(create_retry_policy())
427 .build()?;
428
429 // Default to file storage, loading any existing token
430 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
431 Ok(file_storage) => Arc::new(file_storage),
432 Err(e) => {
433 warn!(
434 "Could not initialize file token storage: {}. Using memory storage.",
435 e
436 );
437 Arc::new(MemoryTokenStorage::new())
438 }
439 };
440
441 // Try to load existing token from storage
442 let token = match storage.load() {
443 Ok(Some(t)) => t,
444 Ok(None) => String::new(),
445 Err(e) => {
446 warn!(
447 "Failed to load token from storage: {}. Starting with empty token.",
448 e
449 );
450 String::new()
451 }
452 };
453
454 Ok(Client {
455 http,
456 url: "https://edgefirst.studio".to_string(),
457 token: Arc::new(tokio::sync::RwLock::new(token)),
458 storage: Some(storage),
459 token_path: None,
460 })
461 }
462
463 /// Returns a new client connected to the specified server instance.
464 ///
465 /// The server parameter is an instance name that maps to a URL:
466 /// - `""` or `"saas"` → `https://edgefirst.studio`
467 /// - `"test"` → `https://test.edgefirst.studio`
468 /// - `"stage"` → `https://stage.edgefirst.studio`
469 /// - `"dev"` → `https://dev.edgefirst.studio`
470 /// - `"{name}"` → `https://{name}.edgefirst.studio`
471 ///
472 /// If a token is already set in the client, it will be dropped as tokens
473 /// are specific to the server instance.
474 ///
475 /// # Examples
476 ///
477 /// ```rust,no_run
478 /// use edgefirst_client::Client;
479 ///
480 /// # fn main() -> Result<(), edgefirst_client::Error> {
481 /// let client = Client::new()?.with_server("test")?;
482 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
483 /// # Ok(())
484 /// # }
485 /// ```
486 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
487 let url = match server {
488 "" | "saas" => "https://edgefirst.studio".to_string(),
489 name => format!("https://{}.edgefirst.studio", name),
490 };
491
492 // Clear token from storage when changing servers to prevent
493 // authentication issues with stale tokens from different instances
494 if let Some(ref storage) = self.storage
495 && let Err(e) = storage.clear()
496 {
497 warn!(
498 "Failed to clear token from storage when changing servers: {}",
499 e
500 );
501 }
502
503 Ok(Client {
504 url,
505 token: Arc::new(tokio::sync::RwLock::new(String::new())),
506 ..self.clone()
507 })
508 }
509
510 /// Returns a new client with the specified token storage backend.
511 ///
512 /// Use this to configure custom token storage, such as platform-specific
513 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
514 ///
515 /// # Examples
516 ///
517 /// ```rust,no_run
518 /// use edgefirst_client::{Client, FileTokenStorage};
519 /// use std::{path::PathBuf, sync::Arc};
520 ///
521 /// # fn main() -> Result<(), edgefirst_client::Error> {
522 /// // Use a custom file path for token storage
523 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
524 /// let client = Client::new()?.with_storage(Arc::new(storage));
525 /// # Ok(())
526 /// # }
527 /// ```
528 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
529 // Try to load existing token from the new storage
530 let token = match storage.load() {
531 Ok(Some(t)) => t,
532 Ok(None) => String::new(),
533 Err(e) => {
534 warn!(
535 "Failed to load token from storage: {}. Starting with empty token.",
536 e
537 );
538 String::new()
539 }
540 };
541
542 Client {
543 token: Arc::new(tokio::sync::RwLock::new(token)),
544 storage: Some(storage),
545 token_path: None,
546 ..self
547 }
548 }
549
550 /// Returns a new client with in-memory token storage (no persistence).
551 ///
552 /// Tokens are stored in memory only and lost when the application exits.
553 /// This is useful for testing or when you want to manage token persistence
554 /// externally.
555 ///
556 /// # Examples
557 ///
558 /// ```rust,no_run
559 /// use edgefirst_client::Client;
560 ///
561 /// # fn main() -> Result<(), edgefirst_client::Error> {
562 /// let client = Client::new()?.with_memory_storage();
563 /// # Ok(())
564 /// # }
565 /// ```
566 pub fn with_memory_storage(self) -> Self {
567 Client {
568 token: Arc::new(tokio::sync::RwLock::new(String::new())),
569 storage: Some(Arc::new(MemoryTokenStorage::new())),
570 token_path: None,
571 ..self
572 }
573 }
574
575 /// Returns a new client with no token storage.
576 ///
577 /// Tokens are not persisted. Use this when you want to manage tokens
578 /// entirely manually.
579 ///
580 /// # Examples
581 ///
582 /// ```rust,no_run
583 /// use edgefirst_client::Client;
584 ///
585 /// # fn main() -> Result<(), edgefirst_client::Error> {
586 /// let client = Client::new()?.with_no_storage();
587 /// # Ok(())
588 /// # }
589 /// ```
590 pub fn with_no_storage(self) -> Self {
591 Client {
592 storage: None,
593 token_path: None,
594 ..self
595 }
596 }
597
598 /// Returns a new client authenticated with the provided username and
599 /// password.
600 ///
601 /// The token is automatically persisted to storage (if configured).
602 ///
603 /// # Examples
604 ///
605 /// ```rust,no_run
606 /// use edgefirst_client::Client;
607 ///
608 /// # async fn example() -> Result<(), edgefirst_client::Error> {
609 /// let client = Client::new()?
610 /// .with_server("test")?
611 /// .with_login("user@example.com", "password")
612 /// .await?;
613 /// # Ok(())
614 /// # }
615 /// ```
616 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
617 let params = HashMap::from([("username", username), ("password", password)]);
618 let login: LoginResult = self
619 .rpc_without_auth("auth.login".to_owned(), Some(params))
620 .await?;
621
622 // Validate that the server returned a non-empty token
623 if login.token.is_empty() {
624 return Err(Error::EmptyToken);
625 }
626
627 // Persist token to storage if configured
628 if let Some(ref storage) = self.storage
629 && let Err(e) = storage.store(&login.token)
630 {
631 warn!("Failed to persist token to storage: {}", e);
632 }
633
634 Ok(Client {
635 token: Arc::new(tokio::sync::RwLock::new(login.token)),
636 ..self.clone()
637 })
638 }
639
640 /// Returns a new client which will load and save the token to the specified
641 /// path.
642 ///
643 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
644 /// [`FileTokenStorage`] instead for more flexible token management.
645 ///
646 /// This method is maintained for backwards compatibility with existing
647 /// code. It disables the default storage and uses file-based storage at
648 /// the specified path.
649 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
650 let token_path = match token_path {
651 Some(path) => path.to_path_buf(),
652 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
653 .ok_or_else(|| {
654 Error::IoError(std::io::Error::new(
655 std::io::ErrorKind::NotFound,
656 "Could not determine user config directory",
657 ))
658 })?
659 .config_dir()
660 .join("token"),
661 };
662
663 debug!("Using token path (legacy): {:?}", token_path);
664
665 let token = match token_path.exists() {
666 true => std::fs::read_to_string(&token_path)?,
667 false => "".to_string(),
668 };
669
670 if !token.is_empty() {
671 match self.with_token(&token) {
672 Ok(client) => Ok(Client {
673 token_path: Some(token_path),
674 storage: None, // Disable new storage when using legacy token_path
675 ..client
676 }),
677 Err(e) => {
678 // Token is corrupted or invalid - remove it and continue with no token
679 warn!(
680 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
681 token_path, e
682 );
683 if let Err(remove_err) = std::fs::remove_file(&token_path) {
684 warn!("Failed to remove corrupted token file: {:?}", remove_err);
685 }
686 Ok(Client {
687 token_path: Some(token_path),
688 storage: None,
689 ..self.clone()
690 })
691 }
692 }
693 } else {
694 Ok(Client {
695 token_path: Some(token_path),
696 storage: None,
697 ..self.clone()
698 })
699 }
700 }
701
702 /// Returns a new client authenticated with the provided token.
703 ///
704 /// The token is automatically persisted to storage (if configured).
705 /// The server URL is extracted from the token payload.
706 ///
707 /// # Examples
708 ///
709 /// ```rust,no_run
710 /// use edgefirst_client::Client;
711 ///
712 /// # fn main() -> Result<(), edgefirst_client::Error> {
713 /// let client = Client::new()?.with_token("your-jwt-token")?;
714 /// # Ok(())
715 /// # }
716 /// ```
717 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
718 if token.is_empty() {
719 return Ok(self.clone());
720 }
721
722 let token_parts: Vec<&str> = token.split('.').collect();
723 if token_parts.len() != 3 {
724 return Err(Error::InvalidToken);
725 }
726
727 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
728 .decode(token_parts[1])
729 .map_err(|_| Error::InvalidToken)?;
730 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
731 let server = match payload.get("server") {
732 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
733 None => return Err(Error::InvalidToken),
734 };
735
736 // Persist token to storage if configured
737 if let Some(ref storage) = self.storage
738 && let Err(e) = storage.store(token)
739 {
740 warn!("Failed to persist token to storage: {}", e);
741 }
742
743 Ok(Client {
744 url: format!("https://{}.edgefirst.studio", server),
745 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
746 ..self.clone()
747 })
748 }
749
750 /// Persist the current token to storage.
751 ///
752 /// This is automatically called when using [`with_login`][Self::with_login]
753 /// or [`with_token`][Self::with_token], so you typically don't need to call
754 /// this directly.
755 ///
756 /// If using the legacy `token_path` configuration, saves to the file path.
757 /// If using the new storage abstraction, saves to the configured storage.
758 pub async fn save_token(&self) -> Result<(), Error> {
759 let token = self.token.read().await;
760
761 // Try new storage first
762 if let Some(ref storage) = self.storage {
763 storage.store(&token)?;
764 debug!("Token saved to storage");
765 return Ok(());
766 }
767
768 // Fall back to legacy token_path behavior
769 let path = self.token_path.clone().unwrap_or_else(|| {
770 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
771 .map(|dirs| dirs.config_dir().join("token"))
772 .unwrap_or_else(|| PathBuf::from(".token"))
773 });
774
775 create_dir_all(path.parent().ok_or_else(|| {
776 Error::IoError(std::io::Error::new(
777 std::io::ErrorKind::InvalidInput,
778 "Token path has no parent directory",
779 ))
780 })?)?;
781 let mut file = std::fs::File::create(&path)?;
782 file.write_all(token.as_bytes())?;
783
784 debug!("Saved token to {:?}", path);
785
786 Ok(())
787 }
788
789 /// Return the version of the EdgeFirst Studio server for the current
790 /// client connection.
791 pub async fn version(&self) -> Result<String, Error> {
792 let version: HashMap<String, String> = self
793 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
794 .await?;
795 let version = version.get("version").ok_or(Error::InvalidResponse)?;
796 Ok(version.to_owned())
797 }
798
799 /// Clear the token used to authenticate the client with the server.
800 ///
801 /// Clears the token from memory and from storage (if configured).
802 /// If using the legacy `token_path` configuration, removes the token file.
803 pub async fn logout(&self) -> Result<(), Error> {
804 {
805 let mut token = self.token.write().await;
806 *token = "".to_string();
807 }
808
809 // Clear from new storage if configured
810 if let Some(ref storage) = self.storage
811 && let Err(e) = storage.clear()
812 {
813 warn!("Failed to clear token from storage: {}", e);
814 }
815
816 // Also clear legacy token_path if configured
817 if let Some(path) = &self.token_path
818 && path.exists()
819 {
820 fs::remove_file(path).await?;
821 }
822
823 Ok(())
824 }
825
826 /// Return the token used to authenticate the client with the server. When
827 /// logging into the server using a username and password, the token is
828 /// returned by the server and stored in the client for future interactions.
829 pub async fn token(&self) -> String {
830 self.token.read().await.clone()
831 }
832
833 /// Verify the token used to authenticate the client with the server. This
834 /// method is used to ensure that the token is still valid and has not
835 /// expired. If the token is invalid, the server will return an error and
836 /// the client will need to login again.
837 pub async fn verify_token(&self) -> Result<(), Error> {
838 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
839 .await?;
840 Ok::<(), Error>(())
841 }
842
843 /// Renew the token used to authenticate the client with the server.
844 ///
845 /// Refreshes the token before it expires. If the token has already expired,
846 /// the server will return an error and you will need to login again.
847 ///
848 /// The new token is automatically persisted to storage (if configured).
849 pub async fn renew_token(&self) -> Result<(), Error> {
850 let params = HashMap::from([("username".to_string(), self.username().await?)]);
851 let result: LoginResult = self
852 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
853 .await?;
854
855 {
856 let mut token = self.token.write().await;
857 *token = result.token.clone();
858 }
859
860 // Persist to new storage if configured
861 if let Some(ref storage) = self.storage
862 && let Err(e) = storage.store(&result.token)
863 {
864 warn!("Failed to persist renewed token to storage: {}", e);
865 }
866
867 // Also persist to legacy token_path if configured
868 if self.token_path.is_some() {
869 self.save_token().await?;
870 }
871
872 Ok(())
873 }
874
875 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
876 let token = self.token.read().await;
877 if token.is_empty() {
878 return Err(Error::EmptyToken);
879 }
880
881 let token_parts: Vec<&str> = token.split('.').collect();
882 if token_parts.len() != 3 {
883 return Err(Error::InvalidToken);
884 }
885
886 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
887 .decode(token_parts[1])
888 .map_err(|_| Error::InvalidToken)?;
889 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
890 match payload.get(field) {
891 Some(value) => Ok(value.to_owned()),
892 None => Err(Error::InvalidToken),
893 }
894 }
895
896 /// Returns the URL of the EdgeFirst Studio server for the current client.
897 pub fn url(&self) -> &str {
898 &self.url
899 }
900
901 /// Returns the username associated with the current token.
902 pub async fn username(&self) -> Result<String, Error> {
903 match self.token_field("username").await? {
904 serde_json::Value::String(username) => Ok(username),
905 _ => Err(Error::InvalidToken),
906 }
907 }
908
909 /// Returns the expiration time for the current token.
910 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
911 let ts = match self.token_field("exp").await? {
912 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
913 _ => return Err(Error::InvalidToken),
914 };
915
916 match DateTime::<Utc>::from_timestamp_secs(ts) {
917 Some(dt) => Ok(dt),
918 None => Err(Error::InvalidToken),
919 }
920 }
921
922 /// Returns the organization information for the current user.
923 pub async fn organization(&self) -> Result<Organization, Error> {
924 self.rpc::<(), Organization>("org.get".to_owned(), None)
925 .await
926 }
927
928 /// Returns a list of projects available to the user. The projects are
929 /// returned as a vector of Project objects. If a name filter is
930 /// provided, only projects matching the filter are returned.
931 ///
932 /// Results are sorted by match quality: exact matches first, then
933 /// case-insensitive exact matches, then shorter names (more specific),
934 /// then alphabetically.
935 ///
936 /// Projects are the top-level organizational unit in EdgeFirst Studio.
937 /// Projects contain datasets, trainers, and trainer sessions. Projects
938 /// are used to group related datasets and trainers together.
939 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
940 let projects = self
941 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
942 .await?;
943 if let Some(name) = name {
944 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
945 } else {
946 Ok(projects)
947 }
948 }
949
950 /// Return the project with the specified project ID. If the project does
951 /// not exist, an error is returned.
952 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
953 let params = HashMap::from([("project_id", project_id)]);
954 self.rpc("project.get".to_owned(), Some(params)).await
955 }
956
957 /// Returns a list of datasets available to the user. The datasets are
958 /// returned as a vector of Dataset objects. If a name filter is
959 /// provided, only datasets matching the filter are returned.
960 ///
961 /// Results are sorted by match quality: exact matches first, then
962 /// case-insensitive exact matches, then shorter names (more specific),
963 /// then alphabetically. This ensures "Deer" returns before "Deer
964 /// Roundtrip".
965 pub async fn datasets(
966 &self,
967 project_id: ProjectID,
968 name: Option<&str>,
969 ) -> Result<Vec<Dataset>, Error> {
970 let params = HashMap::from([("project_id", project_id)]);
971 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
972 if let Some(name) = name {
973 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
974 } else {
975 Ok(datasets)
976 }
977 }
978
979 /// Return the dataset with the specified dataset ID. If the dataset does
980 /// not exist, an error is returned.
981 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
982 let params = HashMap::from([("dataset_id", dataset_id)]);
983 self.rpc("dataset.get".to_owned(), Some(params)).await
984 }
985
986 /// Lists the labels for the specified dataset.
987 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
988 let params = HashMap::from([("dataset_id", dataset_id)]);
989 self.rpc("label.list".to_owned(), Some(params)).await
990 }
991
992 /// Add a new label to the dataset with the specified name.
993 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
994 let new_label = NewLabel {
995 dataset_id,
996 labels: vec![NewLabelObject {
997 name: name.to_owned(),
998 }],
999 };
1000 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1001 Ok(())
1002 }
1003
1004 /// Removes the label with the specified ID from the dataset. Label IDs are
1005 /// globally unique so the dataset_id is not required.
1006 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1007 let params = HashMap::from([("label_id", label_id)]);
1008 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1009 Ok(())
1010 }
1011
1012 /// Creates a new dataset in the specified project.
1013 ///
1014 /// # Arguments
1015 ///
1016 /// * `project_id` - The ID of the project to create the dataset in
1017 /// * `name` - The name of the new dataset
1018 /// * `description` - Optional description for the dataset
1019 ///
1020 /// # Returns
1021 ///
1022 /// Returns the dataset ID of the newly created dataset.
1023 pub async fn create_dataset(
1024 &self,
1025 project_id: &str,
1026 name: &str,
1027 description: Option<&str>,
1028 ) -> Result<DatasetID, Error> {
1029 let mut params = HashMap::new();
1030 params.insert("project_id", project_id);
1031 params.insert("name", name);
1032 if let Some(desc) = description {
1033 params.insert("description", desc);
1034 }
1035
1036 #[derive(Deserialize)]
1037 struct CreateDatasetResult {
1038 id: DatasetID,
1039 }
1040
1041 let result: CreateDatasetResult =
1042 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1043 Ok(result.id)
1044 }
1045
1046 /// Deletes a dataset by marking it as deleted.
1047 ///
1048 /// # Arguments
1049 ///
1050 /// * `dataset_id` - The ID of the dataset to delete
1051 ///
1052 /// # Returns
1053 ///
1054 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1055 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1056 let params = HashMap::from([("id", dataset_id)]);
1057 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1058 Ok(())
1059 }
1060
1061 /// Updates the label with the specified ID to have the new name or index.
1062 /// Label IDs cannot be changed. Label IDs are globally unique so the
1063 /// dataset_id is not required.
1064 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1065 #[derive(Serialize)]
1066 struct Params {
1067 dataset_id: DatasetID,
1068 label_id: u64,
1069 label_name: String,
1070 label_index: u64,
1071 }
1072
1073 let _: String = self
1074 .rpc(
1075 "label.update".to_owned(),
1076 Some(Params {
1077 dataset_id: label.dataset_id(),
1078 label_id: label.id(),
1079 label_name: label.name().to_owned(),
1080 label_index: label.index(),
1081 }),
1082 )
1083 .await?;
1084 Ok(())
1085 }
1086
1087 /// Downloads dataset samples to the local filesystem.
1088 ///
1089 /// # Arguments
1090 ///
1091 /// * `dataset_id` - The unique identifier of the dataset
1092 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1093 /// * `file_types` - File types to download (e.g., Image, LidarPcd)
1094 /// * `output` - Local directory to save downloaded files
1095 /// * `flatten` - If true, download all files to output root without
1096 /// sequence subdirectories. When flattening, filenames are prefixed with
1097 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1098 /// unavailable) unless the filename already starts with
1099 /// `{sequence_name}_`, to avoid conflicts between sequences.
1100 /// * `progress` - Optional channel for progress updates
1101 ///
1102 /// # Returns
1103 ///
1104 /// Returns `Ok(())` on success or an error if download fails.
1105 ///
1106 /// # Example
1107 ///
1108 /// ```rust,no_run
1109 /// # use edgefirst_client::{Client, DatasetID, FileType};
1110 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1111 /// let client = Client::new()?.with_token_path(None)?;
1112 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1113 ///
1114 /// // Download with sequence subdirectories (default)
1115 /// client
1116 /// .download_dataset(
1117 /// dataset_id,
1118 /// &[],
1119 /// &[FileType::Image],
1120 /// "./data".into(),
1121 /// false,
1122 /// None,
1123 /// )
1124 /// .await?;
1125 ///
1126 /// // Download flattened (all files in one directory)
1127 /// client
1128 /// .download_dataset(
1129 /// dataset_id,
1130 /// &[],
1131 /// &[FileType::Image],
1132 /// "./data".into(),
1133 /// true,
1134 /// None,
1135 /// )
1136 /// .await?;
1137 /// # Ok(())
1138 /// # }
1139 /// ```
1140 pub async fn download_dataset(
1141 &self,
1142 dataset_id: DatasetID,
1143 groups: &[String],
1144 file_types: &[FileType],
1145 output: PathBuf,
1146 flatten: bool,
1147 progress: Option<Sender<Progress>>,
1148 ) -> Result<(), Error> {
1149 let samples = self
1150 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1151 .await?;
1152 fs::create_dir_all(&output).await?;
1153
1154 let client = self.clone();
1155 let file_types = file_types.to_vec();
1156 let output = output.clone();
1157
1158 parallel_foreach_items(samples, progress, move |sample| {
1159 let client = client.clone();
1160 let file_types = file_types.clone();
1161 let output = output.clone();
1162
1163 async move {
1164 for file_type in file_types {
1165 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1166 let (file_ext, is_image) = match file_type.clone() {
1167 FileType::Image => (
1168 infer::get(&data)
1169 .expect("Failed to identify image file format for sample")
1170 .extension()
1171 .to_string(),
1172 true,
1173 ),
1174 other => (other.to_string(), false),
1175 };
1176
1177 // Determine target directory based on sequence membership and flatten
1178 // option
1179 // - flatten=false + sequence_name: dataset/sequence_name/
1180 // - flatten=false + no sequence: dataset/ (root level)
1181 // - flatten=true: dataset/ (all files in output root)
1182 // NOTE: group (train/val/test) is NOT used for directory structure
1183 let sequence_dir = sample
1184 .sequence_name()
1185 .map(|name| sanitize_path_component(name));
1186
1187 let target_dir = if flatten {
1188 output.clone()
1189 } else {
1190 sequence_dir
1191 .as_ref()
1192 .map(|seq| output.join(seq))
1193 .unwrap_or_else(|| output.clone())
1194 };
1195 fs::create_dir_all(&target_dir).await?;
1196
1197 let sanitized_sample_name = sample
1198 .name()
1199 .map(|name| sanitize_path_component(&name))
1200 .unwrap_or_else(|| "unknown".to_string());
1201
1202 let image_name = sample.image_name().map(sanitize_path_component);
1203
1204 // Construct filename with smart prefixing for flatten mode
1205 // When flatten=true and sample belongs to a sequence:
1206 // - Check if filename already starts with "{sequence_name}_"
1207 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1208 // - If yes, use filename as-is (already uniquely named)
1209 let file_name = if is_image {
1210 if let Some(img_name) = image_name {
1211 Self::build_filename(
1212 &img_name,
1213 flatten,
1214 sequence_dir.as_ref(),
1215 sample.frame_number(),
1216 )
1217 } else {
1218 format!("{}.{}", sanitized_sample_name, file_ext)
1219 }
1220 } else {
1221 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1222 Self::build_filename(
1223 &base_name,
1224 flatten,
1225 sequence_dir.as_ref(),
1226 sample.frame_number(),
1227 )
1228 };
1229
1230 let file_path = target_dir.join(&file_name);
1231
1232 let mut file = File::create(&file_path).await?;
1233 file.write_all(&data).await?;
1234 } else {
1235 warn!(
1236 "No data for sample: {}",
1237 sample
1238 .id()
1239 .map(|id| id.to_string())
1240 .unwrap_or_else(|| "unknown".to_string())
1241 );
1242 }
1243 }
1244
1245 Ok(())
1246 }
1247 })
1248 .await
1249 }
1250
1251 /// Builds a filename with smart prefixing for flatten mode.
1252 ///
1253 /// When flattening sequences into a single directory, this function ensures
1254 /// unique filenames by checking if the sequence prefix already exists and
1255 /// adding it if necessary.
1256 ///
1257 /// # Logic
1258 ///
1259 /// - If `flatten=false`: returns `base_name` unchanged
1260 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
1261 /// - If `flatten=true` and in sequence:
1262 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
1263 /// unchanged
1264 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
1265 /// `{sequence_name}_{base_name}`
1266 fn build_filename(
1267 base_name: &str,
1268 flatten: bool,
1269 sequence_name: Option<&String>,
1270 frame_number: Option<u32>,
1271 ) -> String {
1272 if !flatten || sequence_name.is_none() {
1273 return base_name.to_string();
1274 }
1275
1276 let seq_name = sequence_name.unwrap();
1277 let prefix = format!("{}_", seq_name);
1278
1279 // Check if already prefixed with sequence name
1280 if base_name.starts_with(&prefix) {
1281 base_name.to_string()
1282 } else {
1283 // Add sequence (and optionally frame) prefix
1284 match frame_number {
1285 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1286 None => format!("{}{}", prefix, base_name),
1287 }
1288 }
1289 }
1290
1291 /// List available annotation sets for the specified dataset.
1292 pub async fn annotation_sets(
1293 &self,
1294 dataset_id: DatasetID,
1295 ) -> Result<Vec<AnnotationSet>, Error> {
1296 let params = HashMap::from([("dataset_id", dataset_id)]);
1297 self.rpc("annset.list".to_owned(), Some(params)).await
1298 }
1299
1300 /// Create a new annotation set for the specified dataset.
1301 ///
1302 /// # Arguments
1303 ///
1304 /// * `dataset_id` - The ID of the dataset to create the annotation set in
1305 /// * `name` - The name of the new annotation set
1306 /// * `description` - Optional description for the annotation set
1307 ///
1308 /// # Returns
1309 ///
1310 /// Returns the annotation set ID of the newly created annotation set.
1311 pub async fn create_annotation_set(
1312 &self,
1313 dataset_id: DatasetID,
1314 name: &str,
1315 description: Option<&str>,
1316 ) -> Result<AnnotationSetID, Error> {
1317 #[derive(Serialize)]
1318 struct Params<'a> {
1319 dataset_id: DatasetID,
1320 name: &'a str,
1321 operator: &'a str,
1322 #[serde(skip_serializing_if = "Option::is_none")]
1323 description: Option<&'a str>,
1324 }
1325
1326 #[derive(Deserialize)]
1327 struct CreateAnnotationSetResult {
1328 id: AnnotationSetID,
1329 }
1330
1331 let username = self.username().await?;
1332 let result: CreateAnnotationSetResult = self
1333 .rpc(
1334 "annset.add".to_owned(),
1335 Some(Params {
1336 dataset_id,
1337 name,
1338 operator: &username,
1339 description,
1340 }),
1341 )
1342 .await?;
1343 Ok(result.id)
1344 }
1345
1346 /// Deletes an annotation set by marking it as deleted.
1347 ///
1348 /// # Arguments
1349 ///
1350 /// * `annotation_set_id` - The ID of the annotation set to delete
1351 ///
1352 /// # Returns
1353 ///
1354 /// Returns `Ok(())` if the annotation set was successfully marked as
1355 /// deleted.
1356 pub async fn delete_annotation_set(
1357 &self,
1358 annotation_set_id: AnnotationSetID,
1359 ) -> Result<(), Error> {
1360 let params = HashMap::from([("id", annotation_set_id)]);
1361 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1362 Ok(())
1363 }
1364
1365 /// Retrieve the annotation set with the specified ID.
1366 pub async fn annotation_set(
1367 &self,
1368 annotation_set_id: AnnotationSetID,
1369 ) -> Result<AnnotationSet, Error> {
1370 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1371 self.rpc("annset.get".to_owned(), Some(params)).await
1372 }
1373
1374 /// Get the annotations for the specified annotation set with the
1375 /// requested annotation types. The annotation types are used to filter
1376 /// the annotations returned. The groups parameter is used to filter for
1377 /// dataset groups (train, val, test). Images which do not have any
1378 /// annotations are also included in the result as long as they are in the
1379 /// requested groups (when specified).
1380 ///
1381 /// The result is a vector of Annotations objects which contain the
1382 /// full dataset along with the annotations for the specified types.
1383 ///
1384 /// To get the annotations as a DataFrame, use the `annotations_dataframe`
1385 /// method instead.
1386 pub async fn annotations(
1387 &self,
1388 annotation_set_id: AnnotationSetID,
1389 groups: &[String],
1390 annotation_types: &[AnnotationType],
1391 progress: Option<Sender<Progress>>,
1392 ) -> Result<Vec<Annotation>, Error> {
1393 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1394 let labels = self
1395 .labels(dataset_id)
1396 .await?
1397 .into_iter()
1398 .map(|label| (label.name().to_string(), label.index()))
1399 .collect::<HashMap<_, _>>();
1400 let total = self
1401 .samples_count(
1402 dataset_id,
1403 Some(annotation_set_id),
1404 annotation_types,
1405 groups,
1406 &[],
1407 )
1408 .await?
1409 .total as usize;
1410
1411 if total == 0 {
1412 return Ok(vec![]);
1413 }
1414
1415 let context = FetchContext {
1416 dataset_id,
1417 annotation_set_id: Some(annotation_set_id),
1418 groups,
1419 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1420 labels: &labels,
1421 };
1422
1423 self.fetch_annotations_paginated(context, total, progress)
1424 .await
1425 }
1426
1427 async fn fetch_annotations_paginated(
1428 &self,
1429 context: FetchContext<'_>,
1430 total: usize,
1431 progress: Option<Sender<Progress>>,
1432 ) -> Result<Vec<Annotation>, Error> {
1433 let mut annotations = vec![];
1434 let mut continue_token: Option<String> = None;
1435 let mut current = 0;
1436
1437 loop {
1438 let params = SamplesListParams {
1439 dataset_id: context.dataset_id,
1440 annotation_set_id: context.annotation_set_id,
1441 types: context.types.clone(),
1442 group_names: context.groups.to_vec(),
1443 continue_token,
1444 };
1445
1446 let result: SamplesListResult =
1447 self.rpc("samples.list".to_owned(), Some(params)).await?;
1448 current += result.samples.len();
1449 continue_token = result.continue_token;
1450
1451 if result.samples.is_empty() {
1452 break;
1453 }
1454
1455 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1456
1457 if let Some(progress) = &progress {
1458 let _ = progress.send(Progress { current, total }).await;
1459 }
1460
1461 match &continue_token {
1462 Some(token) if !token.is_empty() => continue,
1463 _ => break,
1464 }
1465 }
1466
1467 drop(progress);
1468 Ok(annotations)
1469 }
1470
1471 fn process_sample_annotations(
1472 &self,
1473 samples: &[Sample],
1474 labels: &HashMap<String, u64>,
1475 annotations: &mut Vec<Annotation>,
1476 ) {
1477 for sample in samples {
1478 if sample.annotations().is_empty() {
1479 let mut annotation = Annotation::new();
1480 annotation.set_sample_id(sample.id());
1481 annotation.set_name(sample.name());
1482 annotation.set_sequence_name(sample.sequence_name().cloned());
1483 annotation.set_frame_number(sample.frame_number());
1484 annotation.set_group(sample.group().cloned());
1485 annotations.push(annotation);
1486 continue;
1487 }
1488
1489 for annotation in sample.annotations() {
1490 let mut annotation = annotation.clone();
1491 annotation.set_sample_id(sample.id());
1492 annotation.set_name(sample.name());
1493 annotation.set_sequence_name(sample.sequence_name().cloned());
1494 annotation.set_frame_number(sample.frame_number());
1495 annotation.set_group(sample.group().cloned());
1496 Self::set_label_index_from_map(&mut annotation, labels);
1497 annotations.push(annotation);
1498 }
1499 }
1500 }
1501
1502 /// Helper to parse frame number from image_name when sequence_name is
1503 /// present. This ensures frame_number is always derived from the image
1504 /// filename, not from the server's frame_number field (which may be
1505 /// inconsistent).
1506 ///
1507 /// Returns Some(frame_number) if sequence_name is present and frame can be
1508 /// parsed, otherwise None.
1509 fn parse_frame_from_image_name(
1510 image_name: Option<&String>,
1511 sequence_name: Option<&String>,
1512 ) -> Option<u32> {
1513 use std::path::Path;
1514
1515 let sequence = sequence_name?;
1516 let name = image_name?;
1517
1518 // Extract stem (remove extension)
1519 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
1520
1521 // Parse frame from format: "sequence_XXX" where XXX is the frame number
1522 stem.strip_prefix(sequence)
1523 .and_then(|suffix| suffix.strip_prefix('_'))
1524 .and_then(|frame_str| frame_str.parse::<u32>().ok())
1525 }
1526
1527 /// Helper to set label index from a label map
1528 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
1529 if let Some(label) = annotation.label() {
1530 annotation.set_label_index(Some(labels[label.as_str()]));
1531 }
1532 }
1533
1534 pub async fn samples_count(
1535 &self,
1536 dataset_id: DatasetID,
1537 annotation_set_id: Option<AnnotationSetID>,
1538 annotation_types: &[AnnotationType],
1539 groups: &[String],
1540 types: &[FileType],
1541 ) -> Result<SamplesCountResult, Error> {
1542 let types = annotation_types
1543 .iter()
1544 .map(|t| t.to_string())
1545 .chain(types.iter().map(|t| t.to_string()))
1546 .collect::<Vec<_>>();
1547
1548 let params = SamplesListParams {
1549 dataset_id,
1550 annotation_set_id,
1551 group_names: groups.to_vec(),
1552 types,
1553 continue_token: None,
1554 };
1555
1556 self.rpc("samples.count".to_owned(), Some(params)).await
1557 }
1558
1559 pub async fn samples(
1560 &self,
1561 dataset_id: DatasetID,
1562 annotation_set_id: Option<AnnotationSetID>,
1563 annotation_types: &[AnnotationType],
1564 groups: &[String],
1565 types: &[FileType],
1566 progress: Option<Sender<Progress>>,
1567 ) -> Result<Vec<Sample>, Error> {
1568 let types_vec = annotation_types
1569 .iter()
1570 .map(|t| t.to_string())
1571 .chain(types.iter().map(|t| t.to_string()))
1572 .collect::<Vec<_>>();
1573 let labels = self
1574 .labels(dataset_id)
1575 .await?
1576 .into_iter()
1577 .map(|label| (label.name().to_string(), label.index()))
1578 .collect::<HashMap<_, _>>();
1579 let total = self
1580 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
1581 .await?
1582 .total as usize;
1583
1584 if total == 0 {
1585 return Ok(vec![]);
1586 }
1587
1588 let context = FetchContext {
1589 dataset_id,
1590 annotation_set_id,
1591 groups,
1592 types: types_vec,
1593 labels: &labels,
1594 };
1595
1596 self.fetch_samples_paginated(context, total, progress).await
1597 }
1598
1599 async fn fetch_samples_paginated(
1600 &self,
1601 context: FetchContext<'_>,
1602 total: usize,
1603 progress: Option<Sender<Progress>>,
1604 ) -> Result<Vec<Sample>, Error> {
1605 let mut samples = vec![];
1606 let mut continue_token: Option<String> = None;
1607 let mut current = 0;
1608
1609 loop {
1610 let params = SamplesListParams {
1611 dataset_id: context.dataset_id,
1612 annotation_set_id: context.annotation_set_id,
1613 types: context.types.clone(),
1614 group_names: context.groups.to_vec(),
1615 continue_token: continue_token.clone(),
1616 };
1617
1618 let result: SamplesListResult =
1619 self.rpc("samples.list".to_owned(), Some(params)).await?;
1620 current += result.samples.len();
1621 continue_token = result.continue_token;
1622
1623 if result.samples.is_empty() {
1624 break;
1625 }
1626
1627 samples.append(
1628 &mut result
1629 .samples
1630 .into_iter()
1631 .map(|s| {
1632 // Use server's frame_number if valid (>= 0 after deserialization)
1633 // Otherwise parse from image_name as fallback
1634 // This ensures we respect explicit frame_number from uploads
1635 // while still handling legacy data that only has filename encoding
1636 let frame_number = s.frame_number.or_else(|| {
1637 Self::parse_frame_from_image_name(
1638 s.image_name.as_ref(),
1639 s.sequence_name.as_ref(),
1640 )
1641 });
1642
1643 let mut anns = s.annotations().to_vec();
1644 for ann in &mut anns {
1645 // Set annotation fields from parent sample
1646 ann.set_name(s.name());
1647 ann.set_group(s.group().cloned());
1648 ann.set_sequence_name(s.sequence_name().cloned());
1649 ann.set_frame_number(frame_number);
1650 Self::set_label_index_from_map(ann, context.labels);
1651 }
1652 s.with_annotations(anns).with_frame_number(frame_number)
1653 })
1654 .collect::<Vec<_>>(),
1655 );
1656
1657 if let Some(progress) = &progress {
1658 let _ = progress.send(Progress { current, total }).await;
1659 }
1660
1661 match &continue_token {
1662 Some(token) if !token.is_empty() => continue,
1663 _ => break,
1664 }
1665 }
1666
1667 drop(progress);
1668 Ok(samples)
1669 }
1670
1671 /// Populates (imports) samples into a dataset using the `samples.populate2`
1672 /// API.
1673 ///
1674 /// This method creates new samples in the specified dataset, optionally
1675 /// with annotations and sensor data files. For each sample, the `files`
1676 /// field is checked for local file paths. If a filename is a valid path
1677 /// to an existing file, the file will be automatically uploaded to S3
1678 /// using presigned URLs returned by the server. The filename in the
1679 /// request is replaced with the basename (path removed) before sending
1680 /// to the server.
1681 ///
1682 /// # Important Notes
1683 ///
1684 /// - **`annotation_set_id` is REQUIRED** when importing samples with
1685 /// annotations. Without it, the server will accept the request but will
1686 /// not save the annotation data. Use [`Client::annotation_sets`] to query
1687 /// available annotation sets for a dataset, or create a new one via the
1688 /// Studio UI.
1689 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
1690 /// boxes. Divide pixel coordinates by image width/height before creating
1691 /// [`Box2d`](crate::Box2d) annotations.
1692 /// - **Files are uploaded automatically** when the filename is a valid
1693 /// local path. The method will replace the full path with just the
1694 /// basename before sending to the server.
1695 /// - **Image dimensions are extracted automatically** for image files using
1696 /// the `imagesize` crate. The width/height are sent to the server, but
1697 /// note that the server currently doesn't return these fields when
1698 /// fetching samples back.
1699 /// - **UUIDs are generated automatically** if not provided. If you need
1700 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
1701 /// that the server doesn't currently return UUIDs in sample queries.
1702 ///
1703 /// # Arguments
1704 ///
1705 /// * `dataset_id` - The ID of the dataset to populate
1706 /// * `annotation_set_id` - **Required** if samples contain annotations,
1707 /// otherwise they will be ignored. Query with
1708 /// [`Client::annotation_sets`].
1709 /// * `samples` - Vector of samples to import with metadata and file
1710 /// references. For files, use the full local path - it will be uploaded
1711 /// automatically. UUIDs and image dimensions will be
1712 /// auto-generated/extracted if not provided.
1713 ///
1714 /// # Returns
1715 ///
1716 /// Returns the API result with sample UUIDs and upload status.
1717 ///
1718 /// # Example
1719 ///
1720 /// ```no_run
1721 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
1722 ///
1723 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1724 /// # let client = Client::new()?.with_login("user", "pass").await?;
1725 /// # let dataset_id = DatasetID::from(1);
1726 /// // Query available annotation sets for the dataset
1727 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
1728 /// let annotation_set_id = annotation_sets
1729 /// .first()
1730 /// .ok_or_else(|| {
1731 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
1732 /// })?
1733 /// .id();
1734 ///
1735 /// // Create sample with annotation (UUID will be auto-generated)
1736 /// let mut sample = Sample::new();
1737 /// sample.width = Some(1920);
1738 /// sample.height = Some(1080);
1739 /// sample.group = Some("train".to_string());
1740 ///
1741 /// // Add file - use full path to local file, it will be uploaded automatically
1742 /// sample.files = vec![SampleFile::with_filename(
1743 /// "image".to_string(),
1744 /// "/path/to/image.jpg".to_string(),
1745 /// )];
1746 ///
1747 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
1748 /// let mut annotation = Annotation::new();
1749 /// annotation.set_label(Some("person".to_string()));
1750 /// // Normalize pixel coordinates by dividing by image dimensions
1751 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
1752 /// annotation.set_box2d(Some(bbox));
1753 /// sample.annotations = vec![annotation];
1754 ///
1755 /// // Populate with annotation_set_id (REQUIRED for annotations)
1756 /// let result = client
1757 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
1758 /// .await?;
1759 /// # Ok(())
1760 /// # }
1761 /// ```
1762 pub async fn populate_samples(
1763 &self,
1764 dataset_id: DatasetID,
1765 annotation_set_id: Option<AnnotationSetID>,
1766 samples: Vec<Sample>,
1767 progress: Option<Sender<Progress>>,
1768 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1769 use crate::api::SamplesPopulateParams;
1770
1771 // Track which files need to be uploaded
1772 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1773
1774 // Process samples to detect local files and generate UUIDs
1775 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
1776
1777 let has_files_to_upload = !files_to_upload.is_empty();
1778
1779 // Call populate API with presigned_urls=true if we have files to upload
1780 let params = SamplesPopulateParams {
1781 dataset_id,
1782 annotation_set_id,
1783 presigned_urls: Some(has_files_to_upload),
1784 samples,
1785 };
1786
1787 let results: Vec<crate::SamplesPopulateResult> = self
1788 .rpc("samples.populate2".to_owned(), Some(params))
1789 .await?;
1790
1791 // Upload files if we have any
1792 if has_files_to_upload {
1793 self.upload_sample_files(&results, files_to_upload, progress)
1794 .await?;
1795 }
1796
1797 Ok(results)
1798 }
1799
1800 fn prepare_samples_for_upload(
1801 &self,
1802 samples: Vec<Sample>,
1803 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1804 ) -> Result<Vec<Sample>, Error> {
1805 Ok(samples
1806 .into_iter()
1807 .map(|mut sample| {
1808 // Generate UUID if not provided
1809 if sample.uuid.is_none() {
1810 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1811 }
1812
1813 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
1814
1815 // Process files: detect local paths and queue for upload
1816 let files_copy = sample.files.clone();
1817 let updated_files: Vec<crate::SampleFile> = files_copy
1818 .iter()
1819 .map(|file| {
1820 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
1821 })
1822 .collect();
1823
1824 sample.files = updated_files;
1825 sample
1826 })
1827 .collect())
1828 }
1829
1830 fn process_sample_file(
1831 &self,
1832 file: &crate::SampleFile,
1833 sample_uuid: &str,
1834 sample: &mut Sample,
1835 files_to_upload: &mut Vec<(String, String, PathBuf, String)>,
1836 ) -> crate::SampleFile {
1837 use std::path::Path;
1838
1839 if let Some(filename) = file.filename() {
1840 let path = Path::new(filename);
1841
1842 // Check if this is a valid local file path
1843 if path.exists()
1844 && path.is_file()
1845 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
1846 {
1847 // For image files, try to extract dimensions if not already set
1848 if file.file_type() == "image"
1849 && (sample.width.is_none() || sample.height.is_none())
1850 && let Ok(size) = imagesize::size(path)
1851 {
1852 sample.width = Some(size.width as u32);
1853 sample.height = Some(size.height as u32);
1854 }
1855
1856 // Store the full path for later upload
1857 files_to_upload.push((
1858 sample_uuid.to_string(),
1859 file.file_type().to_string(),
1860 path.to_path_buf(),
1861 basename.to_string(),
1862 ));
1863
1864 // Return SampleFile with just the basename
1865 return crate::SampleFile::with_filename(
1866 file.file_type().to_string(),
1867 basename.to_string(),
1868 );
1869 }
1870 }
1871 // Return the file unchanged if not a local path
1872 file.clone()
1873 }
1874
1875 async fn upload_sample_files(
1876 &self,
1877 results: &[crate::SamplesPopulateResult],
1878 files_to_upload: Vec<(String, String, PathBuf, String)>,
1879 progress: Option<Sender<Progress>>,
1880 ) -> Result<(), Error> {
1881 // Build a map from (sample_uuid, basename) -> local_path
1882 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1883 for (uuid, _file_type, path, basename) in files_to_upload {
1884 upload_map.insert((uuid, basename), path);
1885 }
1886
1887 let http = self.http.clone();
1888
1889 // Extract the data we need for parallel upload
1890 let upload_tasks: Vec<_> = results
1891 .iter()
1892 .map(|result| (result.uuid.clone(), result.urls.clone()))
1893 .collect();
1894
1895 parallel_foreach_items(upload_tasks, progress.clone(), move |(uuid, urls)| {
1896 let http = http.clone();
1897 let upload_map = upload_map.clone();
1898
1899 async move {
1900 // Upload all files for this sample
1901 for url_info in &urls {
1902 if let Some(local_path) =
1903 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
1904 {
1905 // Upload the file
1906 upload_file_to_presigned_url(
1907 http.clone(),
1908 &url_info.url,
1909 local_path.clone(),
1910 )
1911 .await?;
1912 }
1913 }
1914
1915 Ok(())
1916 }
1917 })
1918 .await
1919 }
1920
1921 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1922 // Uses default 120s timeout from client
1923 let resp = self.http.get(url).send().await?;
1924
1925 if !resp.status().is_success() {
1926 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
1927 }
1928
1929 let bytes = resp.bytes().await?;
1930 Ok(bytes.to_vec())
1931 }
1932
1933 /// Get the AnnotationGroup for the specified annotation set with the
1934 /// requested annotation types. The annotation type is used to filter
1935 /// the annotations returned. Images which do not have any annotations
1936 /// are included in the result.
1937 ///
1938 /// Get annotations as a DataFrame (2025.01 schema).
1939 ///
1940 /// **DEPRECATED**: Use [`Client::samples_dataframe()`] instead for full
1941 /// 2025.10 schema support including optional metadata columns.
1942 ///
1943 /// The result is a DataFrame following the EdgeFirst Dataset Format
1944 /// definition with 9 columns (original schema). Does not include new
1945 /// optional columns added in 2025.10.
1946 ///
1947 /// # Migration
1948 ///
1949 /// ```rust,no_run
1950 /// # use edgefirst_client::Client;
1951 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1952 /// # let client = Client::new()?;
1953 /// # let dataset_id = 1.into();
1954 /// # let annotation_set_id = 1.into();
1955 /// # let groups = vec![];
1956 /// # let types = vec![];
1957 /// // OLD (deprecated):
1958 /// let df = client
1959 /// .annotations_dataframe(annotation_set_id, &groups, &types, None)
1960 /// .await?;
1961 ///
1962 /// // NEW (recommended):
1963 /// let df = client
1964 /// .samples_dataframe(dataset_id, Some(annotation_set_id), &groups, &types, None)
1965 /// .await?;
1966 /// # Ok(())
1967 /// # }
1968 /// ```
1969 ///
1970 /// To get the annotations as a vector of Annotation objects, use the
1971 /// `annotations` method instead.
1972 #[deprecated(
1973 since = "0.8.0",
1974 note = "Use `samples_dataframe()` for complete 2025.10 schema support"
1975 )]
1976 #[cfg(feature = "polars")]
1977 pub async fn annotations_dataframe(
1978 &self,
1979 annotation_set_id: AnnotationSetID,
1980 groups: &[String],
1981 types: &[AnnotationType],
1982 progress: Option<Sender<Progress>>,
1983 ) -> Result<DataFrame, Error> {
1984 #[allow(deprecated)]
1985 use crate::dataset::annotations_dataframe;
1986
1987 let annotations = self
1988 .annotations(annotation_set_id, groups, types, progress)
1989 .await?;
1990 #[allow(deprecated)]
1991 annotations_dataframe(&annotations)
1992 }
1993
1994 /// Get samples as a DataFrame with complete 2025.10 schema.
1995 ///
1996 /// This is the recommended method for obtaining dataset annotations in
1997 /// DataFrame format. It includes all sample metadata (size, location,
1998 /// pose, degradation) as optional columns.
1999 ///
2000 /// # Arguments
2001 ///
2002 /// * `dataset_id` - Dataset identifier
2003 /// * `annotation_set_id` - Optional annotation set filter
2004 /// * `groups` - Dataset groups to include (train, val, test)
2005 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2006 /// * `progress` - Optional progress callback
2007 ///
2008 /// # Example
2009 ///
2010 /// ```rust,no_run
2011 /// use edgefirst_client::Client;
2012 ///
2013 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2014 /// # let client = Client::new()?;
2015 /// # let dataset_id = 1.into();
2016 /// # let annotation_set_id = 1.into();
2017 /// let df = client
2018 /// .samples_dataframe(
2019 /// dataset_id,
2020 /// Some(annotation_set_id),
2021 /// &["train".to_string()],
2022 /// &[],
2023 /// None,
2024 /// )
2025 /// .await?;
2026 /// println!("DataFrame shape: {:?}", df.shape());
2027 /// # Ok(())
2028 /// # }
2029 /// ```
2030 #[cfg(feature = "polars")]
2031 pub async fn samples_dataframe(
2032 &self,
2033 dataset_id: DatasetID,
2034 annotation_set_id: Option<AnnotationSetID>,
2035 groups: &[String],
2036 types: &[AnnotationType],
2037 progress: Option<Sender<Progress>>,
2038 ) -> Result<DataFrame, Error> {
2039 use crate::dataset::samples_dataframe;
2040
2041 let samples = self
2042 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
2043 .await?;
2044 samples_dataframe(&samples)
2045 }
2046
2047 /// List available snapshots. If a name is provided, only snapshots
2048 /// containing that name are returned.
2049 ///
2050 /// Results are sorted by match quality: exact matches first, then
2051 /// case-insensitive exact matches, then shorter descriptions (more
2052 /// specific), then alphabetically.
2053 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
2054 let snapshots: Vec<Snapshot> = self
2055 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
2056 .await?;
2057 if let Some(name) = name {
2058 Ok(filter_and_sort_by_name(snapshots, name, |s| {
2059 s.description()
2060 }))
2061 } else {
2062 Ok(snapshots)
2063 }
2064 }
2065
2066 /// Get the snapshot with the specified id.
2067 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
2068 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2069 self.rpc("snapshots.get".to_owned(), Some(params)).await
2070 }
2071
2072 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
2073 ///
2074 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
2075 /// pairs) that serve two primary purposes:
2076 ///
2077 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
2078 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
2079 /// restored with AGTG (Automatic Ground Truth Generation) and optional
2080 /// auto-depth processing.
2081 ///
2082 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
2083 /// migration between EdgeFirst Studio instances using the create →
2084 /// download → upload → restore workflow.
2085 ///
2086 /// Large files are automatically chunked into 100MB parts and uploaded
2087 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
2088 /// is streamed without loading into memory, maintaining constant memory
2089 /// usage.
2090 ///
2091 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2092 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
2093 /// better for large files to avoid timeout issues. Higher values (16-32)
2094 /// are better for many small files.
2095 ///
2096 /// # Arguments
2097 ///
2098 /// * `path` - Local file path to MCAP file or directory containing
2099 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
2100 /// * `progress` - Optional channel to receive upload progress updates
2101 ///
2102 /// # Returns
2103 ///
2104 /// Returns a `Snapshot` object with ID, description, status, path, and
2105 /// creation timestamp on success.
2106 ///
2107 /// # Errors
2108 ///
2109 /// Returns an error if:
2110 /// * Path doesn't exist or contains invalid UTF-8
2111 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
2112 /// * Upload fails or network error occurs
2113 /// * Server rejects the snapshot
2114 ///
2115 /// # Example
2116 ///
2117 /// ```no_run
2118 /// # use edgefirst_client::{Client, Progress};
2119 /// # use tokio::sync::mpsc;
2120 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2121 /// let client = Client::new()?.with_token_path(None)?;
2122 ///
2123 /// // Upload MCAP file with progress tracking
2124 /// let (tx, mut rx) = mpsc::channel(1);
2125 /// tokio::spawn(async move {
2126 /// while let Some(Progress { current, total }) = rx.recv().await {
2127 /// println!(
2128 /// "Upload: {}/{} bytes ({:.1}%)",
2129 /// current,
2130 /// total,
2131 /// (current as f64 / total as f64) * 100.0
2132 /// );
2133 /// }
2134 /// });
2135 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
2136 /// println!("Created snapshot: {:?}", snapshot.id());
2137 ///
2138 /// // Upload dataset directory (no progress)
2139 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
2140 /// # Ok(())
2141 /// # }
2142 /// ```
2143 ///
2144 /// # See Also
2145 ///
2146 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2147 /// dataset
2148 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2149 /// data
2150 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2151 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2152 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
2153 pub async fn create_snapshot(
2154 &self,
2155 path: &str,
2156 progress: Option<Sender<Progress>>,
2157 ) -> Result<Snapshot, Error> {
2158 let path = Path::new(path);
2159
2160 if path.is_dir() {
2161 let path_str = path.to_str().ok_or_else(|| {
2162 Error::IoError(std::io::Error::new(
2163 std::io::ErrorKind::InvalidInput,
2164 "Path contains invalid UTF-8",
2165 ))
2166 })?;
2167 return self.create_snapshot_folder(path_str, progress).await;
2168 }
2169
2170 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2171 Error::IoError(std::io::Error::new(
2172 std::io::ErrorKind::InvalidInput,
2173 "Invalid filename",
2174 ))
2175 })?;
2176 let total = path.metadata()?.len() as usize;
2177 let current = Arc::new(AtomicUsize::new(0));
2178
2179 if let Some(progress) = &progress {
2180 let _ = progress.send(Progress { current: 0, total }).await;
2181 }
2182
2183 let params = SnapshotCreateMultipartParams {
2184 snapshot_name: name.to_owned(),
2185 keys: vec![name.to_owned()],
2186 file_sizes: vec![total],
2187 };
2188 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2189 .rpc(
2190 "snapshots.create_upload_url_multipart".to_owned(),
2191 Some(params),
2192 )
2193 .await?;
2194
2195 let snapshot_id = match multipart.get("snapshot_id") {
2196 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2197 _ => return Err(Error::InvalidResponse),
2198 };
2199
2200 let snapshot = self.snapshot(snapshot_id).await?;
2201 let part_prefix = snapshot
2202 .path()
2203 .split("::/")
2204 .last()
2205 .ok_or(Error::InvalidResponse)?
2206 .to_owned();
2207 let part_key = format!("{}/{}", part_prefix, name);
2208 let mut part = match multipart.get(&part_key) {
2209 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2210 _ => return Err(Error::InvalidResponse),
2211 }
2212 .clone();
2213 part.key = Some(part_key);
2214
2215 let params = upload_multipart(
2216 self.http.clone(),
2217 part.clone(),
2218 path.to_path_buf(),
2219 total,
2220 current,
2221 progress.clone(),
2222 )
2223 .await?;
2224
2225 let complete: String = self
2226 .rpc(
2227 "snapshots.complete_multipart_upload".to_owned(),
2228 Some(params),
2229 )
2230 .await?;
2231 debug!("Snapshot Multipart Complete: {:?}", complete);
2232
2233 let params: SnapshotStatusParams = SnapshotStatusParams {
2234 snapshot_id,
2235 status: "available".to_owned(),
2236 };
2237 let _: SnapshotStatusResult = self
2238 .rpc("snapshots.update".to_owned(), Some(params))
2239 .await?;
2240
2241 if let Some(progress) = progress {
2242 drop(progress);
2243 }
2244
2245 self.snapshot(snapshot_id).await
2246 }
2247
2248 async fn create_snapshot_folder(
2249 &self,
2250 path: &str,
2251 progress: Option<Sender<Progress>>,
2252 ) -> Result<Snapshot, Error> {
2253 let path = Path::new(path);
2254 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2255 Error::IoError(std::io::Error::new(
2256 std::io::ErrorKind::InvalidInput,
2257 "Invalid directory name",
2258 ))
2259 })?;
2260
2261 let files = WalkDir::new(path)
2262 .into_iter()
2263 .filter_map(|entry| entry.ok())
2264 .filter(|entry| entry.file_type().is_file())
2265 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
2266 .collect::<Vec<_>>();
2267
2268 let total: usize = files
2269 .iter()
2270 .filter_map(|file| path.join(file).metadata().ok())
2271 .map(|metadata| metadata.len() as usize)
2272 .sum();
2273 let current = Arc::new(AtomicUsize::new(0));
2274
2275 if let Some(progress) = &progress {
2276 let _ = progress.send(Progress { current: 0, total }).await;
2277 }
2278
2279 let keys = files
2280 .iter()
2281 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
2282 .collect::<Vec<_>>();
2283 let file_sizes = files
2284 .iter()
2285 .filter_map(|key| path.join(key).metadata().ok())
2286 .map(|metadata| metadata.len() as usize)
2287 .collect::<Vec<_>>();
2288
2289 let params = SnapshotCreateMultipartParams {
2290 snapshot_name: name.to_owned(),
2291 keys,
2292 file_sizes,
2293 };
2294
2295 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2296 .rpc(
2297 "snapshots.create_upload_url_multipart".to_owned(),
2298 Some(params),
2299 )
2300 .await?;
2301
2302 let snapshot_id = match multipart.get("snapshot_id") {
2303 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2304 _ => return Err(Error::InvalidResponse),
2305 };
2306
2307 let snapshot = self.snapshot(snapshot_id).await?;
2308 let part_prefix = snapshot
2309 .path()
2310 .split("::/")
2311 .last()
2312 .ok_or(Error::InvalidResponse)?
2313 .to_owned();
2314
2315 for file in files {
2316 let file_str = file.to_str().ok_or_else(|| {
2317 Error::IoError(std::io::Error::new(
2318 std::io::ErrorKind::InvalidInput,
2319 "File path contains invalid UTF-8",
2320 ))
2321 })?;
2322 let part_key = format!("{}/{}", part_prefix, file_str);
2323 let mut part = match multipart.get(&part_key) {
2324 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2325 _ => return Err(Error::InvalidResponse),
2326 }
2327 .clone();
2328 part.key = Some(part_key);
2329
2330 let params = upload_multipart(
2331 self.http.clone(),
2332 part.clone(),
2333 path.join(file),
2334 total,
2335 current.clone(),
2336 progress.clone(),
2337 )
2338 .await?;
2339
2340 let complete: String = self
2341 .rpc(
2342 "snapshots.complete_multipart_upload".to_owned(),
2343 Some(params),
2344 )
2345 .await?;
2346 debug!("Snapshot Part Complete: {:?}", complete);
2347 }
2348
2349 let params = SnapshotStatusParams {
2350 snapshot_id,
2351 status: "available".to_owned(),
2352 };
2353 let _: SnapshotStatusResult = self
2354 .rpc("snapshots.update".to_owned(), Some(params))
2355 .await?;
2356
2357 if let Some(progress) = progress {
2358 drop(progress);
2359 }
2360
2361 self.snapshot(snapshot_id).await
2362 }
2363
2364 /// Delete a snapshot from EdgeFirst Studio.
2365 ///
2366 /// Permanently removes a snapshot and its associated data. This operation
2367 /// cannot be undone.
2368 ///
2369 /// # Arguments
2370 ///
2371 /// * `snapshot_id` - The snapshot ID to delete
2372 ///
2373 /// # Errors
2374 ///
2375 /// Returns an error if:
2376 /// * Snapshot doesn't exist
2377 /// * User lacks permission to delete the snapshot
2378 /// * Server error occurs
2379 ///
2380 /// # Example
2381 ///
2382 /// ```no_run
2383 /// # use edgefirst_client::{Client, SnapshotID};
2384 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2385 /// let client = Client::new()?.with_token_path(None)?;
2386 /// let snapshot_id = SnapshotID::from(123);
2387 /// client.delete_snapshot(snapshot_id).await?;
2388 /// # Ok(())
2389 /// # }
2390 /// ```
2391 ///
2392 /// # See Also
2393 ///
2394 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2395 /// * [`snapshots`](Self::snapshots) - List all snapshots
2396 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
2397 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2398 let _: String = self
2399 .rpc("snapshots.delete".to_owned(), Some(params))
2400 .await?;
2401 Ok(())
2402 }
2403
2404 /// Create a snapshot from an existing dataset on the server.
2405 ///
2406 /// Triggers server-side snapshot generation which exports the dataset's
2407 /// images and annotations into a downloadable EdgeFirst Dataset Format
2408 /// snapshot.
2409 ///
2410 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
2411 /// while restore creates a dataset from a snapshot, this method creates a
2412 /// snapshot from a dataset.
2413 ///
2414 /// # Arguments
2415 ///
2416 /// * `dataset_id` - The dataset ID to create snapshot from
2417 /// * `description` - Description for the created snapshot
2418 ///
2419 /// # Returns
2420 ///
2421 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
2422 /// for monitoring progress.
2423 ///
2424 /// # Errors
2425 ///
2426 /// Returns an error if:
2427 /// * Dataset doesn't exist
2428 /// * User lacks permission to access the dataset
2429 /// * Server rejects the request
2430 ///
2431 /// # Example
2432 ///
2433 /// ```no_run
2434 /// # use edgefirst_client::{Client, DatasetID};
2435 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2436 /// let client = Client::new()?.with_token_path(None)?;
2437 /// let dataset_id = DatasetID::from(123);
2438 ///
2439 /// // Create snapshot from dataset (all annotation sets)
2440 /// let result = client
2441 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
2442 /// .await?;
2443 /// println!("Created snapshot: {:?}", result.id);
2444 ///
2445 /// // Monitor progress via task ID
2446 /// if let Some(task_id) = result.task_id {
2447 /// println!("Task: {}", task_id);
2448 /// }
2449 /// # Ok(())
2450 /// # }
2451 /// ```
2452 ///
2453 /// # See Also
2454 ///
2455 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
2456 /// snapshot
2457 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2458 /// dataset
2459 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2460 pub async fn create_snapshot_from_dataset(
2461 &self,
2462 dataset_id: DatasetID,
2463 description: &str,
2464 annotation_set_id: Option<AnnotationSetID>,
2465 ) -> Result<SnapshotFromDatasetResult, Error> {
2466 // Resolve annotation_set_id: use provided value or fetch default
2467 let annotation_set_id = match annotation_set_id {
2468 Some(id) => id,
2469 None => {
2470 // Fetch annotation sets and find default ("annotations") or use first
2471 let sets = self.annotation_sets(dataset_id).await?;
2472 if sets.is_empty() {
2473 return Err(Error::InvalidParameters(
2474 "No annotation sets available for dataset".to_owned(),
2475 ));
2476 }
2477 // Look for "annotations" set (default), otherwise use first
2478 sets.iter()
2479 .find(|s| s.name() == "annotations")
2480 .unwrap_or(&sets[0])
2481 .id()
2482 }
2483 };
2484 let params = SnapshotCreateFromDataset {
2485 description: description.to_owned(),
2486 dataset_id,
2487 annotation_set_id,
2488 };
2489 self.rpc("snapshots.create".to_owned(), Some(params)).await
2490 }
2491
2492 /// Download a snapshot from EdgeFirst Studio to local storage.
2493 ///
2494 /// Downloads all files in a snapshot (single MCAP file or directory of
2495 /// EdgeFirst Dataset Format files) to the specified output path. Files are
2496 /// downloaded concurrently with progress tracking.
2497 ///
2498 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2499 /// downloads (default: half of CPU cores, min 2, max 8).
2500 ///
2501 /// # Arguments
2502 ///
2503 /// * `snapshot_id` - The snapshot ID to download
2504 /// * `output` - Local directory path to save downloaded files
2505 /// * `progress` - Optional channel to receive download progress updates
2506 ///
2507 /// # Errors
2508 ///
2509 /// Returns an error if:
2510 /// * Snapshot doesn't exist
2511 /// * Output directory cannot be created
2512 /// * Download fails or network error occurs
2513 ///
2514 /// # Example
2515 ///
2516 /// ```no_run
2517 /// # use edgefirst_client::{Client, SnapshotID, Progress};
2518 /// # use tokio::sync::mpsc;
2519 /// # use std::path::PathBuf;
2520 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2521 /// let client = Client::new()?.with_token_path(None)?;
2522 /// let snapshot_id = SnapshotID::from(123);
2523 ///
2524 /// // Download with progress tracking
2525 /// let (tx, mut rx) = mpsc::channel(1);
2526 /// tokio::spawn(async move {
2527 /// while let Some(Progress { current, total }) = rx.recv().await {
2528 /// println!("Download: {}/{} bytes", current, total);
2529 /// }
2530 /// });
2531 /// client
2532 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
2533 /// .await?;
2534 /// # Ok(())
2535 /// # }
2536 /// ```
2537 ///
2538 /// # See Also
2539 ///
2540 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2541 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2542 /// dataset
2543 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2544 pub async fn download_snapshot(
2545 &self,
2546 snapshot_id: SnapshotID,
2547 output: PathBuf,
2548 progress: Option<Sender<Progress>>,
2549 ) -> Result<(), Error> {
2550 fs::create_dir_all(&output).await?;
2551
2552 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2553 let items: HashMap<String, String> = self
2554 .rpc("snapshots.create_download_url".to_owned(), Some(params))
2555 .await?;
2556
2557 let total = Arc::new(AtomicUsize::new(0));
2558 let current = Arc::new(AtomicUsize::new(0));
2559 let sem = Arc::new(Semaphore::new(max_tasks()));
2560
2561 let tasks = items
2562 .iter()
2563 .map(|(key, url)| {
2564 let http = self.http.clone();
2565 let key = key.clone();
2566 let url = url.clone();
2567 let output = output.clone();
2568 let progress = progress.clone();
2569 let current = current.clone();
2570 let total = total.clone();
2571 let sem = sem.clone();
2572
2573 tokio::spawn(async move {
2574 let _permit = sem.acquire().await.map_err(|_| {
2575 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2576 })?;
2577 let res = http.get(url).send().await?;
2578 let content_length = res.content_length().unwrap_or(0) as usize;
2579
2580 if let Some(progress) = &progress {
2581 let total = total.fetch_add(content_length, Ordering::SeqCst);
2582 let _ = progress
2583 .send(Progress {
2584 current: current.load(Ordering::SeqCst),
2585 total: total + content_length,
2586 })
2587 .await;
2588 }
2589
2590 let mut file = File::create(output.join(key)).await?;
2591 let mut stream = res.bytes_stream();
2592
2593 while let Some(chunk) = stream.next().await {
2594 let chunk = chunk?;
2595 file.write_all(&chunk).await?;
2596 let len = chunk.len();
2597
2598 if let Some(progress) = &progress {
2599 let total = total.load(Ordering::SeqCst);
2600 let current = current.fetch_add(len, Ordering::SeqCst);
2601
2602 let _ = progress
2603 .send(Progress {
2604 current: current + len,
2605 total,
2606 })
2607 .await;
2608 }
2609 }
2610
2611 Ok::<(), Error>(())
2612 })
2613 })
2614 .collect::<Vec<_>>();
2615
2616 join_all(tasks)
2617 .await
2618 .into_iter()
2619 .collect::<Result<Vec<_>, _>>()?
2620 .into_iter()
2621 .collect::<Result<Vec<_>, _>>()?;
2622
2623 Ok(())
2624 }
2625
2626 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
2627 ///
2628 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
2629 /// the specified project. For MCAP files, supports:
2630 ///
2631 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
2632 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
2633 /// present)
2634 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
2635 /// * **Topic filtering**: Select specific MCAP topics to restore
2636 ///
2637 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
2638 /// dataset structure.
2639 ///
2640 /// # Arguments
2641 ///
2642 /// * `project_id` - Target project ID
2643 /// * `snapshot_id` - Snapshot ID to restore
2644 /// * `topics` - MCAP topics to include (empty = all topics)
2645 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
2646 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
2647 /// * `dataset_name` - Optional custom dataset name
2648 /// * `dataset_description` - Optional dataset description
2649 ///
2650 /// # Returns
2651 ///
2652 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
2653 ///
2654 /// # Errors
2655 ///
2656 /// Returns an error if:
2657 /// * Snapshot or project doesn't exist
2658 /// * Snapshot format is invalid
2659 /// * Server rejects restoration parameters
2660 ///
2661 /// # Example
2662 ///
2663 /// ```no_run
2664 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
2665 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2666 /// let client = Client::new()?.with_token_path(None)?;
2667 /// let project_id = ProjectID::from(1);
2668 /// let snapshot_id = SnapshotID::from(123);
2669 ///
2670 /// // Restore MCAP with AGTG for "person" and "car" detection
2671 /// let result = client
2672 /// .restore_snapshot(
2673 /// project_id,
2674 /// snapshot_id,
2675 /// &[], // All topics
2676 /// &["person".to_string(), "car".to_string()], // AGTG labels
2677 /// true, // Auto-depth
2678 /// Some("Highway Dataset"),
2679 /// Some("Collected on I-95"),
2680 /// )
2681 /// .await?;
2682 /// println!("Restored to dataset: {:?}", result.dataset_id);
2683 /// # Ok(())
2684 /// # }
2685 /// ```
2686 ///
2687 /// # See Also
2688 ///
2689 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
2690 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2691 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2692 #[allow(clippy::too_many_arguments)]
2693 pub async fn restore_snapshot(
2694 &self,
2695 project_id: ProjectID,
2696 snapshot_id: SnapshotID,
2697 topics: &[String],
2698 autolabel: &[String],
2699 autodepth: bool,
2700 dataset_name: Option<&str>,
2701 dataset_description: Option<&str>,
2702 ) -> Result<SnapshotRestoreResult, Error> {
2703 let params = SnapshotRestore {
2704 project_id,
2705 snapshot_id,
2706 fps: 1,
2707 autodepth,
2708 agtg_pipeline: !autolabel.is_empty(),
2709 autolabel: autolabel.to_vec(),
2710 topics: topics.to_vec(),
2711 dataset_name: dataset_name.map(|s| s.to_owned()),
2712 dataset_description: dataset_description.map(|s| s.to_owned()),
2713 };
2714 self.rpc("snapshots.restore".to_owned(), Some(params)).await
2715 }
2716
2717 /// Returns a list of experiments available to the user. The experiments
2718 /// are returned as a vector of Experiment objects. If name is provided
2719 /// then only experiments containing this string are returned.
2720 ///
2721 /// Results are sorted by match quality: exact matches first, then
2722 /// case-insensitive exact matches, then shorter names (more specific),
2723 /// then alphabetically.
2724 ///
2725 /// Experiments provide a method of organizing training and validation
2726 /// sessions together and are akin to an Experiment in MLFlow terminology.
2727 /// Each experiment can have multiple trainer sessions associated with it,
2728 /// these would be akin to runs in MLFlow terminology.
2729 pub async fn experiments(
2730 &self,
2731 project_id: ProjectID,
2732 name: Option<&str>,
2733 ) -> Result<Vec<Experiment>, Error> {
2734 let params = HashMap::from([("project_id", project_id)]);
2735 let experiments: Vec<Experiment> =
2736 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
2737 if let Some(name) = name {
2738 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
2739 } else {
2740 Ok(experiments)
2741 }
2742 }
2743
2744 /// Return the experiment with the specified experiment ID. If the
2745 /// experiment does not exist, an error is returned.
2746 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
2747 let params = HashMap::from([("trainer_id", experiment_id)]);
2748 self.rpc("trainer.get".to_owned(), Some(params)).await
2749 }
2750
2751 /// Returns a list of trainer sessions available to the user. The trainer
2752 /// sessions are returned as a vector of TrainingSession objects. If name
2753 /// is provided then only trainer sessions containing this string are
2754 /// returned.
2755 ///
2756 /// Results are sorted by match quality: exact matches first, then
2757 /// case-insensitive exact matches, then shorter names (more specific),
2758 /// then alphabetically.
2759 ///
2760 /// Trainer sessions are akin to runs in MLFlow terminology. These
2761 /// represent an actual training session which will produce metrics and
2762 /// model artifacts.
2763 pub async fn training_sessions(
2764 &self,
2765 experiment_id: ExperimentID,
2766 name: Option<&str>,
2767 ) -> Result<Vec<TrainingSession>, Error> {
2768 let params = HashMap::from([("trainer_id", experiment_id)]);
2769 let sessions: Vec<TrainingSession> = self
2770 .rpc("trainer.session.list".to_owned(), Some(params))
2771 .await?;
2772 if let Some(name) = name {
2773 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
2774 } else {
2775 Ok(sessions)
2776 }
2777 }
2778
2779 /// Return the trainer session with the specified trainer session ID. If
2780 /// the trainer session does not exist, an error is returned.
2781 pub async fn training_session(
2782 &self,
2783 session_id: TrainingSessionID,
2784 ) -> Result<TrainingSession, Error> {
2785 let params = HashMap::from([("trainer_session_id", session_id)]);
2786 self.rpc("trainer.session.get".to_owned(), Some(params))
2787 .await
2788 }
2789
2790 /// List validation sessions for the given project.
2791 pub async fn validation_sessions(
2792 &self,
2793 project_id: ProjectID,
2794 ) -> Result<Vec<ValidationSession>, Error> {
2795 let params = HashMap::from([("project_id", project_id)]);
2796 self.rpc("validate.session.list".to_owned(), Some(params))
2797 .await
2798 }
2799
2800 /// Retrieve a specific validation session.
2801 pub async fn validation_session(
2802 &self,
2803 session_id: ValidationSessionID,
2804 ) -> Result<ValidationSession, Error> {
2805 let params = HashMap::from([("validate_session_id", session_id)]);
2806 self.rpc("validate.session.get".to_owned(), Some(params))
2807 .await
2808 }
2809
2810 /// List the artifacts for the specified trainer session. The artifacts
2811 /// are returned as a vector of strings.
2812 pub async fn artifacts(
2813 &self,
2814 training_session_id: TrainingSessionID,
2815 ) -> Result<Vec<Artifact>, Error> {
2816 let params = HashMap::from([("training_session_id", training_session_id)]);
2817 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
2818 .await
2819 }
2820
2821 /// Download the model artifact for the specified trainer session to the
2822 /// specified file path, if path is not provided it will be downloaded to
2823 /// the current directory with the same filename. A progress callback can
2824 /// be provided to monitor the progress of the download over a watch
2825 /// channel.
2826 pub async fn download_artifact(
2827 &self,
2828 training_session_id: TrainingSessionID,
2829 modelname: &str,
2830 filename: Option<PathBuf>,
2831 progress: Option<Sender<Progress>>,
2832 ) -> Result<(), Error> {
2833 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
2834 let resp = self
2835 .http
2836 .get(format!(
2837 "{}/download_model?training_session_id={}&file={}",
2838 self.url,
2839 training_session_id.value(),
2840 modelname
2841 ))
2842 .header("Authorization", format!("Bearer {}", self.token().await))
2843 .send()
2844 .await?;
2845 if !resp.status().is_success() {
2846 let err = resp.error_for_status_ref().unwrap_err();
2847 return Err(Error::HttpError(err));
2848 }
2849
2850 if let Some(parent) = filename.parent() {
2851 fs::create_dir_all(parent).await?;
2852 }
2853
2854 if let Some(progress) = progress {
2855 let total = resp.content_length().unwrap_or(0) as usize;
2856 let _ = progress.send(Progress { current: 0, total }).await;
2857
2858 let mut file = File::create(filename).await?;
2859 let mut current = 0;
2860 let mut stream = resp.bytes_stream();
2861
2862 while let Some(item) = stream.next().await {
2863 let chunk = item?;
2864 file.write_all(&chunk).await?;
2865 current += chunk.len();
2866 let _ = progress.send(Progress { current, total }).await;
2867 }
2868 } else {
2869 let body = resp.bytes().await?;
2870 fs::write(filename, body).await?;
2871 }
2872
2873 Ok(())
2874 }
2875
2876 /// Download the model checkpoint associated with the specified trainer
2877 /// session to the specified file path, if path is not provided it will be
2878 /// downloaded to the current directory with the same filename. A progress
2879 /// callback can be provided to monitor the progress of the download over a
2880 /// watch channel.
2881 ///
2882 /// There is no API for listing checkpoints it is expected that trainers are
2883 /// aware of possible checkpoints and their names within the checkpoint
2884 /// folder on the server.
2885 pub async fn download_checkpoint(
2886 &self,
2887 training_session_id: TrainingSessionID,
2888 checkpoint: &str,
2889 filename: Option<PathBuf>,
2890 progress: Option<Sender<Progress>>,
2891 ) -> Result<(), Error> {
2892 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
2893 let resp = self
2894 .http
2895 .get(format!(
2896 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
2897 self.url,
2898 training_session_id.value(),
2899 checkpoint
2900 ))
2901 .header("Authorization", format!("Bearer {}", self.token().await))
2902 .send()
2903 .await?;
2904 if !resp.status().is_success() {
2905 let err = resp.error_for_status_ref().unwrap_err();
2906 return Err(Error::HttpError(err));
2907 }
2908
2909 if let Some(parent) = filename.parent() {
2910 fs::create_dir_all(parent).await?;
2911 }
2912
2913 if let Some(progress) = progress {
2914 let total = resp.content_length().unwrap_or(0) as usize;
2915 let _ = progress.send(Progress { current: 0, total }).await;
2916
2917 let mut file = File::create(filename).await?;
2918 let mut current = 0;
2919 let mut stream = resp.bytes_stream();
2920
2921 while let Some(item) = stream.next().await {
2922 let chunk = item?;
2923 file.write_all(&chunk).await?;
2924 current += chunk.len();
2925 let _ = progress.send(Progress { current, total }).await;
2926 }
2927 } else {
2928 let body = resp.bytes().await?;
2929 fs::write(filename, body).await?;
2930 }
2931
2932 Ok(())
2933 }
2934
2935 /// Return a list of tasks for the current user.
2936 ///
2937 /// # Arguments
2938 ///
2939 /// * `name` - Optional filter for task name (client-side substring match)
2940 /// * `workflow` - Optional filter for workflow/task type. If provided,
2941 /// filters server-side by exact match. Valid values include: "trainer",
2942 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
2943 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
2944 /// "convertor", "twostage"
2945 /// * `status` - Optional filter for task status (e.g., "running",
2946 /// "complete", "error")
2947 /// * `manager` - Optional filter for task manager type (e.g., "aws",
2948 /// "user", "kubernetes")
2949 pub async fn tasks(
2950 &self,
2951 name: Option<&str>,
2952 workflow: Option<&str>,
2953 status: Option<&str>,
2954 manager: Option<&str>,
2955 ) -> Result<Vec<Task>, Error> {
2956 let mut params = TasksListParams {
2957 continue_token: None,
2958 types: workflow.map(|w| vec![w.to_owned()]),
2959 status: status.map(|s| vec![s.to_owned()]),
2960 manager: manager.map(|m| vec![m.to_owned()]),
2961 };
2962 let mut tasks = Vec::new();
2963
2964 loop {
2965 let result = self
2966 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
2967 .await?;
2968 tasks.extend(result.tasks);
2969
2970 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
2971 params.continue_token = None;
2972 } else {
2973 params.continue_token = result.continue_token;
2974 }
2975
2976 if params.continue_token.is_none() {
2977 break;
2978 }
2979 }
2980
2981 if let Some(name) = name {
2982 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
2983 }
2984
2985 Ok(tasks)
2986 }
2987
2988 /// Retrieve the task information and status.
2989 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
2990 self.rpc(
2991 "task.get".to_owned(),
2992 Some(HashMap::from([("id", task_id)])),
2993 )
2994 .await
2995 }
2996
2997 /// Updates the tasks status.
2998 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
2999 let status = TaskStatus {
3000 task_id,
3001 status: status.to_owned(),
3002 };
3003 self.rpc("docker.update.status".to_owned(), Some(status))
3004 .await
3005 }
3006
3007 /// Defines the stages for the task. The stages are defined as a mapping
3008 /// from stage names to their descriptions. Once stages are defined their
3009 /// status can be updated using the update_stage method.
3010 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
3011 let stages: Vec<HashMap<String, String>> = stages
3012 .iter()
3013 .map(|(key, value)| {
3014 let mut stage_map = HashMap::new();
3015 stage_map.insert(key.to_string(), value.to_string());
3016 stage_map
3017 })
3018 .collect();
3019 let params = TaskStages { task_id, stages };
3020 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
3021 Ok(())
3022 }
3023
3024 /// Updates the progress of the task for the provided stage and status
3025 /// information.
3026 pub async fn update_stage(
3027 &self,
3028 task_id: TaskID,
3029 stage: &str,
3030 status: &str,
3031 message: &str,
3032 percentage: u8,
3033 ) -> Result<(), Error> {
3034 let stage = Stage::new(
3035 Some(task_id),
3036 stage.to_owned(),
3037 Some(status.to_owned()),
3038 Some(message.to_owned()),
3039 percentage,
3040 );
3041 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
3042 Ok(())
3043 }
3044
3045 /// Raw fetch from the Studio server is used for downloading files.
3046 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
3047 let req = self
3048 .http
3049 .get(format!("{}/{}", self.url, query))
3050 .header("User-Agent", "EdgeFirst Client")
3051 .header("Authorization", format!("Bearer {}", self.token().await));
3052 let resp = req.send().await?;
3053
3054 if resp.status().is_success() {
3055 let body = resp.bytes().await?;
3056
3057 if log_enabled!(Level::Trace) {
3058 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
3059 }
3060
3061 Ok(body.to_vec())
3062 } else {
3063 let err = resp.error_for_status_ref().unwrap_err();
3064 Err(Error::HttpError(err))
3065 }
3066 }
3067
3068 /// Sends a multipart post request to the server. This is used by the
3069 /// upload and download APIs which do not use JSON-RPC but instead transfer
3070 /// files using multipart/form-data.
3071 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
3072 let req = self
3073 .http
3074 .post(format!("{}/api?method={}", self.url, method))
3075 .header("Accept", "application/json")
3076 .header("User-Agent", "EdgeFirst Client")
3077 .header("Authorization", format!("Bearer {}", self.token().await))
3078 .multipart(form);
3079 let resp = req.send().await?;
3080
3081 if resp.status().is_success() {
3082 let body = resp.bytes().await?;
3083
3084 if log_enabled!(Level::Trace) {
3085 trace!(
3086 "POST Multipart Response: {}",
3087 String::from_utf8_lossy(&body)
3088 );
3089 }
3090
3091 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
3092 Ok(response) => response,
3093 Err(err) => {
3094 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3095 return Err(err.into());
3096 }
3097 };
3098
3099 if let Some(error) = response.error {
3100 Err(Error::RpcError(error.code, error.message))
3101 } else if let Some(result) = response.result {
3102 Ok(result)
3103 } else {
3104 Err(Error::InvalidResponse)
3105 }
3106 } else {
3107 let err = resp.error_for_status_ref().unwrap_err();
3108 Err(Error::HttpError(err))
3109 }
3110 }
3111
3112 /// Send a JSON-RPC request to the server. The method is the name of the
3113 /// method to call on the server. The params are the parameters to pass to
3114 /// the method. The method and params are serialized into a JSON-RPC
3115 /// request and sent to the server. The response is deserialized into
3116 /// the specified type and returned to the caller.
3117 ///
3118 /// NOTE: This API would generally not be called directly and instead users
3119 /// should use the higher-level methods provided by the client.
3120 pub async fn rpc<Params, RpcResult>(
3121 &self,
3122 method: String,
3123 params: Option<Params>,
3124 ) -> Result<RpcResult, Error>
3125 where
3126 Params: Serialize,
3127 RpcResult: DeserializeOwned,
3128 {
3129 let auth_expires = self.token_expiration().await?;
3130 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
3131 self.renew_token().await?;
3132 }
3133
3134 self.rpc_without_auth(method, params).await
3135 }
3136
3137 async fn rpc_without_auth<Params, RpcResult>(
3138 &self,
3139 method: String,
3140 params: Option<Params>,
3141 ) -> Result<RpcResult, Error>
3142 where
3143 Params: Serialize,
3144 RpcResult: DeserializeOwned,
3145 {
3146 let request = RpcRequest {
3147 method,
3148 params,
3149 ..Default::default()
3150 };
3151
3152 if log_enabled!(Level::Trace) {
3153 trace!(
3154 "RPC Request: {}",
3155 serde_json::ser::to_string_pretty(&request)?
3156 );
3157 }
3158
3159 let url = format!("{}/api", self.url);
3160
3161 // Use client-level timeout (allows retry mechanism to work properly)
3162 // Per-request timeout overrides can prevent retries from functioning
3163 let res = self
3164 .http
3165 .post(&url)
3166 .header("Accept", "application/json")
3167 .header("User-Agent", "EdgeFirst Client")
3168 .header("Authorization", format!("Bearer {}", self.token().await))
3169 .json(&request)
3170 .send()
3171 .await?;
3172
3173 self.process_rpc_response(res).await
3174 }
3175
3176 async fn process_rpc_response<RpcResult>(
3177 &self,
3178 res: reqwest::Response,
3179 ) -> Result<RpcResult, Error>
3180 where
3181 RpcResult: DeserializeOwned,
3182 {
3183 let body = res.bytes().await?;
3184
3185 if log_enabled!(Level::Trace) {
3186 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
3187 }
3188
3189 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
3190 Ok(response) => response,
3191 Err(err) => {
3192 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
3193 return Err(err.into());
3194 }
3195 };
3196
3197 // FIXME: Studio Server always returns 999 as the id.
3198 // if request.id.to_string() != response.id {
3199 // return Err(Error::InvalidRpcId(response.id));
3200 // }
3201
3202 if let Some(error) = response.error {
3203 Err(Error::RpcError(error.code, error.message))
3204 } else if let Some(result) = response.result {
3205 Ok(result)
3206 } else {
3207 Err(Error::InvalidResponse)
3208 }
3209 }
3210}
3211
3212/// Process items in parallel with semaphore concurrency control and progress
3213/// tracking.
3214///
3215/// This helper eliminates boilerplate for parallel item processing with:
3216/// - Semaphore limiting concurrent tasks to `max_tasks()` (configurable via
3217/// `MAX_TASKS` environment variable, default: half of CPU cores, min 2, max
3218/// 8)
3219/// - Atomic progress counter with automatic item-level updates
3220/// - Progress updates sent after each item completes (not byte-level streaming)
3221/// - Proper error propagation from spawned tasks
3222///
3223/// Note: This is optimized for discrete items with post-completion progress
3224/// updates. For byte-level streaming progress or custom retry logic, use
3225/// specialized implementations.
3226///
3227/// # Arguments
3228///
3229/// * `items` - Collection of items to process in parallel
3230/// * `progress` - Optional progress channel for tracking completion
3231/// * `work_fn` - Async function to execute for each item
3232///
3233/// # Examples
3234///
3235/// ```rust,ignore
3236/// parallel_foreach_items(samples, progress, |sample| async move {
3237/// // Process sample
3238/// sample.download(&client, file_type).await?;
3239/// Ok(())
3240/// }).await?;
3241/// ```
3242async fn parallel_foreach_items<T, F, Fut>(
3243 items: Vec<T>,
3244 progress: Option<Sender<Progress>>,
3245 work_fn: F,
3246) -> Result<(), Error>
3247where
3248 T: Send + 'static,
3249 F: Fn(T) -> Fut + Send + Sync + 'static,
3250 Fut: Future<Output = Result<(), Error>> + Send + 'static,
3251{
3252 let total = items.len();
3253 let current = Arc::new(AtomicUsize::new(0));
3254 let sem = Arc::new(Semaphore::new(max_tasks()));
3255 let work_fn = Arc::new(work_fn);
3256
3257 let tasks = items
3258 .into_iter()
3259 .map(|item| {
3260 let sem = sem.clone();
3261 let current = current.clone();
3262 let progress = progress.clone();
3263 let work_fn = work_fn.clone();
3264
3265 tokio::spawn(async move {
3266 let _permit = sem.acquire().await.map_err(|_| {
3267 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3268 })?;
3269
3270 // Execute the actual work
3271 work_fn(item).await?;
3272
3273 // Update progress
3274 if let Some(progress) = &progress {
3275 let current = current.fetch_add(1, Ordering::SeqCst);
3276 let _ = progress
3277 .send(Progress {
3278 current: current + 1,
3279 total,
3280 })
3281 .await;
3282 }
3283
3284 Ok::<(), Error>(())
3285 })
3286 })
3287 .collect::<Vec<_>>();
3288
3289 join_all(tasks)
3290 .await
3291 .into_iter()
3292 .collect::<Result<Vec<_>, _>>()?
3293 .into_iter()
3294 .collect::<Result<Vec<_>, _>>()?;
3295
3296 if let Some(progress) = progress {
3297 drop(progress);
3298 }
3299
3300 Ok(())
3301}
3302
3303/// Upload a file to S3 using multipart upload with presigned URLs.
3304///
3305/// Splits a file into chunks (100MB each) and uploads them in parallel using
3306/// S3 multipart upload protocol. Returns completion parameters with ETags for
3307/// finalizing the upload.
3308///
3309/// This function handles:
3310/// - Splitting files into parts based on PART_SIZE (100MB)
3311/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
3312/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
3313/// - Retry logic (handled by reqwest client)
3314/// - Progress tracking across all parts
3315///
3316/// # Arguments
3317///
3318/// * `http` - HTTP client for making requests
3319/// * `part` - Snapshot part info with presigned URLs for each chunk
3320/// * `path` - Local file path to upload
3321/// * `total` - Total bytes across all files for progress calculation
3322/// * `current` - Atomic counter tracking bytes uploaded across all operations
3323/// * `progress` - Optional channel for sending progress updates
3324///
3325/// # Returns
3326///
3327/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
3328async fn upload_multipart(
3329 http: reqwest::Client,
3330 part: SnapshotPart,
3331 path: PathBuf,
3332 total: usize,
3333 current: Arc<AtomicUsize>,
3334 progress: Option<Sender<Progress>>,
3335) -> Result<SnapshotCompleteMultipartParams, Error> {
3336 let filesize = path.metadata()?.len() as usize;
3337 let n_parts = filesize.div_ceil(PART_SIZE);
3338 let sem = Arc::new(Semaphore::new(max_tasks()));
3339
3340 let key = part.key.ok_or(Error::InvalidResponse)?;
3341 let upload_id = part.upload_id;
3342
3343 let urls = part.urls.clone();
3344 // Pre-allocate ETag slots for all parts
3345 let etags = Arc::new(tokio::sync::Mutex::new(vec![
3346 EtagPart {
3347 etag: "".to_owned(),
3348 part_number: 0,
3349 };
3350 n_parts
3351 ]));
3352
3353 // Upload all parts in parallel with concurrency limiting
3354 let tasks = (0..n_parts)
3355 .map(|part| {
3356 let http = http.clone();
3357 let url = urls[part].clone();
3358 let etags = etags.clone();
3359 let path = path.to_owned();
3360 let sem = sem.clone();
3361 let progress = progress.clone();
3362 let current = current.clone();
3363
3364 tokio::spawn(async move {
3365 // Acquire semaphore permit to limit concurrent uploads
3366 let _permit = sem.acquire().await?;
3367
3368 // Upload part (retry is handled by reqwest client)
3369 let etag =
3370 upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await?;
3371
3372 // Store ETag for this part (needed to complete multipart upload)
3373 let mut etags = etags.lock().await;
3374 etags[part] = EtagPart {
3375 etag,
3376 part_number: part + 1,
3377 };
3378
3379 // Update progress counter
3380 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
3381 if let Some(progress) = &progress {
3382 let _ = progress
3383 .send(Progress {
3384 current: current + PART_SIZE,
3385 total,
3386 })
3387 .await;
3388 }
3389
3390 Ok::<(), Error>(())
3391 })
3392 })
3393 .collect::<Vec<_>>();
3394
3395 // Wait for all parts to complete
3396 join_all(tasks)
3397 .await
3398 .into_iter()
3399 .collect::<Result<Vec<_>, _>>()?;
3400
3401 Ok(SnapshotCompleteMultipartParams {
3402 key,
3403 upload_id,
3404 etag_list: etags.lock().await.clone(),
3405 })
3406}
3407
3408async fn upload_part(
3409 http: reqwest::Client,
3410 url: String,
3411 path: PathBuf,
3412 part: usize,
3413 n_parts: usize,
3414) -> Result<String, Error> {
3415 let filesize = path.metadata()?.len() as usize;
3416 let mut file = File::open(path).await?;
3417 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
3418 .await?;
3419 let file = file.take(PART_SIZE as u64);
3420
3421 let body_length = if part + 1 == n_parts {
3422 filesize % PART_SIZE
3423 } else {
3424 PART_SIZE
3425 };
3426
3427 let stream = FramedRead::new(file, BytesCodec::new());
3428 let body = Body::wrap_stream(stream);
3429
3430 let resp = http
3431 .put(url.clone())
3432 .header(CONTENT_LENGTH, body_length)
3433 .body(body)
3434 .send()
3435 .await?
3436 .error_for_status()?;
3437
3438 let etag = resp
3439 .headers()
3440 .get("etag")
3441 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
3442 .to_str()
3443 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
3444 .to_owned();
3445
3446 // Studio Server requires etag without the quotes.
3447 let etag = etag
3448 .strip_prefix("\"")
3449 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
3450 let etag = etag
3451 .strip_suffix("\"")
3452 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
3453
3454 Ok(etag.to_owned())
3455}
3456
3457/// Upload a complete file to a presigned S3 URL using HTTP PUT.
3458///
3459/// This is used for populate_samples to upload files to S3 after
3460/// receiving presigned URLs from the server.
3461async fn upload_file_to_presigned_url(
3462 http: reqwest::Client,
3463 url: &str,
3464 path: PathBuf,
3465) -> Result<(), Error> {
3466 // Read the entire file into memory
3467 let file_data = fs::read(&path).await?;
3468 let file_size = file_data.len();
3469
3470 // Upload (retry is handled by reqwest client)
3471 let resp = http
3472 .put(url)
3473 .header(CONTENT_LENGTH, file_size)
3474 .body(file_data)
3475 .send()
3476 .await?;
3477
3478 if resp.status().is_success() {
3479 debug!(
3480 "Successfully uploaded file: {:?} ({} bytes)",
3481 path, file_size
3482 );
3483 Ok(())
3484 } else {
3485 let status = resp.status();
3486 let error_text = resp.text().await.unwrap_or_default();
3487 Err(Error::InvalidParameters(format!(
3488 "Upload failed: HTTP {} - {}",
3489 status, error_text
3490 )))
3491 }
3492}
3493
3494#[cfg(test)]
3495mod tests {
3496 use super::*;
3497
3498 #[test]
3499 fn test_filter_and_sort_by_name_exact_match_first() {
3500 // Test that exact matches come first
3501 let items = vec![
3502 "Deer Roundtrip 123".to_string(),
3503 "Deer".to_string(),
3504 "Reindeer".to_string(),
3505 "DEER".to_string(),
3506 ];
3507 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3508 assert_eq!(result[0], "Deer"); // Exact match first
3509 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
3510 }
3511
3512 #[test]
3513 fn test_filter_and_sort_by_name_shorter_names_preferred() {
3514 // Test that shorter names (more specific) come before longer ones
3515 let items = vec![
3516 "Test Dataset ABC".to_string(),
3517 "Test".to_string(),
3518 "Test Dataset".to_string(),
3519 ];
3520 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3521 assert_eq!(result[0], "Test"); // Exact match first
3522 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
3523 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
3524 }
3525
3526 #[test]
3527 fn test_filter_and_sort_by_name_case_insensitive_filter() {
3528 // Test that filtering is case-insensitive
3529 let items = vec![
3530 "UPPERCASE".to_string(),
3531 "lowercase".to_string(),
3532 "MixedCase".to_string(),
3533 ];
3534 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
3535 assert_eq!(result.len(), 3); // All items should match
3536 }
3537
3538 #[test]
3539 fn test_filter_and_sort_by_name_no_matches() {
3540 // Test that empty result is returned when no matches
3541 let items = vec!["Apple".to_string(), "Banana".to_string()];
3542 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
3543 assert!(result.is_empty());
3544 }
3545
3546 #[test]
3547 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
3548 // Test alphabetical ordering for same-length names
3549 let items = vec![
3550 "TestC".to_string(),
3551 "TestA".to_string(),
3552 "TestB".to_string(),
3553 ];
3554 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
3555 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
3556 }
3557
3558 #[test]
3559 fn test_build_filename_no_flatten() {
3560 // When flatten=false, should return base_name unchanged
3561 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
3562 assert_eq!(result, "image.jpg");
3563
3564 let result = Client::build_filename("test.png", false, None, None);
3565 assert_eq!(result, "test.png");
3566 }
3567
3568 #[test]
3569 fn test_build_filename_flatten_no_sequence() {
3570 // When flatten=true but no sequence, should return base_name unchanged
3571 let result = Client::build_filename("standalone.jpg", true, None, None);
3572 assert_eq!(result, "standalone.jpg");
3573 }
3574
3575 #[test]
3576 fn test_build_filename_flatten_with_sequence_not_prefixed() {
3577 // When flatten=true, in sequence, filename not prefixed → add prefix
3578 let result = Client::build_filename(
3579 "image.camera.jpeg",
3580 true,
3581 Some(&"deer_sequence".to_string()),
3582 Some(42),
3583 );
3584 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
3585 }
3586
3587 #[test]
3588 fn test_build_filename_flatten_with_sequence_no_frame() {
3589 // When flatten=true, in sequence, no frame number → prefix with sequence only
3590 let result =
3591 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
3592 assert_eq!(result, "sequence_A_image.jpg");
3593 }
3594
3595 #[test]
3596 fn test_build_filename_flatten_already_prefixed() {
3597 // When flatten=true, filename already starts with sequence_ → return unchanged
3598 let result = Client::build_filename(
3599 "deer_sequence_042.camera.jpeg",
3600 true,
3601 Some(&"deer_sequence".to_string()),
3602 Some(42),
3603 );
3604 assert_eq!(result, "deer_sequence_042.camera.jpeg");
3605 }
3606
3607 #[test]
3608 fn test_build_filename_flatten_already_prefixed_different_frame() {
3609 // Edge case: filename has sequence prefix but we're adding different frame
3610 // Should still respect existing prefix
3611 let result = Client::build_filename(
3612 "sequence_A_001.jpg",
3613 true,
3614 Some(&"sequence_A".to_string()),
3615 Some(2),
3616 );
3617 assert_eq!(result, "sequence_A_001.jpg");
3618 }
3619
3620 #[test]
3621 fn test_build_filename_flatten_partial_match() {
3622 // Edge case: filename contains sequence name but not as prefix
3623 let result = Client::build_filename(
3624 "test_sequence_A_image.jpg",
3625 true,
3626 Some(&"sequence_A".to_string()),
3627 Some(5),
3628 );
3629 // Should add prefix because it doesn't START with "sequence_A_"
3630 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
3631 }
3632
3633 #[test]
3634 fn test_build_filename_flatten_preserves_extension() {
3635 // Verify that file extensions are preserved correctly
3636 let extensions = vec![
3637 "jpeg",
3638 "jpg",
3639 "png",
3640 "camera.jpeg",
3641 "lidar.pcd",
3642 "depth.png",
3643 ];
3644
3645 for ext in extensions {
3646 let filename = format!("image.{}", ext);
3647 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
3648 assert!(
3649 result.ends_with(&format!(".{}", ext)),
3650 "Extension .{} not preserved in {}",
3651 ext,
3652 result
3653 );
3654 }
3655 }
3656
3657 #[test]
3658 fn test_build_filename_flatten_sanitization_compatibility() {
3659 // Test with sanitized path components (no special chars)
3660 let result = Client::build_filename(
3661 "sample_001.jpg",
3662 true,
3663 Some(&"seq_name_with_underscores".to_string()),
3664 Some(10),
3665 );
3666 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
3667 }
3668
3669 // =========================================================================
3670 // Additional filter_and_sort_by_name tests for exact match determinism
3671 // =========================================================================
3672
3673 #[test]
3674 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
3675 // Test that searching for "Deer" always returns "Deer" first, not
3676 // "Deer Roundtrip 20251129" or similar
3677 let items = vec![
3678 "Deer Roundtrip 20251129".to_string(),
3679 "White-Tailed Deer".to_string(),
3680 "Deer".to_string(),
3681 "Deer Snapshot Test".to_string(),
3682 "Reindeer Dataset".to_string(),
3683 ];
3684
3685 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3686
3687 // CRITICAL: First result must be exact match "Deer"
3688 assert_eq!(
3689 result.first().map(|s| s.as_str()),
3690 Some("Deer"),
3691 "Expected exact match 'Deer' first, got: {:?}",
3692 result.first()
3693 );
3694
3695 // Verify all items containing "Deer" are present (case-insensitive)
3696 assert_eq!(result.len(), 5);
3697 }
3698
3699 #[test]
3700 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
3701 // Verify case-sensitive exact match takes priority over case-insensitive
3702 let items = vec![
3703 "DEER".to_string(),
3704 "deer".to_string(),
3705 "Deer".to_string(),
3706 "Deer Test".to_string(),
3707 ];
3708
3709 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3710
3711 // Priority 1: Case-sensitive exact match "Deer" first
3712 assert_eq!(result[0], "Deer");
3713 // Priority 2: Case-insensitive exact matches next
3714 assert!(result[1] == "DEER" || result[1] == "deer");
3715 assert!(result[2] == "DEER" || result[2] == "deer");
3716 }
3717
3718 #[test]
3719 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
3720 // Realistic scenario: User searches for snapshot "Deer" and multiple
3721 // snapshots exist with similar names
3722 let items = vec![
3723 "Unit Testing - Deer Dataset Backup".to_string(),
3724 "Deer".to_string(),
3725 "Deer Snapshot 2025-01-15".to_string(),
3726 "Original Deer".to_string(),
3727 ];
3728
3729 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3730
3731 // MUST return exact match first for deterministic test behavior
3732 assert_eq!(
3733 result[0], "Deer",
3734 "Searching for 'Deer' should return exact 'Deer' first"
3735 );
3736 }
3737
3738 #[test]
3739 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
3740 // Realistic scenario: User searches for dataset "Deer" but multiple
3741 // datasets have "Deer" in their name
3742 let items = vec![
3743 "Deer Roundtrip".to_string(),
3744 "Deer".to_string(),
3745 "deer".to_string(),
3746 "White-Tailed Deer".to_string(),
3747 "Deer-V2".to_string(),
3748 ];
3749
3750 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
3751
3752 // Exact case-sensitive match must be first
3753 assert_eq!(result[0], "Deer");
3754 // Case-insensitive exact match should be second
3755 assert_eq!(result[1], "deer");
3756 // Shorter names should come before longer names
3757 assert!(
3758 result.iter().position(|s| s == "Deer-V2").unwrap()
3759 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
3760 );
3761 }
3762
3763 #[test]
3764 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
3765 // CRITICAL: The first result should ALWAYS be the best match
3766 // This is essential for deterministic test behavior
3767 let scenarios = vec![
3768 // (items, filter, expected_first)
3769 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
3770 (vec!["test", "TEST", "Test Data"], "test", "test"),
3771 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
3772 ];
3773
3774 for (items, filter, expected_first) in scenarios {
3775 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
3776 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
3777
3778 assert_eq!(
3779 result.first().map(|s| s.as_str()),
3780 Some(expected_first),
3781 "For filter '{}', expected first result '{}', got: {:?}",
3782 filter,
3783 expected_first,
3784 result.first()
3785 );
3786 }
3787 }
3788
3789 #[test]
3790 fn test_with_server_clears_storage() {
3791 use crate::storage::MemoryTokenStorage;
3792
3793 // Create client with memory storage and a token
3794 let storage = Arc::new(MemoryTokenStorage::new());
3795 storage.store("test-token").unwrap();
3796
3797 let client = Client::new().unwrap().with_storage(storage.clone());
3798
3799 // Verify token is loaded
3800 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
3801
3802 // Change server - should clear storage
3803 let _new_client = client.with_server("test").unwrap();
3804
3805 // Verify storage was cleared
3806 assert_eq!(storage.load().unwrap(), None);
3807 }
3808}