edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SampleID, SamplesCountResult, SamplesListParams, SamplesListResult,
9 Snapshot, SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID,
10 SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus,
11 TasksListParams, TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
12 ValidationSessionID,
13 },
14 dataset::{
15 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
16 },
17 retry::{create_retry_policy, log_retry_configuration},
18 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
19};
20use base64::Engine as _;
21use chrono::{DateTime, Utc};
22use directories::ProjectDirs;
23use futures::{StreamExt as _, future::join_all};
24use log::{Level, debug, error, log_enabled, trace, warn};
25use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
26use serde::{Deserialize, Serialize, de::DeserializeOwned};
27use std::{
28 collections::HashMap,
29 ffi::OsStr,
30 fs::create_dir_all,
31 io::{SeekFrom, Write as _},
32 path::{Path, PathBuf},
33 sync::{
34 Arc,
35 atomic::{AtomicUsize, Ordering},
36 },
37 time::Duration,
38 vec,
39};
40use tokio::{
41 fs::{self, File},
42 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
43 sync::{RwLock, Semaphore, mpsc::Sender},
44};
45use tokio_util::codec::{BytesCodec, FramedRead};
46use walkdir::WalkDir;
47
48#[cfg(feature = "polars")]
49use polars::prelude::*;
50
51static PART_SIZE: usize = 100 * 1024 * 1024;
52
53/// Source for file content during upload - either a local path or raw bytes.
54#[derive(Clone)]
55enum FileSource {
56 /// File content from a local filesystem path.
57 Path(PathBuf),
58 /// File content as raw bytes (e.g., from a ZIP archive).
59 Bytes(Vec<u8>),
60}
61
62fn max_tasks() -> usize {
63 std::env::var("MAX_TASKS")
64 .ok()
65 .and_then(|v| v.parse().ok())
66 .unwrap_or_else(|| {
67 // Default to half the number of CPUs, minimum 2, maximum 8
68 let cpus = std::thread::available_parallelism()
69 .map(|n| n.get())
70 .unwrap_or(4);
71 (cpus / 2).clamp(2, 8)
72 })
73}
74
75/// Maximum concurrent upload tasks for multipart S3 uploads.
76///
77/// Higher concurrency improves upload throughput by saturating available
78/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
79fn max_upload_tasks() -> usize {
80 std::env::var("MAX_UPLOAD_TASKS")
81 .ok()
82 .and_then(|v| v.parse().ok())
83 .unwrap_or(8) // Default to 8 concurrent part uploads
84}
85
86/// Filters items by name and sorts by match quality.
87///
88/// Match quality priority (best to worst):
89/// 1. Exact match (case-sensitive)
90/// 2. Exact match (case-insensitive)
91/// 3. Substring match (shorter names first, then alphabetically)
92///
93/// This ensures that searching for "Deer" returns "Deer" before
94/// "Deer Roundtrip 20251129" or "Reindeer".
95fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
96where
97 F: Fn(&T) -> &str,
98{
99 let filter_lower = filter.to_lowercase();
100 let mut filtered: Vec<T> = items
101 .into_iter()
102 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
103 .collect();
104
105 filtered.sort_by(|a, b| {
106 let name_a = get_name(a);
107 let name_b = get_name(b);
108
109 // Priority 1: Exact match (case-sensitive)
110 let exact_a = name_a == filter;
111 let exact_b = name_b == filter;
112 if exact_a != exact_b {
113 return exact_b.cmp(&exact_a); // true (exact) comes first
114 }
115
116 // Priority 2: Exact match (case-insensitive)
117 let exact_ci_a = name_a.to_lowercase() == filter_lower;
118 let exact_ci_b = name_b.to_lowercase() == filter_lower;
119 if exact_ci_a != exact_ci_b {
120 return exact_ci_b.cmp(&exact_ci_a);
121 }
122
123 // Priority 3: Shorter names first (more specific matches)
124 let len_cmp = name_a.len().cmp(&name_b.len());
125 if len_cmp != std::cmp::Ordering::Equal {
126 return len_cmp;
127 }
128
129 // Priority 4: Alphabetical order for stability
130 name_a.cmp(name_b)
131 });
132
133 filtered
134}
135
136fn sanitize_path_component(name: &str) -> String {
137 let trimmed = name.trim();
138 if trimmed.is_empty() {
139 return "unnamed".to_string();
140 }
141
142 let component = Path::new(trimmed)
143 .file_name()
144 .unwrap_or_else(|| OsStr::new(trimmed));
145
146 let sanitized: String = component
147 .to_string_lossy()
148 .chars()
149 .map(|c| match c {
150 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
151 _ => c,
152 })
153 .collect();
154
155 if sanitized.is_empty() {
156 "unnamed".to_string()
157 } else {
158 sanitized
159 }
160}
161
162/// Progress information for long-running operations.
163///
164/// This struct tracks the current progress of operations like file uploads,
165/// downloads, or dataset processing. It provides the current count, total
166/// count, and an optional status string to enable progress reporting in
167/// applications.
168///
169/// # Multi-Stage Progress
170///
171/// The `status` field enables multi-stage progress tracking. When an operation
172/// has multiple phases, the status field changes to indicate the current phase.
173/// Applications should detect status changes to reset their progress display.
174///
175/// # Operation Progress Details
176///
177/// | Operation | Status | Unit | Notes |
178/// |-----------|--------|------|-------|
179/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
180/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
181/// | [`samples`] | `None` | samples | Paginated API fetch |
182/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
183/// | [`annotations`] | `None` | samples | Samples processed for annotations |
184/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
185/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
186/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
187///
188/// [`download_dataset`]: Client::download_dataset
189/// [`populate_samples`]: Client::populate_samples
190/// [`samples`]: Client::samples
191/// [`sample_names`]: Client::sample_names
192/// [`annotations`]: Client::annotations
193/// [`download_artifact`]: Client::download_artifact
194/// [`download_checkpoint`]: Client::download_checkpoint
195/// [`download_snapshot`]: Client::download_snapshot
196///
197/// # Examples
198///
199/// Basic progress display:
200///
201/// ```rust
202/// use edgefirst_client::Progress;
203///
204/// let progress = Progress {
205/// current: 25,
206/// total: 100,
207/// status: Some("Downloading".to_string()),
208/// };
209/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
210/// println!(
211/// "{}: {:.1}% ({}/{})",
212/// progress.status.as_deref().unwrap_or("Progress"),
213/// percentage,
214/// progress.current,
215/// progress.total
216/// );
217/// ```
218///
219/// Multi-stage progress handling (e.g., for `download_dataset`):
220///
221/// ```rust,ignore
222/// let mut last_status: Option<String> = None;
223///
224/// while let Some(progress) = rx.recv().await {
225/// // Detect stage change and reset progress bar
226/// if progress.status != last_status {
227/// if let Some(ref status) = progress.status {
228/// println!("\n{}", status);
229/// }
230/// last_status = progress.status.clone();
231/// }
232///
233/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
234/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
235/// }
236/// ```
237#[derive(Debug, Clone)]
238pub struct Progress {
239 /// Current number of completed items or bytes.
240 pub current: usize,
241 /// Total number of items or bytes to process.
242 pub total: usize,
243 /// Optional status describing the current operation phase.
244 ///
245 /// When this value changes from `None` to `Some(...)` or between different
246 /// values, it indicates a new phase has started. Applications should reset
247 /// their progress display when the status changes.
248 ///
249 /// Currently only [`Client::download_dataset`] uses status changes:
250 /// - Phase 1: `None` while fetching sample metadata
251 /// - Phase 2: `"Downloading"` while downloading files
252 ///
253 /// All other operations use `None` throughout.
254 pub status: Option<String>,
255}
256
257#[derive(Serialize)]
258struct RpcRequest<Params> {
259 id: u64,
260 jsonrpc: String,
261 method: String,
262 params: Option<Params>,
263}
264
265impl<T> Default for RpcRequest<T> {
266 fn default() -> Self {
267 RpcRequest {
268 id: 0,
269 jsonrpc: "2.0".to_string(),
270 method: "".to_string(),
271 params: None,
272 }
273 }
274}
275
276#[derive(Deserialize)]
277struct RpcError {
278 code: i32,
279 message: String,
280}
281
282#[derive(Deserialize)]
283struct RpcResponse<RpcResult> {
284 #[allow(dead_code)]
285 id: String,
286 #[allow(dead_code)]
287 jsonrpc: String,
288 error: Option<RpcError>,
289 result: Option<RpcResult>,
290}
291
292#[derive(Deserialize)]
293#[allow(dead_code)]
294struct EmptyResult {}
295
296#[derive(Debug, Serialize)]
297#[allow(dead_code)]
298struct SnapshotCreateParams {
299 snapshot_name: String,
300 keys: Vec<String>,
301}
302
303#[derive(Debug, Deserialize)]
304#[allow(dead_code)]
305struct SnapshotCreateResult {
306 snapshot_id: SnapshotID,
307 urls: Vec<String>,
308}
309
310#[derive(Debug, Serialize)]
311struct SnapshotCreateMultipartParams {
312 snapshot_name: String,
313 keys: Vec<String>,
314 file_sizes: Vec<usize>,
315 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
316 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
317 snapshot_type: Option<String>,
318}
319
320#[derive(Debug, Deserialize)]
321#[serde(untagged)]
322enum SnapshotCreateMultipartResultField {
323 Id(u64),
324 Part(SnapshotPart),
325}
326
327#[derive(Debug, Serialize)]
328struct SnapshotCompleteMultipartParams {
329 key: String,
330 upload_id: String,
331 etag_list: Vec<EtagPart>,
332}
333
334#[derive(Debug, Clone, Serialize)]
335struct EtagPart {
336 #[serde(rename = "ETag")]
337 etag: String,
338 #[serde(rename = "PartNumber")]
339 part_number: usize,
340}
341
342#[derive(Debug, Clone, Deserialize)]
343struct SnapshotPart {
344 key: Option<String>,
345 upload_id: String,
346 urls: Vec<String>,
347}
348
349#[derive(Debug, Serialize)]
350struct SnapshotStatusParams {
351 snapshot_id: SnapshotID,
352 status: String,
353}
354
355#[derive(Deserialize, Debug)]
356struct SnapshotStatusResult {
357 #[allow(dead_code)]
358 pub id: SnapshotID,
359 #[allow(dead_code)]
360 pub uid: String,
361 #[allow(dead_code)]
362 pub description: String,
363 #[allow(dead_code)]
364 pub date: String,
365 #[allow(dead_code)]
366 pub status: String,
367}
368
369#[derive(Serialize)]
370#[allow(dead_code)]
371struct ImageListParams {
372 images_filter: ImagesFilter,
373 image_files_filter: HashMap<String, String>,
374 only_ids: bool,
375}
376
377#[derive(Serialize)]
378#[allow(dead_code)]
379struct ImagesFilter {
380 dataset_id: DatasetID,
381}
382
383/// Main client for interacting with EdgeFirst Studio Server.
384///
385/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
386/// and manages authentication, RPC calls, and data operations. It provides
387/// methods for managing projects, datasets, experiments, training sessions,
388/// and various utility functions for data processing.
389///
390/// The client supports multiple authentication methods and can work with both
391/// SaaS and self-hosted EdgeFirst Studio instances.
392///
393/// # Features
394///
395/// - **Authentication**: Token-based authentication with automatic persistence
396/// - **Dataset Management**: Upload, download, and manipulate datasets
397/// - **Project Operations**: Create and manage projects and experiments
398/// - **Training & Validation**: Submit and monitor ML training jobs
399/// - **Data Integration**: Convert between EdgeFirst datasets and popular
400/// formats
401/// - **Progress Tracking**: Real-time progress updates for long-running
402/// operations
403///
404/// # Examples
405///
406/// ```no_run
407/// use edgefirst_client::{Client, DatasetID};
408/// use std::str::FromStr;
409///
410/// # async fn example() -> Result<(), edgefirst_client::Error> {
411/// // Create a new client and authenticate
412/// let mut client = Client::new()?;
413/// let client = client
414/// .with_login("your-email@example.com", "password")
415/// .await?;
416///
417/// // Or use an existing token
418/// let base_client = Client::new()?;
419/// let client = base_client.with_token("your-token-here")?;
420///
421/// // Get organization and projects
422/// let org = client.organization().await?;
423/// let projects = client.projects(None).await?;
424///
425/// // Work with datasets
426/// let dataset_id = DatasetID::from_str("ds-abc123")?;
427/// let dataset = client.dataset(dataset_id).await?;
428/// # Ok(())
429/// # }
430/// ```
431/// Client is Clone but cannot derive Debug due to dyn TokenStorage
432#[derive(Clone)]
433pub struct Client {
434 http: reqwest::Client,
435 /// HTTP client for large file uploads (no request timeout)
436 upload_http: reqwest::Client,
437 url: String,
438 token: Arc<RwLock<String>>,
439 /// Token storage backend. When set, tokens are automatically persisted.
440 storage: Option<Arc<dyn TokenStorage>>,
441 /// Legacy token path field for backwards compatibility with
442 /// with_token_path(). Deprecated: Use with_storage() instead.
443 token_path: Option<PathBuf>,
444}
445
446impl std::fmt::Debug for Client {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 f.debug_struct("Client")
449 .field("url", &self.url)
450 .field("has_storage", &self.storage.is_some())
451 .field("token_path", &self.token_path)
452 .finish()
453 }
454}
455
456/// Private context struct for pagination operations
457struct FetchContext<'a> {
458 dataset_id: DatasetID,
459 annotation_set_id: Option<AnnotationSetID>,
460 groups: &'a [String],
461 types: Vec<String>,
462 labels: &'a HashMap<String, u64>,
463}
464
465impl Client {
466 /// Create a new unauthenticated client with the default saas server.
467 ///
468 /// By default, the client uses [`FileTokenStorage`] for token persistence.
469 /// Use [`with_storage`][Self::with_storage],
470 /// [`with_memory_storage`][Self::with_memory_storage],
471 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
472 /// behavior.
473 ///
474 /// To connect to a different server, use [`with_server`][Self::with_server]
475 /// or [`with_token`][Self::with_token] (tokens include the server
476 /// instance).
477 ///
478 /// This client is created without a token and will need to authenticate
479 /// before using methods that require authentication.
480 ///
481 /// # Examples
482 ///
483 /// ```rust,no_run
484 /// use edgefirst_client::Client;
485 ///
486 /// # fn main() -> Result<(), edgefirst_client::Error> {
487 /// // Create client with default file storage
488 /// let client = Client::new()?;
489 ///
490 /// // Create client without token persistence
491 /// let client = Client::new()?.with_memory_storage();
492 /// # Ok(())
493 /// # }
494 /// ```
495 pub fn new() -> Result<Self, Error> {
496 log_retry_configuration();
497
498 // Get timeout from environment or use default
499 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
500 .ok()
501 .and_then(|s| s.parse().ok())
502 .unwrap_or(30); // Default 30s timeout for API calls
503
504 // Create single HTTP client with URL-based retry policy
505 //
506 // The retry policy classifies requests into two categories:
507 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
508 // errors
509 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
510 //
511 // This allows the same client to handle both API calls and file operations
512 // with appropriate retry behavior for each. See retry.rs for details.
513 let http = reqwest::Client::builder()
514 .connect_timeout(Duration::from_secs(10))
515 .timeout(Duration::from_secs(timeout_secs))
516 .pool_idle_timeout(Duration::from_secs(90))
517 .pool_max_idle_per_host(10)
518 .retry(create_retry_policy())
519 .build()?;
520
521 // Separate HTTP client for large file uploads - no request timeout
522 // since upload duration depends on file size and network speed
523 let upload_http = reqwest::Client::builder()
524 .connect_timeout(Duration::from_secs(30))
525 .pool_idle_timeout(Duration::from_secs(90))
526 .pool_max_idle_per_host(10)
527 .build()?;
528
529 // Default to file storage, loading any existing token
530 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
531 Ok(file_storage) => Arc::new(file_storage),
532 Err(e) => {
533 warn!(
534 "Could not initialize file token storage: {}. Using memory storage.",
535 e
536 );
537 Arc::new(MemoryTokenStorage::new())
538 }
539 };
540
541 // Try to load existing token from storage
542 let token = match storage.load() {
543 Ok(Some(t)) => t,
544 Ok(None) => String::new(),
545 Err(e) => {
546 warn!(
547 "Failed to load token from storage: {}. Starting with empty token.",
548 e
549 );
550 String::new()
551 }
552 };
553
554 // Extract server from token if available
555 let url = if !token.is_empty() {
556 match Self::extract_server_from_token(&token) {
557 Ok(server) => format!("https://{}.edgefirst.studio", server),
558 Err(e) => {
559 warn!(
560 "Failed to extract server from token: {}. Using default server.",
561 e
562 );
563 "https://edgefirst.studio".to_string()
564 }
565 }
566 } else {
567 "https://edgefirst.studio".to_string()
568 };
569
570 Ok(Client {
571 http,
572 upload_http,
573 url,
574 token: Arc::new(tokio::sync::RwLock::new(token)),
575 storage: Some(storage),
576 token_path: None,
577 })
578 }
579
580 /// Returns a new client connected to the specified server instance.
581 ///
582 /// The server parameter is an instance name that maps to a URL:
583 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
584 /// server)
585 /// - `"test"` → `https://test.edgefirst.studio`
586 /// - `"stage"` → `https://stage.edgefirst.studio`
587 /// - `"dev"` → `https://dev.edgefirst.studio`
588 /// - `"{name}"` → `https://{name}.edgefirst.studio`
589 ///
590 /// # Server Selection Priority
591 ///
592 /// When using the CLI or Python API, server selection follows this
593 /// priority:
594 ///
595 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
596 /// they were issued for. If you have a valid token, its server is used.
597 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
598 /// token is available. If a token exists with a different server, a
599 /// warning is emitted and the token's server takes priority.
600 /// 3. **Default `"saas"`** - If no token and no server specified, the
601 /// production server (`https://edgefirst.studio`) is used.
602 ///
603 /// # Important Notes
604 ///
605 /// - If a token is already set in the client, calling this method will
606 /// **drop the token** as tokens are specific to the server instance.
607 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
608 /// token's server before calling this method.
609 /// - For login operations, call `with_server()` first, then authenticate.
610 ///
611 /// # Examples
612 ///
613 /// ```rust,no_run
614 /// use edgefirst_client::Client;
615 ///
616 /// # fn main() -> Result<(), edgefirst_client::Error> {
617 /// let client = Client::new()?.with_server("test")?;
618 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
619 /// # Ok(())
620 /// # }
621 /// ```
622 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
623 let url = match server {
624 "" | "saas" => "https://edgefirst.studio".to_string(),
625 name => format!("https://{}.edgefirst.studio", name),
626 };
627
628 // Clear token from storage when changing servers to prevent
629 // authentication issues with stale tokens from different instances
630 if let Some(ref storage) = self.storage
631 && let Err(e) = storage.clear()
632 {
633 warn!(
634 "Failed to clear token from storage when changing servers: {}",
635 e
636 );
637 }
638
639 Ok(Client {
640 url,
641 token: Arc::new(tokio::sync::RwLock::new(String::new())),
642 ..self.clone()
643 })
644 }
645
646 /// Returns a new client with the specified token storage backend.
647 ///
648 /// Use this to configure custom token storage, such as platform-specific
649 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
650 ///
651 /// # Examples
652 ///
653 /// ```rust,no_run
654 /// use edgefirst_client::{Client, FileTokenStorage};
655 /// use std::{path::PathBuf, sync::Arc};
656 ///
657 /// # fn main() -> Result<(), edgefirst_client::Error> {
658 /// // Use a custom file path for token storage
659 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
660 /// let client = Client::new()?.with_storage(Arc::new(storage));
661 /// # Ok(())
662 /// # }
663 /// ```
664 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
665 // Try to load existing token from the new storage
666 let token = match storage.load() {
667 Ok(Some(t)) => t,
668 Ok(None) => String::new(),
669 Err(e) => {
670 warn!(
671 "Failed to load token from storage: {}. Starting with empty token.",
672 e
673 );
674 String::new()
675 }
676 };
677
678 Client {
679 token: Arc::new(tokio::sync::RwLock::new(token)),
680 storage: Some(storage),
681 token_path: None,
682 ..self
683 }
684 }
685
686 /// Returns a new client with in-memory token storage (no persistence).
687 ///
688 /// Tokens are stored in memory only and lost when the application exits.
689 /// This is useful for testing or when you want to manage token persistence
690 /// externally.
691 ///
692 /// # Examples
693 ///
694 /// ```rust,no_run
695 /// use edgefirst_client::Client;
696 ///
697 /// # fn main() -> Result<(), edgefirst_client::Error> {
698 /// let client = Client::new()?.with_memory_storage();
699 /// # Ok(())
700 /// # }
701 /// ```
702 pub fn with_memory_storage(self) -> Self {
703 Client {
704 token: Arc::new(tokio::sync::RwLock::new(String::new())),
705 storage: Some(Arc::new(MemoryTokenStorage::new())),
706 token_path: None,
707 ..self
708 }
709 }
710
711 /// Returns a new client with no token storage.
712 ///
713 /// Tokens are not persisted. Use this when you want to manage tokens
714 /// entirely manually.
715 ///
716 /// # Examples
717 ///
718 /// ```rust,no_run
719 /// use edgefirst_client::Client;
720 ///
721 /// # fn main() -> Result<(), edgefirst_client::Error> {
722 /// let client = Client::new()?.with_no_storage();
723 /// # Ok(())
724 /// # }
725 /// ```
726 pub fn with_no_storage(self) -> Self {
727 Client {
728 storage: None,
729 token_path: None,
730 ..self
731 }
732 }
733
734 /// Returns a new client authenticated with the provided username and
735 /// password.
736 ///
737 /// The token is automatically persisted to storage (if configured).
738 ///
739 /// # Examples
740 ///
741 /// ```rust,no_run
742 /// use edgefirst_client::Client;
743 ///
744 /// # async fn example() -> Result<(), edgefirst_client::Error> {
745 /// let client = Client::new()?
746 /// .with_server("test")?
747 /// .with_login("user@example.com", "password")
748 /// .await?;
749 /// # Ok(())
750 /// # }
751 /// ```
752 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
753 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
754 let params = HashMap::from([("username", username), ("password", password)]);
755 let login: LoginResult = self
756 .rpc_without_auth("auth.login".to_owned(), Some(params))
757 .await?;
758
759 // Validate that the server returned a non-empty token
760 if login.token.is_empty() {
761 return Err(Error::EmptyToken);
762 }
763
764 // Persist token to storage if configured
765 if let Some(ref storage) = self.storage
766 && let Err(e) = storage.store(&login.token)
767 {
768 warn!("Failed to persist token to storage: {}", e);
769 }
770
771 Ok(Client {
772 token: Arc::new(tokio::sync::RwLock::new(login.token)),
773 ..self.clone()
774 })
775 }
776
777 /// Returns a new client which will load and save the token to the specified
778 /// path.
779 ///
780 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
781 /// [`FileTokenStorage`] instead for more flexible token management.
782 ///
783 /// This method is maintained for backwards compatibility with existing
784 /// code. It disables the default storage and uses file-based storage at
785 /// the specified path.
786 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
787 let token_path = match token_path {
788 Some(path) => path.to_path_buf(),
789 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
790 .ok_or_else(|| {
791 Error::IoError(std::io::Error::new(
792 std::io::ErrorKind::NotFound,
793 "Could not determine user config directory",
794 ))
795 })?
796 .config_dir()
797 .join("token"),
798 };
799
800 debug!("Using token path (legacy): {:?}", token_path);
801
802 let token = match token_path.exists() {
803 true => std::fs::read_to_string(&token_path)?,
804 false => "".to_string(),
805 };
806
807 if !token.is_empty() {
808 match self.with_token(&token) {
809 Ok(client) => Ok(Client {
810 token_path: Some(token_path),
811 storage: None, // Disable new storage when using legacy token_path
812 ..client
813 }),
814 Err(e) => {
815 // Token is corrupted or invalid - remove it and continue with no token
816 warn!(
817 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
818 token_path, e
819 );
820 if let Err(remove_err) = std::fs::remove_file(&token_path) {
821 warn!("Failed to remove corrupted token file: {:?}", remove_err);
822 }
823 // Clear any token from default storage to ensure we don't use it
824 Ok(Client {
825 token_path: Some(token_path),
826 storage: None,
827 token: Arc::new(RwLock::new("".to_string())),
828 ..self.clone()
829 })
830 }
831 }
832 } else {
833 // No token in the legacy file - clear any token from default storage
834 Ok(Client {
835 token_path: Some(token_path),
836 storage: None,
837 token: Arc::new(RwLock::new("".to_string())),
838 ..self.clone()
839 })
840 }
841 }
842
843 /// Returns a new client authenticated with the provided token.
844 ///
845 /// The token is automatically persisted to storage (if configured).
846 /// The server URL is extracted from the token payload.
847 ///
848 /// # Examples
849 ///
850 /// ```rust,no_run
851 /// use edgefirst_client::Client;
852 ///
853 /// # fn main() -> Result<(), edgefirst_client::Error> {
854 /// let client = Client::new()?.with_token("your-jwt-token")?;
855 /// # Ok(())
856 /// # }
857 /// ```
858 /// Extract server name from JWT token payload.
859 ///
860 /// Helper method to parse the JWT token and extract the "server" field
861 /// from the payload. Returns the server name (e.g., "test", "stage", "")
862 /// or an error if the token is invalid.
863 fn extract_server_from_token(token: &str) -> Result<String, Error> {
864 let token_parts: Vec<&str> = token.split('.').collect();
865 if token_parts.len() != 3 {
866 return Err(Error::InvalidToken);
867 }
868
869 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
870 .decode(token_parts[1])
871 .map_err(|_| Error::InvalidToken)?;
872 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
873 let server = match payload.get("server") {
874 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
875 None => return Err(Error::InvalidToken),
876 };
877
878 Ok(server)
879 }
880
881 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
882 if token.is_empty() {
883 return Ok(self.clone());
884 }
885
886 let server = Self::extract_server_from_token(token)?;
887
888 // Persist token to storage if configured
889 if let Some(ref storage) = self.storage
890 && let Err(e) = storage.store(token)
891 {
892 warn!("Failed to persist token to storage: {}", e);
893 }
894
895 Ok(Client {
896 url: format!("https://{}.edgefirst.studio", server),
897 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
898 ..self.clone()
899 })
900 }
901
902 /// Persist the current token to storage.
903 ///
904 /// This is automatically called when using [`with_login`][Self::with_login]
905 /// or [`with_token`][Self::with_token], so you typically don't need to call
906 /// this directly.
907 ///
908 /// If using the legacy `token_path` configuration, saves to the file path.
909 /// If using the new storage abstraction, saves to the configured storage.
910 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
911 pub async fn save_token(&self) -> Result<(), Error> {
912 let token = self.token.read().await;
913
914 // Try new storage first
915 if let Some(ref storage) = self.storage {
916 storage.store(&token)?;
917 debug!("Token saved to storage");
918 return Ok(());
919 }
920
921 // Fall back to legacy token_path behavior
922 let path = self.token_path.clone().unwrap_or_else(|| {
923 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
924 .map(|dirs| dirs.config_dir().join("token"))
925 .unwrap_or_else(|| PathBuf::from(".token"))
926 });
927
928 create_dir_all(path.parent().ok_or_else(|| {
929 Error::IoError(std::io::Error::new(
930 std::io::ErrorKind::InvalidInput,
931 "Token path has no parent directory",
932 ))
933 })?)?;
934 let mut file = std::fs::File::create(&path)?;
935 file.write_all(token.as_bytes())?;
936
937 debug!("Saved token to {:?}", path);
938
939 Ok(())
940 }
941
942 /// Return the version of the EdgeFirst Studio server for the current
943 /// client connection.
944 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
945 pub async fn version(&self) -> Result<String, Error> {
946 let version: HashMap<String, String> = self
947 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
948 .await?;
949 let version = version.get("version").ok_or(Error::InvalidResponse)?;
950 Ok(version.to_owned())
951 }
952
953 /// Clear the token used to authenticate the client with the server.
954 ///
955 /// Clears the token from memory and from storage (if configured).
956 /// If using the legacy `token_path` configuration, removes the token file.
957 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
958 pub async fn logout(&self) -> Result<(), Error> {
959 {
960 let mut token = self.token.write().await;
961 *token = "".to_string();
962 }
963
964 // Clear from new storage if configured
965 if let Some(ref storage) = self.storage
966 && let Err(e) = storage.clear()
967 {
968 warn!("Failed to clear token from storage: {}", e);
969 }
970
971 // Also clear legacy token_path if configured
972 if let Some(path) = &self.token_path
973 && path.exists()
974 {
975 fs::remove_file(path).await?;
976 }
977
978 Ok(())
979 }
980
981 /// Return the token used to authenticate the client with the server. When
982 /// logging into the server using a username and password, the token is
983 /// returned by the server and stored in the client for future interactions.
984 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
985 pub async fn token(&self) -> String {
986 self.token.read().await.clone()
987 }
988
989 /// Verify the token used to authenticate the client with the server. This
990 /// method is used to ensure that the token is still valid and has not
991 /// expired. If the token is invalid, the server will return an error and
992 /// the client will need to login again.
993 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
994 pub async fn verify_token(&self) -> Result<(), Error> {
995 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
996 .await?;
997 Ok::<(), Error>(())
998 }
999
1000 /// Renew the token used to authenticate the client with the server.
1001 ///
1002 /// Refreshes the token before it expires. If the token has already expired,
1003 /// the server will return an error and you will need to login again.
1004 ///
1005 /// The new token is automatically persisted to storage (if configured).
1006 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1007 pub async fn renew_token(&self) -> Result<(), Error> {
1008 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1009 let result: LoginResult = self
1010 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1011 .await?;
1012
1013 {
1014 let mut token = self.token.write().await;
1015 *token = result.token.clone();
1016 }
1017
1018 // Persist to new storage if configured
1019 if let Some(ref storage) = self.storage
1020 && let Err(e) = storage.store(&result.token)
1021 {
1022 warn!("Failed to persist renewed token to storage: {}", e);
1023 }
1024
1025 // Also persist to legacy token_path if configured
1026 if self.token_path.is_some() {
1027 self.save_token().await?;
1028 }
1029
1030 Ok(())
1031 }
1032
1033 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1034 let token = self.token.read().await;
1035 if token.is_empty() {
1036 return Err(Error::EmptyToken);
1037 }
1038
1039 let token_parts: Vec<&str> = token.split('.').collect();
1040 if token_parts.len() != 3 {
1041 return Err(Error::InvalidToken);
1042 }
1043
1044 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1045 .decode(token_parts[1])
1046 .map_err(|_| Error::InvalidToken)?;
1047 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1048 match payload.get(field) {
1049 Some(value) => Ok(value.to_owned()),
1050 None => Err(Error::InvalidToken),
1051 }
1052 }
1053
1054 /// Returns the URL of the EdgeFirst Studio server for the current client.
1055 pub fn url(&self) -> &str {
1056 &self.url
1057 }
1058
1059 /// Returns the server name for the current client.
1060 ///
1061 /// This extracts the server name from the client's URL:
1062 /// - `https://edgefirst.studio` → `"saas"`
1063 /// - `https://test.edgefirst.studio` → `"test"`
1064 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1065 ///
1066 /// # Examples
1067 ///
1068 /// ```rust,no_run
1069 /// use edgefirst_client::Client;
1070 ///
1071 /// # fn main() -> Result<(), edgefirst_client::Error> {
1072 /// let client = Client::new()?.with_server("test")?;
1073 /// assert_eq!(client.server(), "test");
1074 ///
1075 /// let client = Client::new()?; // default
1076 /// assert_eq!(client.server(), "saas");
1077 /// # Ok(())
1078 /// # }
1079 /// ```
1080 pub fn server(&self) -> &str {
1081 if self.url == "https://edgefirst.studio" {
1082 "saas"
1083 } else if let Some(name) = self.url.strip_prefix("https://") {
1084 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1085 } else {
1086 "saas"
1087 }
1088 }
1089
1090 /// Returns the username associated with the current token.
1091 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1092 pub async fn username(&self) -> Result<String, Error> {
1093 match self.token_field("username").await? {
1094 serde_json::Value::String(username) => Ok(username),
1095 _ => Err(Error::InvalidToken),
1096 }
1097 }
1098
1099 /// Returns the expiration time for the current token.
1100 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1101 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1102 let ts = match self.token_field("exp").await? {
1103 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1104 _ => return Err(Error::InvalidToken),
1105 };
1106
1107 match DateTime::<Utc>::from_timestamp(ts, 0) {
1108 Some(dt) => Ok(dt),
1109 None => Err(Error::InvalidToken),
1110 }
1111 }
1112
1113 /// Returns the organization information for the current user.
1114 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1115 pub async fn organization(&self) -> Result<Organization, Error> {
1116 self.rpc::<(), Organization>("org.get".to_owned(), None)
1117 .await
1118 }
1119
1120 /// Returns a list of projects available to the user. The projects are
1121 /// returned as a vector of Project objects. If a name filter is
1122 /// provided, only projects matching the filter are returned.
1123 ///
1124 /// Results are sorted by match quality: exact matches first, then
1125 /// case-insensitive exact matches, then shorter names (more specific),
1126 /// then alphabetically.
1127 ///
1128 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1129 /// Projects contain datasets, trainers, and trainer sessions. Projects
1130 /// are used to group related datasets and trainers together.
1131 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1132 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1133 let projects = self
1134 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1135 .await?;
1136 if let Some(name) = name {
1137 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1138 } else {
1139 Ok(projects)
1140 }
1141 }
1142
1143 /// Return the project with the specified project ID. If the project does
1144 /// not exist, an error is returned.
1145 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1146 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1147 let params = HashMap::from([("project_id", project_id)]);
1148 self.rpc("project.get".to_owned(), Some(params)).await
1149 }
1150
1151 /// Returns a list of datasets available to the user. The datasets are
1152 /// returned as a vector of Dataset objects. If a name filter is
1153 /// provided, only datasets matching the filter are returned.
1154 ///
1155 /// Results are sorted by match quality: exact matches first, then
1156 /// case-insensitive exact matches, then shorter names (more specific),
1157 /// then alphabetically. This ensures "Deer" returns before "Deer
1158 /// Roundtrip".
1159 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1160 pub async fn datasets(
1161 &self,
1162 project_id: ProjectID,
1163 name: Option<&str>,
1164 ) -> Result<Vec<Dataset>, Error> {
1165 let params = HashMap::from([("project_id", project_id)]);
1166 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1167 if let Some(name) = name {
1168 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1169 } else {
1170 Ok(datasets)
1171 }
1172 }
1173
1174 /// Return the dataset with the specified dataset ID. If the dataset does
1175 /// not exist, an error is returned.
1176 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1177 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1178 let params = HashMap::from([("dataset_id", dataset_id)]);
1179 self.rpc("dataset.get".to_owned(), Some(params)).await
1180 }
1181
1182 /// Lists the labels for the specified dataset.
1183 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1184 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1185 let params = HashMap::from([("dataset_id", dataset_id)]);
1186 self.rpc("label.list".to_owned(), Some(params)).await
1187 }
1188
1189 /// Add a new label to the dataset with the specified name.
1190 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1191 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1192 let new_label = NewLabel {
1193 dataset_id,
1194 labels: vec![NewLabelObject {
1195 name: name.to_owned(),
1196 }],
1197 };
1198 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1199 Ok(())
1200 }
1201
1202 /// Removes the label with the specified ID from the dataset. Label IDs are
1203 /// globally unique so the dataset_id is not required.
1204 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1205 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1206 let params = HashMap::from([("label_id", label_id)]);
1207 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1208 Ok(())
1209 }
1210
1211 /// Creates a new dataset in the specified project.
1212 ///
1213 /// # Arguments
1214 ///
1215 /// * `project_id` - The ID of the project to create the dataset in
1216 /// * `name` - The name of the new dataset
1217 /// * `description` - Optional description for the dataset
1218 ///
1219 /// # Returns
1220 ///
1221 /// Returns the dataset ID of the newly created dataset.
1222 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1223 pub async fn create_dataset(
1224 &self,
1225 project_id: &str,
1226 name: &str,
1227 description: Option<&str>,
1228 ) -> Result<DatasetID, Error> {
1229 let mut params = HashMap::new();
1230 params.insert("project_id", project_id);
1231 params.insert("name", name);
1232 if let Some(desc) = description {
1233 params.insert("description", desc);
1234 }
1235
1236 #[derive(Deserialize)]
1237 struct CreateDatasetResult {
1238 id: DatasetID,
1239 }
1240
1241 let result: CreateDatasetResult =
1242 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1243 Ok(result.id)
1244 }
1245
1246 /// Deletes a dataset by marking it as deleted.
1247 ///
1248 /// # Arguments
1249 ///
1250 /// * `dataset_id` - The ID of the dataset to delete
1251 ///
1252 /// # Returns
1253 ///
1254 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1255 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1256 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1257 let params = HashMap::from([("id", dataset_id)]);
1258 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1259 Ok(())
1260 }
1261
1262 /// Updates the label with the specified ID to have the new name or index.
1263 /// Label IDs cannot be changed. Label IDs are globally unique so the
1264 /// dataset_id is not required.
1265 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1266 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1267 #[derive(Serialize)]
1268 struct Params {
1269 dataset_id: DatasetID,
1270 label_id: u64,
1271 label_name: String,
1272 label_index: u64,
1273 }
1274
1275 let _: String = self
1276 .rpc(
1277 "label.update".to_owned(),
1278 Some(Params {
1279 dataset_id: label.dataset_id(),
1280 label_id: label.id(),
1281 label_name: label.name().to_owned(),
1282 label_index: label.index(),
1283 }),
1284 )
1285 .await?;
1286 Ok(())
1287 }
1288
1289 /// Lists the groups for the specified dataset.
1290 ///
1291 /// Groups are used to organize samples into logical subsets such as
1292 /// "train", "val", "test", etc. Each sample can belong to at most one
1293 /// group at a time.
1294 ///
1295 /// # Arguments
1296 ///
1297 /// * `dataset_id` - The ID of the dataset to list groups for
1298 ///
1299 /// # Returns
1300 ///
1301 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1302 /// empty vector if no groups have been created yet.
1303 ///
1304 /// # Errors
1305 ///
1306 /// Returns an error if the dataset does not exist or cannot be accessed.
1307 ///
1308 /// # Example
1309 ///
1310 /// ```rust,no_run
1311 /// # use edgefirst_client::{Client, DatasetID};
1312 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1313 /// let client = Client::new()?.with_token_path(None)?;
1314 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1315 ///
1316 /// let groups = client.groups(dataset_id).await?;
1317 /// for group in groups {
1318 /// println!("{}: {}", group.id, group.name);
1319 /// }
1320 /// # Ok(())
1321 /// # }
1322 /// ```
1323 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1324 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1325 let params = HashMap::from([("dataset_id", dataset_id)]);
1326 self.rpc("groups.list".to_owned(), Some(params)).await
1327 }
1328
1329 /// Gets an existing group by name or creates a new one.
1330 ///
1331 /// This is a convenience method that first checks if a group with the
1332 /// specified name exists, and creates it if not. This is useful when
1333 /// you need to ensure a group exists before assigning samples to it.
1334 ///
1335 /// # Arguments
1336 ///
1337 /// * `dataset_id` - The ID of the dataset
1338 /// * `name` - The name of the group (e.g., "train", "val", "test")
1339 ///
1340 /// # Returns
1341 ///
1342 /// Returns the group ID (either existing or newly created).
1343 ///
1344 /// # Errors
1345 ///
1346 /// Returns an error if:
1347 /// - The dataset does not exist or cannot be accessed
1348 /// - The group creation fails
1349 ///
1350 /// # Concurrency
1351 ///
1352 /// This method handles concurrent creation attempts gracefully. If another
1353 /// process creates the group between the existence check and creation,
1354 /// this method will return the existing group's ID.
1355 ///
1356 /// # Example
1357 ///
1358 /// ```rust,no_run
1359 /// # use edgefirst_client::{Client, DatasetID};
1360 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1361 /// let client = Client::new()?.with_token_path(None)?;
1362 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1363 ///
1364 /// // Get or create a "train" group
1365 /// let train_group_id = client
1366 /// .get_or_create_group(dataset_id.clone(), "train")
1367 /// .await?;
1368 /// println!("Train group ID: {}", train_group_id);
1369 ///
1370 /// // Calling again returns the same ID
1371 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1372 /// assert_eq!(train_group_id, same_id);
1373 /// # Ok(())
1374 /// # }
1375 /// ```
1376 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1377 pub async fn get_or_create_group(
1378 &self,
1379 dataset_id: DatasetID,
1380 name: &str,
1381 ) -> Result<u64, Error> {
1382 // First check if the group already exists
1383 let groups = self.groups(dataset_id).await?;
1384 if let Some(group) = groups.iter().find(|g| g.name == name) {
1385 return Ok(group.id);
1386 }
1387
1388 // Create the group
1389 #[derive(Serialize)]
1390 struct CreateGroupParams {
1391 dataset_id: DatasetID,
1392 group_names: Vec<String>,
1393 group_splits: Vec<i64>,
1394 }
1395
1396 let params = CreateGroupParams {
1397 dataset_id,
1398 group_names: vec![name.to_string()],
1399 group_splits: vec![0], // No automatic splitting
1400 };
1401
1402 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1403 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1404 Ok(group.id)
1405 } else {
1406 // Group might have been created by concurrent call, try fetching again
1407 let groups = self.groups(dataset_id).await?;
1408 groups
1409 .iter()
1410 .find(|g| g.name == name)
1411 .map(|g| g.id)
1412 .ok_or_else(|| {
1413 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1414 })
1415 }
1416 }
1417
1418 /// Sets the group for a sample.
1419 ///
1420 /// Assigns a sample to a specific group. Each sample can belong to at most
1421 /// one group at a time. Setting a new group replaces any existing group
1422 /// assignment.
1423 ///
1424 /// # Arguments
1425 ///
1426 /// * `sample_id` - The ID of the sample (image) to update
1427 /// * `group_id` - The ID of the group to assign. Use
1428 /// [`get_or_create_group`] to obtain a group ID from a name.
1429 ///
1430 /// # Returns
1431 ///
1432 /// Returns `Ok(())` on success.
1433 ///
1434 /// # Errors
1435 ///
1436 /// Returns an error if:
1437 /// - The sample does not exist
1438 /// - The group does not exist
1439 /// - Insufficient permissions to modify the sample
1440 ///
1441 /// # Example
1442 ///
1443 /// ```rust,no_run
1444 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1445 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1446 /// let client = Client::new()?.with_token_path(None)?;
1447 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1448 /// let sample_id: SampleID = 12345.into();
1449 ///
1450 /// // Get or create the "val" group
1451 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1452 ///
1453 /// // Assign the sample to the "val" group
1454 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1455 /// # Ok(())
1456 /// # }
1457 /// ```
1458 ///
1459 /// [`get_or_create_group`]: Self::get_or_create_group
1460 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1461 pub async fn set_sample_group_id(
1462 &self,
1463 sample_id: SampleID,
1464 group_id: u64,
1465 ) -> Result<(), Error> {
1466 #[derive(Serialize)]
1467 struct SetGroupParams {
1468 image_id: SampleID,
1469 group_id: u64,
1470 }
1471
1472 let params = SetGroupParams {
1473 image_id: sample_id,
1474 group_id,
1475 };
1476 let _: String = self
1477 .rpc("image.set_group_id".to_owned(), Some(params))
1478 .await?;
1479 Ok(())
1480 }
1481
1482 /// Downloads dataset samples to the local filesystem.
1483 ///
1484 /// # Arguments
1485 ///
1486 /// * `dataset_id` - The unique identifier of the dataset
1487 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1488 /// * `file_types` - File types to download. Supported types:
1489 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
1490 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
1491 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
1492 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
1493 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
1494 /// - `FileType::RadarCube` - Radar cube data (.png format)
1495 /// - `FileType::All` - All sensor types (expands to all of the above)
1496 /// * `output` - Local directory to save downloaded files
1497 /// * `flatten` - If true, download all files to output root without
1498 /// sequence subdirectories. When flattening, filenames are prefixed with
1499 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1500 /// unavailable) unless the filename already starts with
1501 /// `{sequence_name}_`, to avoid conflicts between sequences.
1502 /// * `progress` - Optional channel for progress updates
1503 ///
1504 /// # Progress
1505 ///
1506 /// This operation has two phases with distinct progress reporting:
1507 ///
1508 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
1509 /// from the server. Progress counts samples fetched.
1510 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
1511 /// files to disk. Progress counts samples completed (each sample may
1512 /// have multiple files for different sensor types).
1513 ///
1514 /// Applications should detect the status change from `None` to
1515 /// `"Downloading"` to reset their progress bar for the second phase.
1516 ///
1517 /// # Returns
1518 ///
1519 /// Returns `Ok(())` on success or an error if download fails.
1520 ///
1521 /// # Example
1522 ///
1523 /// ```rust,no_run
1524 /// # use edgefirst_client::{Client, DatasetID, FileType};
1525 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1526 /// let client = Client::new()?.with_token_path(None)?;
1527 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1528 ///
1529 /// // Download with sequence subdirectories (default)
1530 /// client
1531 /// .download_dataset(
1532 /// dataset_id,
1533 /// &[],
1534 /// &[FileType::Image],
1535 /// "./data".into(),
1536 /// false,
1537 /// None,
1538 /// )
1539 /// .await?;
1540 ///
1541 /// // Download flattened (all files in one directory)
1542 /// client
1543 /// .download_dataset(
1544 /// dataset_id,
1545 /// &[],
1546 /// &[FileType::Image],
1547 /// "./data".into(),
1548 /// true,
1549 /// None,
1550 /// )
1551 /// .await?;
1552 ///
1553 /// // Download all sensor types
1554 /// client
1555 /// .download_dataset(
1556 /// dataset_id,
1557 /// &[],
1558 /// &FileType::expand_types(&[FileType::All]),
1559 /// "./data".into(),
1560 /// false,
1561 /// None,
1562 /// )
1563 /// .await?;
1564 /// # Ok(())
1565 /// # }
1566 /// ```
1567 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
1568 pub async fn download_dataset(
1569 &self,
1570 dataset_id: DatasetID,
1571 groups: &[String],
1572 file_types: &[FileType],
1573 output: PathBuf,
1574 flatten: bool,
1575 progress: Option<Sender<Progress>>,
1576 ) -> Result<(), Error> {
1577 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
1578 let samples = self
1579 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1580 .await?;
1581 fs::create_dir_all(&output).await?;
1582
1583 // Phase 2: Download actual files using direct semaphore pattern
1584 let total = samples.len();
1585 let current = Arc::new(AtomicUsize::new(0));
1586 let sem = Arc::new(Semaphore::new(max_tasks()));
1587
1588 // Send initial progress for download phase
1589 if let Some(ref progress) = progress {
1590 let _ = progress
1591 .send(Progress {
1592 current: 0,
1593 total,
1594 status: Some("Downloading".to_string()),
1595 })
1596 .await;
1597 }
1598
1599 let tasks = samples
1600 .into_iter()
1601 .map(|sample| {
1602 let client = self.clone();
1603 let file_types = file_types.to_vec();
1604 let output = output.clone();
1605 let progress = progress.clone();
1606 let current = current.clone();
1607 let sem = sem.clone();
1608
1609 tokio::spawn(async move {
1610 let _permit = sem.acquire().await.map_err(|_| {
1611 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1612 })?;
1613
1614 for file_type in &file_types {
1615 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1616 let (file_ext, is_image) = match file_type {
1617 FileType::Image => (
1618 infer::get(&data)
1619 .expect("Failed to identify image file format for sample")
1620 .extension()
1621 .to_string(),
1622 true,
1623 ),
1624 other => (other.file_extension().to_string(), false),
1625 };
1626
1627 // Determine target directory based on sequence membership and
1628 // flatten option
1629 // - flatten=false + sequence_name: dataset/sequence_name/
1630 // - flatten=false + no sequence: dataset/ (root level)
1631 // - flatten=true: dataset/ (all files in output root)
1632 // NOTE: group (train/val/test) is NOT used for directory structure
1633 let sequence_dir = sample
1634 .sequence_name()
1635 .map(|name| sanitize_path_component(name));
1636
1637 let target_dir = if flatten {
1638 output.clone()
1639 } else {
1640 sequence_dir
1641 .as_ref()
1642 .map(|seq| output.join(seq))
1643 .unwrap_or_else(|| output.clone())
1644 };
1645 fs::create_dir_all(&target_dir).await?;
1646
1647 let sanitized_sample_name = sample
1648 .name()
1649 .map(|name| sanitize_path_component(&name))
1650 .unwrap_or_else(|| "unknown".to_string());
1651
1652 let image_name = sample.image_name().map(sanitize_path_component);
1653
1654 // Construct filename with smart prefixing for flatten mode
1655 // When flatten=true and sample belongs to a sequence:
1656 // - Check if filename already starts with "{sequence_name}_"
1657 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1658 // - If yes, use filename as-is (already uniquely named)
1659 let file_name = if is_image {
1660 if let Some(img_name) = image_name {
1661 Client::build_filename(
1662 &img_name,
1663 flatten,
1664 sequence_dir.as_ref(),
1665 sample.frame_number(),
1666 )
1667 } else {
1668 format!("{}.{}", sanitized_sample_name, file_ext)
1669 }
1670 } else {
1671 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1672 Client::build_filename(
1673 &base_name,
1674 flatten,
1675 sequence_dir.as_ref(),
1676 sample.frame_number(),
1677 )
1678 };
1679
1680 let file_path = target_dir.join(&file_name);
1681
1682 let mut file = File::create(&file_path).await?;
1683 file.write_all(&data).await?;
1684 }
1685 }
1686
1687 // Update progress after sample completes
1688 if let Some(progress) = &progress {
1689 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
1690 let _ = progress
1691 .send(Progress {
1692 current: completed,
1693 total,
1694 status: Some("Downloading".to_string()),
1695 })
1696 .await;
1697 }
1698
1699 Ok::<(), Error>(())
1700 })
1701 })
1702 .collect::<Vec<_>>();
1703
1704 join_all(tasks)
1705 .await
1706 .into_iter()
1707 .collect::<Result<Vec<_>, _>>()?
1708 .into_iter()
1709 .collect::<Result<Vec<_>, _>>()?;
1710
1711 Ok(())
1712 }
1713
1714 /// Builds a filename with smart prefixing for flatten mode.
1715 ///
1716 /// When flattening sequences into a single directory, this function ensures
1717 /// unique filenames by checking if the sequence prefix already exists and
1718 /// adding it if necessary.
1719 ///
1720 /// # Logic
1721 ///
1722 /// - If `flatten=false`: returns `base_name` unchanged
1723 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
1724 /// - If `flatten=true` and in sequence:
1725 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
1726 /// unchanged
1727 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
1728 /// `{sequence_name}_{base_name}`
1729 fn build_filename(
1730 base_name: &str,
1731 flatten: bool,
1732 sequence_name: Option<&String>,
1733 frame_number: Option<u32>,
1734 ) -> String {
1735 if !flatten || sequence_name.is_none() {
1736 return base_name.to_string();
1737 }
1738
1739 let seq_name = sequence_name.unwrap();
1740 let prefix = format!("{}_", seq_name);
1741
1742 // Check if already prefixed with sequence name
1743 if base_name.starts_with(&prefix) {
1744 base_name.to_string()
1745 } else {
1746 // Add sequence (and optionally frame) prefix
1747 match frame_number {
1748 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1749 None => format!("{}{}", prefix, base_name),
1750 }
1751 }
1752 }
1753
1754 /// List available annotation sets for the specified dataset.
1755 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1756 pub async fn annotation_sets(
1757 &self,
1758 dataset_id: DatasetID,
1759 ) -> Result<Vec<AnnotationSet>, Error> {
1760 let params = HashMap::from([("dataset_id", dataset_id)]);
1761 self.rpc("annset.list".to_owned(), Some(params)).await
1762 }
1763
1764 /// Create a new annotation set for the specified dataset.
1765 ///
1766 /// # Arguments
1767 ///
1768 /// * `dataset_id` - The ID of the dataset to create the annotation set in
1769 /// * `name` - The name of the new annotation set
1770 /// * `description` - Optional description for the annotation set
1771 ///
1772 /// # Returns
1773 ///
1774 /// Returns the annotation set ID of the newly created annotation set.
1775 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1776 pub async fn create_annotation_set(
1777 &self,
1778 dataset_id: DatasetID,
1779 name: &str,
1780 description: Option<&str>,
1781 ) -> Result<AnnotationSetID, Error> {
1782 #[derive(Serialize)]
1783 struct Params<'a> {
1784 dataset_id: DatasetID,
1785 name: &'a str,
1786 operator: &'a str,
1787 #[serde(skip_serializing_if = "Option::is_none")]
1788 description: Option<&'a str>,
1789 }
1790
1791 #[derive(Deserialize)]
1792 struct CreateAnnotationSetResult {
1793 id: AnnotationSetID,
1794 }
1795
1796 let username = self.username().await?;
1797 let result: CreateAnnotationSetResult = self
1798 .rpc(
1799 "annset.add".to_owned(),
1800 Some(Params {
1801 dataset_id,
1802 name,
1803 operator: &username,
1804 description,
1805 }),
1806 )
1807 .await?;
1808 Ok(result.id)
1809 }
1810
1811 /// Deletes an annotation set by marking it as deleted.
1812 ///
1813 /// # Arguments
1814 ///
1815 /// * `annotation_set_id` - The ID of the annotation set to delete
1816 ///
1817 /// # Returns
1818 ///
1819 /// Returns `Ok(())` if the annotation set was successfully marked as
1820 /// deleted.
1821 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
1822 pub async fn delete_annotation_set(
1823 &self,
1824 annotation_set_id: AnnotationSetID,
1825 ) -> Result<(), Error> {
1826 let params = HashMap::from([("id", annotation_set_id)]);
1827 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1828 Ok(())
1829 }
1830
1831 /// Retrieve the annotation set with the specified ID.
1832 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
1833 pub async fn annotation_set(
1834 &self,
1835 annotation_set_id: AnnotationSetID,
1836 ) -> Result<AnnotationSet, Error> {
1837 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1838 self.rpc("annset.get".to_owned(), Some(params)).await
1839 }
1840
1841 /// Get the annotations for the specified annotation set with the
1842 /// requested annotation types. The annotation types are used to filter
1843 /// the annotations returned. The groups parameter is used to filter for
1844 /// dataset groups (train, val, test). Images which do not have any
1845 /// annotations are also included in the result as long as they are in the
1846 /// requested groups (when specified).
1847 ///
1848 /// The result is a vector of Annotations objects which contain the
1849 /// full dataset along with the annotations for the specified types.
1850 ///
1851 /// # Progress
1852 ///
1853 /// Reports progress with `status: None` as samples are fetched and
1854 /// processed for their annotations. Progress unit is samples processed
1855 /// (not individual annotations).
1856 ///
1857 /// To get the annotations as a DataFrame, use the `samples_dataframe`
1858 /// method instead.
1859 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
1860 pub async fn annotations(
1861 &self,
1862 annotation_set_id: AnnotationSetID,
1863 groups: &[String],
1864 annotation_types: &[AnnotationType],
1865 progress: Option<Sender<Progress>>,
1866 ) -> Result<Vec<Annotation>, Error> {
1867 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1868 let labels = self
1869 .labels(dataset_id)
1870 .await?
1871 .into_iter()
1872 .map(|label| (label.name().to_string(), label.index()))
1873 .collect::<HashMap<_, _>>();
1874 let total = self
1875 .samples_count(
1876 dataset_id,
1877 Some(annotation_set_id),
1878 annotation_types,
1879 groups,
1880 &[],
1881 )
1882 .await?
1883 .total as usize;
1884
1885 if total == 0 {
1886 return Ok(vec![]);
1887 }
1888
1889 let context = FetchContext {
1890 dataset_id,
1891 annotation_set_id: Some(annotation_set_id),
1892 groups,
1893 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1894 labels: &labels,
1895 };
1896
1897 self.fetch_annotations_paginated(context, total, progress)
1898 .await
1899 }
1900
1901 async fn fetch_annotations_paginated(
1902 &self,
1903 context: FetchContext<'_>,
1904 total: usize,
1905 progress: Option<Sender<Progress>>,
1906 ) -> Result<Vec<Annotation>, Error> {
1907 let mut annotations = vec![];
1908 let mut continue_token: Option<String> = None;
1909 let mut current = 0;
1910
1911 loop {
1912 let params = SamplesListParams {
1913 dataset_id: context.dataset_id,
1914 annotation_set_id: context.annotation_set_id,
1915 types: context.types.clone(),
1916 group_names: context.groups.to_vec(),
1917 continue_token,
1918 };
1919
1920 let result: SamplesListResult =
1921 self.rpc("samples.list".to_owned(), Some(params)).await?;
1922 current += result.samples.len();
1923 continue_token = result.continue_token;
1924
1925 if result.samples.is_empty() {
1926 break;
1927 }
1928
1929 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1930
1931 if let Some(progress) = &progress {
1932 let _ = progress
1933 .send(Progress {
1934 current,
1935 total,
1936 status: None,
1937 })
1938 .await;
1939 }
1940
1941 match &continue_token {
1942 Some(token) if !token.is_empty() => continue,
1943 _ => break,
1944 }
1945 }
1946
1947 drop(progress);
1948 Ok(annotations)
1949 }
1950
1951 fn process_sample_annotations(
1952 &self,
1953 samples: &[Sample],
1954 labels: &HashMap<String, u64>,
1955 annotations: &mut Vec<Annotation>,
1956 ) {
1957 for sample in samples {
1958 if sample.annotations().is_empty() {
1959 let mut annotation = Annotation::new();
1960 annotation.set_sample_id(sample.id());
1961 annotation.set_name(sample.name());
1962 annotation.set_sequence_name(sample.sequence_name().cloned());
1963 annotation.set_frame_number(sample.frame_number());
1964 annotation.set_group(sample.group().cloned());
1965 annotations.push(annotation);
1966 continue;
1967 }
1968
1969 for annotation in sample.annotations() {
1970 let mut annotation = annotation.clone();
1971 annotation.set_sample_id(sample.id());
1972 annotation.set_name(sample.name());
1973 annotation.set_sequence_name(sample.sequence_name().cloned());
1974 annotation.set_frame_number(sample.frame_number());
1975 annotation.set_group(sample.group().cloned());
1976 Self::set_label_index_from_map(&mut annotation, labels);
1977 annotations.push(annotation);
1978 }
1979 }
1980 }
1981
1982 /// Delete annotations in bulk from specified samples.
1983 ///
1984 /// This method calls the `annotation.bulk.del` API to efficiently remove
1985 /// annotations from multiple samples at once. Useful for clearing
1986 /// annotations before re-importing updated data.
1987 ///
1988 /// # Arguments
1989 /// * `annotation_set_id` - The annotation set containing the annotations
1990 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
1991 /// for masks
1992 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
1993 ///
1994 /// # Example
1995 /// ```no_run
1996 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
1997 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1998 /// # let client = Client::new()?.with_login("user", "pass").await?;
1999 /// let annotation_set_id = AnnotationSetID::from(123);
2000 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2001 ///
2002 /// client
2003 /// .delete_annotations_bulk(
2004 /// annotation_set_id,
2005 /// &["box".to_string(), "seg".to_string()],
2006 /// &sample_ids,
2007 /// )
2008 /// .await?;
2009 /// # Ok(())
2010 /// # }
2011 /// ```
2012 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2013 pub async fn delete_annotations_bulk(
2014 &self,
2015 annotation_set_id: AnnotationSetID,
2016 annotation_types: &[String],
2017 sample_ids: &[SampleID],
2018 ) -> Result<(), Error> {
2019 use crate::api::AnnotationBulkDeleteParams;
2020
2021 let params = AnnotationBulkDeleteParams {
2022 annotation_set_id: annotation_set_id.into(),
2023 annotation_types: annotation_types.to_vec(),
2024 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2025 delete_all: None,
2026 };
2027
2028 let _: String = self
2029 .rpc("annotation.bulk.del".to_owned(), Some(params))
2030 .await?;
2031 Ok(())
2032 }
2033
2034 /// Add annotations in bulk.
2035 ///
2036 /// This method calls the `annotation.add_bulk` API to efficiently add
2037 /// multiple annotations at once. The annotations must be in server format
2038 /// with image_id references.
2039 ///
2040 /// # Arguments
2041 /// * `annotation_set_id` - The annotation set to add annotations to
2042 /// * `annotations` - Vector of server-format annotations to add
2043 ///
2044 /// # Returns
2045 /// Vector of created annotation records from the server.
2046 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2047 pub async fn add_annotations_bulk(
2048 &self,
2049 annotation_set_id: AnnotationSetID,
2050 annotations: Vec<crate::api::ServerAnnotation>,
2051 ) -> Result<Vec<serde_json::Value>, Error> {
2052 use crate::api::AnnotationAddBulkParams;
2053
2054 let params = AnnotationAddBulkParams {
2055 annotation_set_id: annotation_set_id.into(),
2056 annotations,
2057 };
2058
2059 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2060 .await
2061 }
2062
2063 /// Helper to parse frame number from image_name when sequence_name is
2064 /// present. This ensures frame_number is always derived from the image
2065 /// filename, not from the server's frame_number field (which may be
2066 /// inconsistent).
2067 ///
2068 /// Returns Some(frame_number) if sequence_name is present and frame can be
2069 /// parsed, otherwise None.
2070 fn parse_frame_from_image_name(
2071 image_name: Option<&String>,
2072 sequence_name: Option<&String>,
2073 ) -> Option<u32> {
2074 use std::path::Path;
2075
2076 let sequence = sequence_name?;
2077 let name = image_name?;
2078
2079 // Extract stem (remove extension)
2080 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2081
2082 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2083 stem.strip_prefix(sequence)
2084 .and_then(|suffix| suffix.strip_prefix('_'))
2085 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2086 }
2087
2088 /// Helper to set label index from a label map
2089 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2090 if let Some(label) = annotation.label() {
2091 annotation.set_label_index(Some(labels[label.as_str()]));
2092 }
2093 }
2094
2095 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2096 pub async fn samples_count(
2097 &self,
2098 dataset_id: DatasetID,
2099 annotation_set_id: Option<AnnotationSetID>,
2100 annotation_types: &[AnnotationType],
2101 groups: &[String],
2102 types: &[FileType],
2103 ) -> Result<SamplesCountResult, Error> {
2104 // Use server type names for API calls (e.g., "box" instead of "box2d")
2105 let types = annotation_types
2106 .iter()
2107 .map(|t| t.as_server_type().to_string())
2108 .chain(types.iter().map(|t| t.to_string()))
2109 .collect::<Vec<_>>();
2110
2111 let params = SamplesListParams {
2112 dataset_id,
2113 annotation_set_id,
2114 group_names: groups.to_vec(),
2115 types,
2116 continue_token: None,
2117 };
2118
2119 self.rpc("samples.count".to_owned(), Some(params)).await
2120 }
2121
2122 /// Fetches samples from a dataset with optional annotation and file type
2123 /// filters.
2124 ///
2125 /// # Arguments
2126 ///
2127 /// * `dataset_id` - The dataset to fetch samples from
2128 /// * `annotation_set_id` - Optional annotation set to include annotations
2129 /// from
2130 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2131 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2132 /// * `types` - File types to include metadata for
2133 /// * `progress` - Optional channel for progress updates
2134 ///
2135 /// # Progress
2136 ///
2137 /// Reports progress with `status: None` as samples are fetched from the
2138 /// server in paginated batches. Progress unit is samples fetched.
2139 ///
2140 /// # Returns
2141 ///
2142 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2143 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2144 pub async fn samples(
2145 &self,
2146 dataset_id: DatasetID,
2147 annotation_set_id: Option<AnnotationSetID>,
2148 annotation_types: &[AnnotationType],
2149 groups: &[String],
2150 types: &[FileType],
2151 progress: Option<Sender<Progress>>,
2152 ) -> Result<Vec<Sample>, Error> {
2153 // Use server type names for API calls (e.g., "box" instead of "box2d")
2154 let types_vec = annotation_types
2155 .iter()
2156 .map(|t| t.as_server_type().to_string())
2157 .chain(types.iter().map(|t| t.to_string()))
2158 .collect::<Vec<_>>();
2159 let labels = self
2160 .labels(dataset_id)
2161 .await?
2162 .into_iter()
2163 .map(|label| (label.name().to_string(), label.index()))
2164 .collect::<HashMap<_, _>>();
2165 let total = self
2166 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2167 .await?
2168 .total as usize;
2169
2170 if total == 0 {
2171 return Ok(vec![]);
2172 }
2173
2174 let context = FetchContext {
2175 dataset_id,
2176 annotation_set_id,
2177 groups,
2178 types: types_vec,
2179 labels: &labels,
2180 };
2181
2182 self.fetch_samples_paginated(context, total, progress).await
2183 }
2184
2185 /// Get all sample names in a dataset.
2186 ///
2187 /// This is an efficient method for checking which samples already exist,
2188 /// useful for resuming interrupted imports. It only retrieves sample names
2189 /// without loading full annotation data.
2190 ///
2191 /// # Arguments
2192 ///
2193 /// * `dataset_id` - The dataset to query
2194 /// * `groups` - Optional group filter (empty = all groups)
2195 /// * `progress` - Optional progress channel
2196 ///
2197 /// # Progress
2198 ///
2199 /// Reports progress with `status: None` as sample names are fetched from
2200 /// the server in paginated batches. Progress unit is samples fetched.
2201 ///
2202 /// # Returns
2203 ///
2204 /// A HashSet of sample names (image_name field) that exist in the dataset.
2205 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2206 pub async fn sample_names(
2207 &self,
2208 dataset_id: DatasetID,
2209 groups: &[String],
2210 progress: Option<Sender<Progress>>,
2211 ) -> Result<std::collections::HashSet<String>, Error> {
2212 use std::collections::HashSet;
2213
2214 let total = self
2215 .samples_count(dataset_id, None, &[], groups, &[])
2216 .await?
2217 .total as usize;
2218
2219 if total == 0 {
2220 return Ok(HashSet::new());
2221 }
2222
2223 let mut names = HashSet::with_capacity(total);
2224 let mut continue_token: Option<String> = None;
2225 let mut current = 0;
2226
2227 loop {
2228 let params = SamplesListParams {
2229 dataset_id,
2230 annotation_set_id: None,
2231 types: vec![], // No type filter - we just want names
2232 group_names: groups.to_vec(),
2233 continue_token: continue_token.clone(),
2234 };
2235
2236 let result: SamplesListResult =
2237 self.rpc("samples.list".to_owned(), Some(params)).await?;
2238 current += result.samples.len();
2239 continue_token = result.continue_token;
2240
2241 if result.samples.is_empty() {
2242 break;
2243 }
2244
2245 // Extract sample names (normalized without extension)
2246 for sample in result.samples {
2247 if let Some(name) = sample.name() {
2248 names.insert(name);
2249 }
2250 }
2251
2252 if let Some(ref p) = progress {
2253 let _ = p
2254 .send(Progress {
2255 current,
2256 total,
2257 status: None,
2258 })
2259 .await;
2260 }
2261
2262 match &continue_token {
2263 Some(token) if !token.is_empty() => continue,
2264 _ => break,
2265 }
2266 }
2267
2268 Ok(names)
2269 }
2270
2271 async fn fetch_samples_paginated(
2272 &self,
2273 context: FetchContext<'_>,
2274 total: usize,
2275 progress: Option<Sender<Progress>>,
2276 ) -> Result<Vec<Sample>, Error> {
2277 let mut samples = vec![];
2278 let mut continue_token: Option<String> = None;
2279 let mut current = 0;
2280
2281 loop {
2282 let params = SamplesListParams {
2283 dataset_id: context.dataset_id,
2284 annotation_set_id: context.annotation_set_id,
2285 types: context.types.clone(),
2286 group_names: context.groups.to_vec(),
2287 continue_token: continue_token.clone(),
2288 };
2289
2290 let result: SamplesListResult =
2291 self.rpc("samples.list".to_owned(), Some(params)).await?;
2292 current += result.samples.len();
2293 continue_token = result.continue_token;
2294
2295 if result.samples.is_empty() {
2296 break;
2297 }
2298
2299 samples.append(
2300 &mut result
2301 .samples
2302 .into_iter()
2303 .map(|s| {
2304 // Use server's frame_number if valid (>= 0 after deserialization)
2305 // Otherwise parse from image_name as fallback
2306 // This ensures we respect explicit frame_number from uploads
2307 // while still handling legacy data that only has filename encoding
2308 let frame_number = s.frame_number.or_else(|| {
2309 Self::parse_frame_from_image_name(
2310 s.image_name.as_ref(),
2311 s.sequence_name.as_ref(),
2312 )
2313 });
2314
2315 let mut anns = s.annotations().to_vec();
2316 for ann in &mut anns {
2317 // Set annotation fields from parent sample
2318 ann.set_name(s.name());
2319 ann.set_group(s.group().cloned());
2320 ann.set_sequence_name(s.sequence_name().cloned());
2321 ann.set_frame_number(frame_number);
2322 Self::set_label_index_from_map(ann, context.labels);
2323 }
2324 s.with_annotations(anns).with_frame_number(frame_number)
2325 })
2326 .collect::<Vec<_>>(),
2327 );
2328
2329 if let Some(progress) = &progress {
2330 let _ = progress
2331 .send(Progress {
2332 current,
2333 total,
2334 status: None,
2335 })
2336 .await;
2337 }
2338
2339 match &continue_token {
2340 Some(token) if !token.is_empty() => continue,
2341 _ => break,
2342 }
2343 }
2344
2345 drop(progress);
2346 Ok(samples)
2347 }
2348
2349 /// Populates (imports) samples into a dataset using the `samples.populate2`
2350 /// API.
2351 ///
2352 /// This method creates new samples in the specified dataset, optionally
2353 /// with annotations and sensor data files. For each sample, the `files`
2354 /// field is checked for local file paths. If a filename is a valid path
2355 /// to an existing file, the file will be automatically uploaded to S3
2356 /// using presigned URLs returned by the server. The filename in the
2357 /// request is replaced with the basename (path removed) before sending
2358 /// to the server.
2359 ///
2360 /// # Important Notes
2361 ///
2362 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2363 /// annotations. Without it, the server will accept the request but will
2364 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2365 /// available annotation sets for a dataset, or create a new one via the
2366 /// Studio UI.
2367 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2368 /// boxes. Divide pixel coordinates by image width/height before creating
2369 /// [`Box2d`](crate::Box2d) annotations.
2370 /// - **Files are uploaded automatically** when the filename is a valid
2371 /// local path. The method will replace the full path with just the
2372 /// basename before sending to the server.
2373 /// - **Image dimensions are extracted automatically** for image files using
2374 /// the `imagesize` crate. The width/height are sent to the server, but
2375 /// note that the server currently doesn't return these fields when
2376 /// fetching samples back.
2377 /// - **UUIDs are generated automatically** if not provided. If you need
2378 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
2379 /// that the server doesn't currently return UUIDs in sample queries.
2380 ///
2381 /// # Arguments
2382 ///
2383 /// * `dataset_id` - The ID of the dataset to populate
2384 /// * `annotation_set_id` - **Required** if samples contain annotations,
2385 /// otherwise they will be ignored. Query with
2386 /// [`Client::annotation_sets`].
2387 /// * `samples` - Vector of samples to import with metadata and file
2388 /// references. For files, use the full local path - it will be uploaded
2389 /// automatically. UUIDs and image dimensions will be
2390 /// auto-generated/extracted if not provided.
2391 /// * `progress` - Optional channel for progress updates
2392 ///
2393 /// # Progress
2394 ///
2395 /// Reports progress with `status: None` as each sample's files are
2396 /// uploaded. Progress unit is samples (not individual files). Each
2397 /// sample may contain multiple files (image, lidar, radar, etc.) which
2398 /// are all uploaded before the sample is counted as complete.
2399 ///
2400 /// # Returns
2401 ///
2402 /// Returns the API result with sample UUIDs and upload status.
2403 ///
2404 /// # Example
2405 ///
2406 /// ```no_run
2407 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2408 ///
2409 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2410 /// # let client = Client::new()?.with_login("user", "pass").await?;
2411 /// # let dataset_id = DatasetID::from(1);
2412 /// // Query available annotation sets for the dataset
2413 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2414 /// let annotation_set_id = annotation_sets
2415 /// .first()
2416 /// .ok_or_else(|| {
2417 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2418 /// })?
2419 /// .id();
2420 ///
2421 /// // Create sample with annotation (UUID will be auto-generated)
2422 /// let mut sample = Sample::new();
2423 /// sample.width = Some(1920);
2424 /// sample.height = Some(1080);
2425 /// sample.group = Some("train".to_string());
2426 ///
2427 /// // Add file - use full path to local file, it will be uploaded automatically
2428 /// sample.files = vec![SampleFile::with_filename(
2429 /// "image".to_string(),
2430 /// "/path/to/image.jpg".to_string(),
2431 /// )];
2432 ///
2433 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2434 /// let mut annotation = Annotation::new();
2435 /// annotation.set_label(Some("person".to_string()));
2436 /// // Normalize pixel coordinates by dividing by image dimensions
2437 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2438 /// annotation.set_box2d(Some(bbox));
2439 /// sample.annotations = vec![annotation];
2440 ///
2441 /// // Populate with annotation_set_id (REQUIRED for annotations)
2442 /// let result = client
2443 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2444 /// .await?;
2445 /// # Ok(())
2446 /// # }
2447 /// ```
2448 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2449 pub async fn populate_samples(
2450 &self,
2451 dataset_id: DatasetID,
2452 annotation_set_id: Option<AnnotationSetID>,
2453 samples: Vec<Sample>,
2454 progress: Option<Sender<Progress>>,
2455 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2456 self.populate_samples_with_concurrency(
2457 dataset_id,
2458 annotation_set_id,
2459 samples,
2460 progress,
2461 None,
2462 )
2463 .await
2464 }
2465
2466 /// Populate samples with custom upload concurrency.
2467 ///
2468 /// Same as [`populate_samples`](Self::populate_samples) but allows
2469 /// specifying the maximum number of concurrent file uploads. Use this
2470 /// for bulk imports where higher concurrency can significantly reduce
2471 /// upload time.
2472 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2473 pub async fn populate_samples_with_concurrency(
2474 &self,
2475 dataset_id: DatasetID,
2476 annotation_set_id: Option<AnnotationSetID>,
2477 samples: Vec<Sample>,
2478 progress: Option<Sender<Progress>>,
2479 concurrency: Option<usize>,
2480 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2481 use crate::api::SamplesPopulateParams;
2482
2483 // Track which files need to be uploaded
2484 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
2485
2486 // Process samples to detect local files and generate UUIDs
2487 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
2488
2489 let has_files_to_upload = !files_to_upload.is_empty();
2490
2491 // Call populate API with presigned_urls=true if we have files to upload
2492 let params = SamplesPopulateParams {
2493 dataset_id,
2494 annotation_set_id,
2495 presigned_urls: Some(has_files_to_upload),
2496 samples,
2497 };
2498
2499 let results: Vec<crate::SamplesPopulateResult> = self
2500 .rpc("samples.populate2".to_owned(), Some(params))
2501 .await?;
2502
2503 // Upload files if we have any
2504 if has_files_to_upload {
2505 self.upload_sample_files(&results, files_to_upload, progress, concurrency)
2506 .await?;
2507 }
2508
2509 Ok(results)
2510 }
2511
2512 fn prepare_samples_for_upload(
2513 &self,
2514 samples: Vec<Sample>,
2515 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2516 ) -> Result<Vec<Sample>, Error> {
2517 Ok(samples
2518 .into_iter()
2519 .map(|mut sample| {
2520 // Generate UUID if not provided
2521 if sample.uuid.is_none() {
2522 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2523 }
2524
2525 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2526
2527 // Process files: detect local paths and queue for upload
2528 let files_copy = sample.files.clone();
2529 let updated_files: Vec<crate::SampleFile> = files_copy
2530 .iter()
2531 .map(|file| {
2532 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2533 })
2534 .collect();
2535
2536 sample.files = updated_files;
2537 sample
2538 })
2539 .collect())
2540 }
2541
2542 fn process_sample_file(
2543 &self,
2544 file: &crate::SampleFile,
2545 sample_uuid: &str,
2546 sample: &mut Sample,
2547 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2548 ) -> crate::SampleFile {
2549 use std::path::Path;
2550
2551 // Handle files with raw bytes (e.g., from ZIP archives)
2552 if let Some(bytes) = file.bytes()
2553 && let Some(filename) = file.filename()
2554 {
2555 // For image files with bytes, try to extract dimensions if not already set
2556 if file.file_type() == "image"
2557 && (sample.width.is_none() || sample.height.is_none())
2558 && let Ok(size) = imagesize::blob_size(bytes)
2559 {
2560 sample.width = Some(size.width as u32);
2561 sample.height = Some(size.height as u32);
2562 }
2563
2564 // Store the bytes for later upload
2565 files_to_upload.push((
2566 sample_uuid.to_string(),
2567 file.file_type().to_string(),
2568 FileSource::Bytes(bytes.to_vec()),
2569 filename.to_string(),
2570 ));
2571
2572 // Return SampleFile with just the filename
2573 return crate::SampleFile::with_filename(
2574 file.file_type().to_string(),
2575 filename.to_string(),
2576 );
2577 }
2578
2579 // Handle files with local paths
2580 if let Some(filename) = file.filename() {
2581 let path = Path::new(filename);
2582
2583 // Check if this is a valid local file path
2584 if path.exists()
2585 && path.is_file()
2586 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2587 {
2588 // For image files, try to extract dimensions if not already set
2589 if file.file_type() == "image"
2590 && (sample.width.is_none() || sample.height.is_none())
2591 && let Ok(size) = imagesize::size(path)
2592 {
2593 sample.width = Some(size.width as u32);
2594 sample.height = Some(size.height as u32);
2595 }
2596
2597 // Store the full path for later upload
2598 files_to_upload.push((
2599 sample_uuid.to_string(),
2600 file.file_type().to_string(),
2601 FileSource::Path(path.to_path_buf()),
2602 basename.to_string(),
2603 ));
2604
2605 // Return SampleFile with just the basename
2606 return crate::SampleFile::with_filename(
2607 file.file_type().to_string(),
2608 basename.to_string(),
2609 );
2610 }
2611 }
2612 // Return the file unchanged if not a local path
2613 file.clone()
2614 }
2615
2616 async fn upload_sample_files(
2617 &self,
2618 results: &[crate::SamplesPopulateResult],
2619 files_to_upload: Vec<(String, String, FileSource, String)>,
2620 progress: Option<Sender<Progress>>,
2621 concurrency: Option<usize>,
2622 ) -> Result<(), Error> {
2623 // Build a map from (sample_uuid, basename) -> file source
2624 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
2625 for (uuid, _file_type, source, basename) in files_to_upload {
2626 upload_map.insert((uuid, basename), source);
2627 }
2628
2629 let http = self.http.clone();
2630
2631 // Extract the data we need for parallel upload
2632 let upload_tasks: Vec<_> = results
2633 .iter()
2634 .map(|result| (result.uuid.clone(), result.urls.clone()))
2635 .collect();
2636
2637 parallel_foreach_items(
2638 upload_tasks,
2639 progress.clone(),
2640 concurrency,
2641 move |(uuid, urls)| {
2642 let http = http.clone();
2643 let upload_map = upload_map.clone();
2644
2645 async move {
2646 // Upload all files for this sample
2647 for url_info in &urls {
2648 if let Some(source) =
2649 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2650 {
2651 match source {
2652 FileSource::Path(path) => {
2653 upload_file_to_presigned_url(
2654 http.clone(),
2655 &url_info.url,
2656 path.clone(),
2657 )
2658 .await?;
2659 }
2660 FileSource::Bytes(bytes) => {
2661 upload_bytes_to_presigned_url(
2662 http.clone(),
2663 &url_info.url,
2664 bytes.clone(),
2665 &url_info.filename,
2666 )
2667 .await?;
2668 }
2669 }
2670 }
2671 }
2672
2673 Ok(())
2674 }
2675 },
2676 )
2677 .await
2678 }
2679
2680 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2681 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
2682 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
2683 if !url.starts_with("http://") && !url.starts_with("https://") {
2684 return Err(Error::InvalidParameters(format!(
2685 "Invalid URL (must be absolute): {}",
2686 url
2687 )));
2688 }
2689
2690 // Uses default 120s timeout from client
2691 let resp = self.http.get(url).send().await?;
2692
2693 if !resp.status().is_success() {
2694 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2695 }
2696
2697 let bytes = resp.bytes().await?;
2698 Ok(bytes.to_vec())
2699 }
2700
2701 /// Get samples as a DataFrame with complete 2025.10 schema.
2702 ///
2703 /// This is the recommended method for obtaining dataset annotations in
2704 /// DataFrame format. It includes all sample metadata (size, location,
2705 /// pose, degradation) as optional columns.
2706 ///
2707 /// # Arguments
2708 ///
2709 /// * `dataset_id` - Dataset identifier
2710 /// * `annotation_set_id` - Optional annotation set filter
2711 /// * `groups` - Dataset groups to include (train, val, test)
2712 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2713 /// * `progress` - Optional progress callback
2714 ///
2715 /// # Progress
2716 ///
2717 /// Reports progress with `status: None` as samples are fetched from the
2718 /// server in paginated batches. Progress unit is samples fetched. This
2719 /// method delegates to [`samples()`](Self::samples) and shares its
2720 /// progress behavior.
2721 ///
2722 /// # Example
2723 ///
2724 /// ```rust,no_run
2725 /// use edgefirst_client::Client;
2726 ///
2727 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2728 /// # let client = Client::new()?;
2729 /// # let dataset_id = 1.into();
2730 /// # let annotation_set_id = 1.into();
2731 /// let df = client
2732 /// .samples_dataframe(
2733 /// dataset_id,
2734 /// Some(annotation_set_id),
2735 /// &["train".to_string()],
2736 /// &[],
2737 /// None,
2738 /// )
2739 /// .await?;
2740 /// println!("DataFrame shape: {:?}", df.shape());
2741 /// # Ok(())
2742 /// # }
2743 /// ```
2744 #[cfg(feature = "polars")]
2745 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2746 pub async fn samples_dataframe(
2747 &self,
2748 dataset_id: DatasetID,
2749 annotation_set_id: Option<AnnotationSetID>,
2750 groups: &[String],
2751 types: &[AnnotationType],
2752 progress: Option<Sender<Progress>>,
2753 ) -> Result<DataFrame, Error> {
2754 use crate::dataset::samples_dataframe;
2755
2756 let samples = self
2757 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
2758 .await?;
2759 samples_dataframe(&samples)
2760 }
2761
2762 /// List available snapshots. If a name is provided, only snapshots
2763 /// containing that name are returned.
2764 ///
2765 /// Results are sorted by match quality: exact matches first, then
2766 /// case-insensitive exact matches, then shorter descriptions (more
2767 /// specific), then alphabetically.
2768 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2769 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
2770 let snapshots: Vec<Snapshot> = self
2771 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
2772 .await?;
2773 if let Some(name) = name {
2774 Ok(filter_and_sort_by_name(snapshots, name, |s| {
2775 s.description()
2776 }))
2777 } else {
2778 Ok(snapshots)
2779 }
2780 }
2781
2782 /// Get the snapshot with the specified id.
2783 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
2784 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
2785 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2786 self.rpc("snapshots.get".to_owned(), Some(params)).await
2787 }
2788
2789 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
2790 ///
2791 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
2792 /// pairs) that serve two primary purposes:
2793 ///
2794 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
2795 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
2796 /// restored with AGTG (Automatic Ground Truth Generation) and optional
2797 /// auto-depth processing.
2798 ///
2799 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
2800 /// migration between EdgeFirst Studio instances using the create →
2801 /// download → upload → restore workflow.
2802 ///
2803 /// Large files are automatically chunked into 100MB parts and uploaded
2804 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
2805 /// is streamed without loading into memory, maintaining constant memory
2806 /// usage.
2807 ///
2808 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2809 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
2810 /// better for large files to avoid timeout issues. Higher values (16-32)
2811 /// are better for many small files.
2812 ///
2813 /// # Arguments
2814 ///
2815 /// * `path` - Local file path to MCAP file or directory containing
2816 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
2817 /// * `progress` - Optional channel to receive upload progress updates
2818 ///
2819 /// # Progress
2820 ///
2821 /// Reports progress with `status: None` as file data is uploaded. Progress
2822 /// unit is bytes uploaded. For single files, total is the file size. For
2823 /// directories, total is the combined size of all files.
2824 ///
2825 /// # Returns
2826 ///
2827 /// Returns a `Snapshot` object with ID, description, status, path, and
2828 /// creation timestamp on success.
2829 ///
2830 /// # Errors
2831 ///
2832 /// Returns an error if:
2833 /// * Path doesn't exist or contains invalid UTF-8
2834 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
2835 /// * Upload fails or network error occurs
2836 /// * Server rejects the snapshot
2837 ///
2838 /// # Example
2839 ///
2840 /// ```no_run
2841 /// # use edgefirst_client::{Client, Progress};
2842 /// # use tokio::sync::mpsc;
2843 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2844 /// let client = Client::new()?.with_token_path(None)?;
2845 ///
2846 /// // Upload MCAP file with progress tracking
2847 /// let (tx, mut rx) = mpsc::channel(1);
2848 /// tokio::spawn(async move {
2849 /// while let Some(Progress {
2850 /// current,
2851 /// total,
2852 /// status,
2853 /// }) = rx.recv().await
2854 /// {
2855 /// println!(
2856 /// "{}: {}/{} bytes ({:.1}%)",
2857 /// status.as_deref().unwrap_or("Upload"),
2858 /// current,
2859 /// total,
2860 /// (current as f64 / total as f64) * 100.0
2861 /// );
2862 /// }
2863 /// });
2864 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
2865 /// println!("Created snapshot: {:?}", snapshot.id());
2866 ///
2867 /// // Upload dataset directory (no progress)
2868 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
2869 /// # Ok(())
2870 /// # }
2871 /// ```
2872 ///
2873 /// # See Also
2874 ///
2875 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2876 /// dataset
2877 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2878 /// data
2879 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2880 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2881 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
2882 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
2883 pub async fn create_snapshot(
2884 &self,
2885 path: &str,
2886 progress: Option<Sender<Progress>>,
2887 ) -> Result<Snapshot, Error> {
2888 let path = Path::new(path);
2889
2890 if path.is_dir() {
2891 let path_str = path.to_str().ok_or_else(|| {
2892 Error::IoError(std::io::Error::new(
2893 std::io::ErrorKind::InvalidInput,
2894 "Path contains invalid UTF-8",
2895 ))
2896 })?;
2897 return self.create_snapshot_folder(path_str, progress).await;
2898 }
2899
2900 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2901 Error::IoError(std::io::Error::new(
2902 std::io::ErrorKind::InvalidInput,
2903 "Invalid filename",
2904 ))
2905 })?;
2906 let total = path.metadata()?.len() as usize;
2907 let current = Arc::new(AtomicUsize::new(0));
2908
2909 if let Some(progress) = &progress {
2910 let _ = progress
2911 .send(Progress {
2912 current: 0,
2913 total,
2914 status: None,
2915 })
2916 .await;
2917 }
2918
2919 let params = SnapshotCreateMultipartParams {
2920 snapshot_name: name.to_owned(),
2921 keys: vec![name.to_owned()],
2922 file_sizes: vec![total],
2923 snapshot_type: None,
2924 };
2925 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2926 .rpc(
2927 "snapshots.create_upload_url_multipart".to_owned(),
2928 Some(params),
2929 )
2930 .await?;
2931
2932 let snapshot_id = match multipart.get("snapshot_id") {
2933 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2934 _ => return Err(Error::InvalidResponse),
2935 };
2936
2937 let snapshot = self.snapshot(snapshot_id).await?;
2938 let part_prefix = snapshot
2939 .path()
2940 .split("::/")
2941 .last()
2942 .ok_or(Error::InvalidResponse)?
2943 .to_owned();
2944 let part_key = format!("{}/{}", part_prefix, name);
2945 let mut part = match multipart.get(&part_key) {
2946 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2947 _ => return Err(Error::InvalidResponse),
2948 }
2949 .clone();
2950 part.key = Some(part_key);
2951
2952 let params = upload_multipart(
2953 self.upload_http.clone(),
2954 part.clone(),
2955 path.to_path_buf(),
2956 total,
2957 current,
2958 progress.clone(),
2959 )
2960 .await?;
2961
2962 let complete: String = self
2963 .rpc(
2964 "snapshots.complete_multipart_upload".to_owned(),
2965 Some(params),
2966 )
2967 .await?;
2968 debug!("Snapshot Multipart Complete: {:?}", complete);
2969
2970 let params: SnapshotStatusParams = SnapshotStatusParams {
2971 snapshot_id,
2972 status: "available".to_owned(),
2973 };
2974 let _: SnapshotStatusResult = self
2975 .rpc("snapshots.update".to_owned(), Some(params))
2976 .await?;
2977
2978 if let Some(progress) = progress {
2979 drop(progress);
2980 }
2981
2982 self.snapshot(snapshot_id).await
2983 }
2984
2985 async fn create_snapshot_folder(
2986 &self,
2987 path: &str,
2988 progress: Option<Sender<Progress>>,
2989 ) -> Result<Snapshot, Error> {
2990 let path = Path::new(path);
2991 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2992 Error::IoError(std::io::Error::new(
2993 std::io::ErrorKind::InvalidInput,
2994 "Invalid directory name",
2995 ))
2996 })?;
2997
2998 let files = WalkDir::new(path)
2999 .into_iter()
3000 .filter_map(|entry| entry.ok())
3001 .filter(|entry| entry.file_type().is_file())
3002 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3003 .collect::<Vec<_>>();
3004
3005 let total: usize = files
3006 .iter()
3007 .filter_map(|file| path.join(file).metadata().ok())
3008 .map(|metadata| metadata.len() as usize)
3009 .sum();
3010 let current = Arc::new(AtomicUsize::new(0));
3011
3012 if let Some(progress) = &progress {
3013 let _ = progress
3014 .send(Progress {
3015 current: 0,
3016 total,
3017 status: None,
3018 })
3019 .await;
3020 }
3021
3022 let keys = files
3023 .iter()
3024 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3025 .collect::<Vec<_>>();
3026 let file_sizes = files
3027 .iter()
3028 .filter_map(|key| path.join(key).metadata().ok())
3029 .map(|metadata| metadata.len() as usize)
3030 .collect::<Vec<_>>();
3031
3032 let params = SnapshotCreateMultipartParams {
3033 snapshot_name: name.to_owned(),
3034 keys,
3035 file_sizes,
3036 snapshot_type: None,
3037 };
3038
3039 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3040 .rpc(
3041 "snapshots.create_upload_url_multipart".to_owned(),
3042 Some(params),
3043 )
3044 .await?;
3045
3046 let snapshot_id = match multipart.get("snapshot_id") {
3047 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3048 _ => return Err(Error::InvalidResponse),
3049 };
3050
3051 let snapshot = self.snapshot(snapshot_id).await?;
3052 let part_prefix = snapshot
3053 .path()
3054 .split("::/")
3055 .last()
3056 .ok_or(Error::InvalidResponse)?
3057 .to_owned();
3058
3059 for file in files {
3060 let file_str = file.to_str().ok_or_else(|| {
3061 Error::IoError(std::io::Error::new(
3062 std::io::ErrorKind::InvalidInput,
3063 "File path contains invalid UTF-8",
3064 ))
3065 })?;
3066 let part_key = format!("{}/{}", part_prefix, file_str);
3067 let mut part = match multipart.get(&part_key) {
3068 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3069 _ => return Err(Error::InvalidResponse),
3070 }
3071 .clone();
3072 part.key = Some(part_key);
3073
3074 let params = upload_multipart(
3075 self.upload_http.clone(),
3076 part.clone(),
3077 path.join(file),
3078 total,
3079 current.clone(),
3080 progress.clone(),
3081 )
3082 .await?;
3083
3084 let complete: String = self
3085 .rpc(
3086 "snapshots.complete_multipart_upload".to_owned(),
3087 Some(params),
3088 )
3089 .await?;
3090 debug!("Snapshot Part Complete: {:?}", complete);
3091 }
3092
3093 let params = SnapshotStatusParams {
3094 snapshot_id,
3095 status: "available".to_owned(),
3096 };
3097 let _: SnapshotStatusResult = self
3098 .rpc("snapshots.update".to_owned(), Some(params))
3099 .await?;
3100
3101 if let Some(progress) = progress {
3102 drop(progress);
3103 }
3104
3105 self.snapshot(snapshot_id).await
3106 }
3107
3108 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3109 ///
3110 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3111 /// This format is the native EdgeFirst Dataset Format used for efficient
3112 /// dataset storage and transfer.
3113 ///
3114 /// # Arguments
3115 ///
3116 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3117 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3118 /// * `description` - Optional description for the snapshot
3119 /// * `progress` - Optional progress channel for upload tracking
3120 ///
3121 /// # File Requirements
3122 ///
3123 /// - Arrow file must have `.arrow` extension
3124 /// - ZIP file must have `.zip` extension
3125 /// - Both files must exist and be readable
3126 ///
3127 /// # Example
3128 ///
3129 /// ```no_run
3130 /// # use edgefirst_client::Client;
3131 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3132 /// let client = Client::new()?.with_token_path(None)?;
3133 ///
3134 /// let snapshot = client
3135 /// .create_snapshot_edgefirst_format(
3136 /// "dataset.arrow",
3137 /// "dataset.zip",
3138 /// Some("My Dataset Snapshot"),
3139 /// None,
3140 /// )
3141 /// .await?;
3142 /// println!("Created snapshot: {}", snapshot.id());
3143 /// # Ok(())
3144 /// # }
3145 /// ```
3146 ///
3147 /// # See Also
3148 ///
3149 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3150 /// folder
3151 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3152 /// dataset
3153 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3154 pub async fn create_snapshot_edgefirst_format(
3155 &self,
3156 arrow_path: &str,
3157 zip_path: &str,
3158 description: Option<&str>,
3159 progress: Option<Sender<Progress>>,
3160 ) -> Result<Snapshot, Error> {
3161 let arrow_path = Path::new(arrow_path);
3162 let zip_path = Path::new(zip_path);
3163
3164 // Validate files exist
3165 if !arrow_path.exists() {
3166 return Err(Error::IoError(std::io::Error::new(
3167 std::io::ErrorKind::NotFound,
3168 format!("Arrow file not found: {}", arrow_path.display()),
3169 )));
3170 }
3171 if !zip_path.exists() {
3172 return Err(Error::IoError(std::io::Error::new(
3173 std::io::ErrorKind::NotFound,
3174 format!("ZIP file not found: {}", zip_path.display()),
3175 )));
3176 }
3177
3178 // Get file names
3179 let arrow_name = arrow_path
3180 .file_name()
3181 .and_then(|n| n.to_str())
3182 .ok_or_else(|| {
3183 Error::IoError(std::io::Error::new(
3184 std::io::ErrorKind::InvalidInput,
3185 "Invalid Arrow filename",
3186 ))
3187 })?;
3188 let zip_name = zip_path
3189 .file_name()
3190 .and_then(|n| n.to_str())
3191 .ok_or_else(|| {
3192 Error::IoError(std::io::Error::new(
3193 std::io::ErrorKind::InvalidInput,
3194 "Invalid ZIP filename",
3195 ))
3196 })?;
3197
3198 // Generate snapshot name from arrow file (without extension)
3199 let snapshot_name = description
3200 .map(|s| s.to_string())
3201 .or_else(|| {
3202 arrow_path
3203 .file_stem()
3204 .and_then(|s| s.to_str())
3205 .map(|s| s.to_string())
3206 })
3207 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3208
3209 // Calculate file sizes
3210 let arrow_size = arrow_path.metadata()?.len() as usize;
3211 let zip_size = zip_path.metadata()?.len() as usize;
3212 let total = arrow_size + zip_size;
3213 let current = Arc::new(AtomicUsize::new(0));
3214
3215 if let Some(progress) = &progress {
3216 let _ = progress
3217 .send(Progress {
3218 current: 0,
3219 total,
3220 status: None,
3221 })
3222 .await;
3223 }
3224
3225 // Create multipart upload request with "ziparrow" type
3226 let params = SnapshotCreateMultipartParams {
3227 snapshot_name,
3228 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
3229 file_sizes: vec![arrow_size, zip_size],
3230 snapshot_type: Some("ziparrow".to_string()),
3231 };
3232
3233 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3234 .rpc(
3235 "snapshots.create_upload_url_multipart".to_owned(),
3236 Some(params),
3237 )
3238 .await?;
3239
3240 let snapshot_id = match multipart.get("snapshot_id") {
3241 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3242 _ => return Err(Error::InvalidResponse),
3243 };
3244
3245 let snapshot = self.snapshot(snapshot_id).await?;
3246 let part_prefix = snapshot
3247 .path()
3248 .split("::/")
3249 .last()
3250 .ok_or(Error::InvalidResponse)?
3251 .to_owned();
3252
3253 // Upload Arrow file
3254 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
3255 let mut arrow_part = match multipart.get(&arrow_key) {
3256 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3257 _ => return Err(Error::InvalidResponse),
3258 };
3259 arrow_part.key = Some(arrow_key);
3260
3261 let params = upload_multipart(
3262 self.upload_http.clone(),
3263 arrow_part,
3264 arrow_path.to_path_buf(),
3265 total,
3266 current.clone(),
3267 progress.clone(),
3268 )
3269 .await?;
3270
3271 let _: String = self
3272 .rpc(
3273 "snapshots.complete_multipart_upload".to_owned(),
3274 Some(params),
3275 )
3276 .await?;
3277 debug!("Arrow file upload complete");
3278
3279 // Upload ZIP file
3280 let zip_key = format!("{}/{}", part_prefix, zip_name);
3281 let mut zip_part = match multipart.get(&zip_key) {
3282 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3283 _ => return Err(Error::InvalidResponse),
3284 };
3285 zip_part.key = Some(zip_key);
3286
3287 let params = upload_multipart(
3288 self.upload_http.clone(),
3289 zip_part,
3290 zip_path.to_path_buf(),
3291 total,
3292 current.clone(),
3293 progress.clone(),
3294 )
3295 .await?;
3296
3297 let _: String = self
3298 .rpc(
3299 "snapshots.complete_multipart_upload".to_owned(),
3300 Some(params),
3301 )
3302 .await?;
3303 debug!("ZIP file upload complete");
3304
3305 // Mark snapshot as available
3306 let params = SnapshotStatusParams {
3307 snapshot_id,
3308 status: "available".to_owned(),
3309 };
3310 let _: SnapshotStatusResult = self
3311 .rpc("snapshots.update".to_owned(), Some(params))
3312 .await?;
3313
3314 if let Some(progress) = progress {
3315 drop(progress);
3316 }
3317
3318 self.snapshot(snapshot_id).await
3319 }
3320
3321 /// Delete a snapshot from EdgeFirst Studio.
3322 ///
3323 /// Permanently removes a snapshot and its associated data. This operation
3324 /// cannot be undone.
3325 ///
3326 /// # Arguments
3327 ///
3328 /// * `snapshot_id` - The snapshot ID to delete
3329 ///
3330 /// # Errors
3331 ///
3332 /// Returns an error if:
3333 /// * Snapshot doesn't exist
3334 /// * User lacks permission to delete the snapshot
3335 /// * Server error occurs
3336 ///
3337 /// # Example
3338 ///
3339 /// ```no_run
3340 /// # use edgefirst_client::{Client, SnapshotID};
3341 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3342 /// let client = Client::new()?.with_token_path(None)?;
3343 /// let snapshot_id = SnapshotID::from(123);
3344 /// client.delete_snapshot(snapshot_id).await?;
3345 /// # Ok(())
3346 /// # }
3347 /// ```
3348 ///
3349 /// # See Also
3350 ///
3351 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3352 /// * [`snapshots`](Self::snapshots) - List all snapshots
3353 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3354 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
3355 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3356 let _: serde_json::Value = self
3357 .rpc("snapshots.delete".to_owned(), Some(params))
3358 .await?;
3359 Ok(())
3360 }
3361
3362 /// Create a snapshot from an existing dataset on the server.
3363 ///
3364 /// Triggers server-side snapshot generation which exports the dataset's
3365 /// images and annotations into a downloadable EdgeFirst Dataset Format
3366 /// snapshot.
3367 ///
3368 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
3369 /// while restore creates a dataset from a snapshot, this method creates a
3370 /// snapshot from a dataset.
3371 ///
3372 /// # Arguments
3373 ///
3374 /// * `dataset_id` - The dataset ID to create snapshot from
3375 /// * `description` - Description for the created snapshot
3376 ///
3377 /// # Returns
3378 ///
3379 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
3380 /// for monitoring progress.
3381 ///
3382 /// # Errors
3383 ///
3384 /// Returns an error if:
3385 /// * Dataset doesn't exist
3386 /// * User lacks permission to access the dataset
3387 /// * Server rejects the request
3388 ///
3389 /// # Example
3390 ///
3391 /// ```no_run
3392 /// # use edgefirst_client::{Client, DatasetID};
3393 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3394 /// let client = Client::new()?.with_token_path(None)?;
3395 /// let dataset_id = DatasetID::from(123);
3396 ///
3397 /// // Create snapshot from dataset (all annotation sets)
3398 /// let result = client
3399 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
3400 /// .await?;
3401 /// println!("Created snapshot: {:?}", result.id);
3402 ///
3403 /// // Monitor progress via task ID
3404 /// if let Some(task_id) = result.task_id {
3405 /// println!("Task: {}", task_id);
3406 /// }
3407 /// # Ok(())
3408 /// # }
3409 /// ```
3410 ///
3411 /// # See Also
3412 ///
3413 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
3414 /// snapshot
3415 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3416 /// dataset
3417 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3418 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3419 pub async fn create_snapshot_from_dataset(
3420 &self,
3421 dataset_id: DatasetID,
3422 description: &str,
3423 annotation_set_id: Option<AnnotationSetID>,
3424 ) -> Result<SnapshotFromDatasetResult, Error> {
3425 // Resolve annotation_set_id: use provided value or fetch default
3426 let annotation_set_id = match annotation_set_id {
3427 Some(id) => id,
3428 None => {
3429 // Fetch annotation sets and find default ("annotations") or use first
3430 let sets = self.annotation_sets(dataset_id).await?;
3431 if sets.is_empty() {
3432 return Err(Error::InvalidParameters(
3433 "No annotation sets available for dataset".to_owned(),
3434 ));
3435 }
3436 // Look for "annotations" set (default), otherwise use first
3437 sets.iter()
3438 .find(|s| s.name() == "annotations")
3439 .unwrap_or(&sets[0])
3440 .id()
3441 }
3442 };
3443 let params = SnapshotCreateFromDataset {
3444 description: description.to_owned(),
3445 dataset_id,
3446 annotation_set_id,
3447 };
3448 self.rpc("snapshots.create".to_owned(), Some(params)).await
3449 }
3450
3451 /// Download a snapshot from EdgeFirst Studio to local storage.
3452 ///
3453 /// Downloads all files in a snapshot (single MCAP file or directory of
3454 /// EdgeFirst Dataset Format files) to the specified output path. Files are
3455 /// downloaded concurrently with progress tracking.
3456 ///
3457 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3458 /// downloads (default: half of CPU cores, min 2, max 8).
3459 ///
3460 /// # Arguments
3461 ///
3462 /// * `snapshot_id` - The snapshot ID to download
3463 /// * `output` - Local directory path to save downloaded files
3464 /// * `progress` - Optional channel to receive download progress updates
3465 ///
3466 /// # Progress
3467 ///
3468 /// Reports progress with `status: None` as file data is received. Progress
3469 /// unit is bytes downloaded across all files combined. The total
3470 /// accumulates as file sizes become known (from HTTP Content-Length
3471 /// headers), so both `current` and `total` may increase during
3472 /// download.
3473 ///
3474 /// # Errors
3475 ///
3476 /// Returns an error if:
3477 /// * Snapshot doesn't exist
3478 /// * Output directory cannot be created
3479 /// * Download fails or network error occurs
3480 ///
3481 /// # Example
3482 ///
3483 /// ```no_run
3484 /// # use edgefirst_client::{Client, SnapshotID, Progress};
3485 /// # use tokio::sync::mpsc;
3486 /// # use std::path::PathBuf;
3487 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3488 /// let client = Client::new()?.with_token_path(None)?;
3489 /// let snapshot_id = SnapshotID::from(123);
3490 ///
3491 /// // Download with progress tracking
3492 /// let (tx, mut rx) = mpsc::channel(1);
3493 /// tokio::spawn(async move {
3494 /// while let Some(Progress {
3495 /// current,
3496 /// total,
3497 /// status,
3498 /// }) = rx.recv().await
3499 /// {
3500 /// println!(
3501 /// "{}: {}/{} bytes",
3502 /// status.as_deref().unwrap_or("Download"),
3503 /// current,
3504 /// total
3505 /// );
3506 /// }
3507 /// });
3508 /// client
3509 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
3510 /// .await?;
3511 /// # Ok(())
3512 /// # }
3513 /// ```
3514 ///
3515 /// # See Also
3516 ///
3517 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3518 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3519 /// dataset
3520 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3521 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
3522 pub async fn download_snapshot(
3523 &self,
3524 snapshot_id: SnapshotID,
3525 output: PathBuf,
3526 progress: Option<Sender<Progress>>,
3527 ) -> Result<(), Error> {
3528 fs::create_dir_all(&output).await?;
3529
3530 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3531 let items: HashMap<String, String> = self
3532 .rpc("snapshots.create_download_url".to_owned(), Some(params))
3533 .await?;
3534
3535 let total = Arc::new(AtomicUsize::new(0));
3536 let current = Arc::new(AtomicUsize::new(0));
3537 let sem = Arc::new(Semaphore::new(max_tasks()));
3538
3539 let tasks = items
3540 .iter()
3541 .map(|(key, url)| {
3542 let http = self.http.clone();
3543 let key = key.clone();
3544 let url = url.clone();
3545 let output = output.clone();
3546 let progress = progress.clone();
3547 let current = current.clone();
3548 let total = total.clone();
3549 let sem = sem.clone();
3550
3551 tokio::spawn(async move {
3552 let _permit = sem.acquire().await.map_err(|_| {
3553 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3554 })?;
3555 let res = http.get(url).send().await?;
3556 let content_length = res.content_length().unwrap_or(0) as usize;
3557
3558 if let Some(progress) = &progress {
3559 let total = total.fetch_add(content_length, Ordering::SeqCst);
3560 let _ = progress
3561 .send(Progress {
3562 current: current.load(Ordering::SeqCst),
3563 total: total + content_length,
3564 status: None,
3565 })
3566 .await;
3567 }
3568
3569 let mut file = File::create(output.join(key)).await?;
3570 let mut stream = res.bytes_stream();
3571
3572 while let Some(chunk) = stream.next().await {
3573 let chunk = chunk?;
3574 file.write_all(&chunk).await?;
3575 let len = chunk.len();
3576
3577 if let Some(progress) = &progress {
3578 let total = total.load(Ordering::SeqCst);
3579 let current = current.fetch_add(len, Ordering::SeqCst);
3580
3581 let _ = progress
3582 .send(Progress {
3583 current: current + len,
3584 total,
3585 status: None,
3586 })
3587 .await;
3588 }
3589 }
3590
3591 Ok::<(), Error>(())
3592 })
3593 })
3594 .collect::<Vec<_>>();
3595
3596 join_all(tasks)
3597 .await
3598 .into_iter()
3599 .collect::<Result<Vec<_>, _>>()?
3600 .into_iter()
3601 .collect::<Result<Vec<_>, _>>()?;
3602
3603 Ok(())
3604 }
3605
3606 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
3607 ///
3608 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
3609 /// the specified project. For MCAP files, supports:
3610 ///
3611 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
3612 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
3613 /// present)
3614 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
3615 /// * **Topic filtering**: Select specific MCAP topics to restore
3616 ///
3617 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
3618 /// dataset structure.
3619 ///
3620 /// # Arguments
3621 ///
3622 /// * `project_id` - Target project ID
3623 /// * `snapshot_id` - Snapshot ID to restore
3624 /// * `topics` - MCAP topics to include (empty = all topics)
3625 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
3626 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
3627 /// * `dataset_name` - Optional custom dataset name
3628 /// * `dataset_description` - Optional dataset description
3629 ///
3630 /// # Returns
3631 ///
3632 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
3633 ///
3634 /// # Errors
3635 ///
3636 /// Returns an error if:
3637 /// * Snapshot or project doesn't exist
3638 /// * Snapshot format is invalid
3639 /// * Server rejects restoration parameters
3640 ///
3641 /// # Example
3642 ///
3643 /// ```no_run
3644 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
3645 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3646 /// let client = Client::new()?.with_token_path(None)?;
3647 /// let project_id = ProjectID::from(1);
3648 /// let snapshot_id = SnapshotID::from(123);
3649 ///
3650 /// // Restore MCAP with AGTG for "person" and "car" detection
3651 /// let result = client
3652 /// .restore_snapshot(
3653 /// project_id,
3654 /// snapshot_id,
3655 /// &[], // All topics
3656 /// &["person".to_string(), "car".to_string()], // AGTG labels
3657 /// true, // Auto-depth
3658 /// Some("Highway Dataset"),
3659 /// Some("Collected on I-95"),
3660 /// )
3661 /// .await?;
3662 /// println!("Restored to dataset: {:?}", result.dataset_id);
3663 /// # Ok(())
3664 /// # }
3665 /// ```
3666 ///
3667 /// # See Also
3668 ///
3669 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3670 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3671 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3672 #[allow(clippy::too_many_arguments)]
3673 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3674 pub async fn restore_snapshot(
3675 &self,
3676 project_id: ProjectID,
3677 snapshot_id: SnapshotID,
3678 topics: &[String],
3679 autolabel: &[String],
3680 autodepth: bool,
3681 dataset_name: Option<&str>,
3682 dataset_description: Option<&str>,
3683 ) -> Result<SnapshotRestoreResult, Error> {
3684 let params = SnapshotRestore {
3685 project_id,
3686 snapshot_id,
3687 fps: 1,
3688 autodepth,
3689 agtg_pipeline: !autolabel.is_empty(),
3690 autolabel: autolabel.to_vec(),
3691 topics: topics.to_vec(),
3692 dataset_name: dataset_name.map(|s| s.to_owned()),
3693 dataset_description: dataset_description.map(|s| s.to_owned()),
3694 };
3695 self.rpc("snapshots.restore".to_owned(), Some(params)).await
3696 }
3697
3698 /// Returns a list of experiments available to the user. The experiments
3699 /// are returned as a vector of Experiment objects. If name is provided
3700 /// then only experiments containing this string are returned.
3701 ///
3702 /// Results are sorted by match quality: exact matches first, then
3703 /// case-insensitive exact matches, then shorter names (more specific),
3704 /// then alphabetically.
3705 ///
3706 /// Experiments provide a method of organizing training and validation
3707 /// sessions together and are akin to an Experiment in MLFlow terminology.
3708 /// Each experiment can have multiple trainer sessions associated with it,
3709 /// these would be akin to runs in MLFlow terminology.
3710 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3711 pub async fn experiments(
3712 &self,
3713 project_id: ProjectID,
3714 name: Option<&str>,
3715 ) -> Result<Vec<Experiment>, Error> {
3716 let params = HashMap::from([("project_id", project_id)]);
3717 let experiments: Vec<Experiment> =
3718 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
3719 if let Some(name) = name {
3720 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
3721 } else {
3722 Ok(experiments)
3723 }
3724 }
3725
3726 /// Return the experiment with the specified experiment ID. If the
3727 /// experiment does not exist, an error is returned.
3728 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3729 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
3730 let params = HashMap::from([("trainer_id", experiment_id)]);
3731 self.rpc("trainer.get".to_owned(), Some(params)).await
3732 }
3733
3734 /// Returns a list of trainer sessions available to the user. The trainer
3735 /// sessions are returned as a vector of TrainingSession objects. If name
3736 /// is provided then only trainer sessions containing this string are
3737 /// returned.
3738 ///
3739 /// Results are sorted by match quality: exact matches first, then
3740 /// case-insensitive exact matches, then shorter names (more specific),
3741 /// then alphabetically.
3742 ///
3743 /// Trainer sessions are akin to runs in MLFlow terminology. These
3744 /// represent an actual training session which will produce metrics and
3745 /// model artifacts.
3746 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3747 pub async fn training_sessions(
3748 &self,
3749 experiment_id: ExperimentID,
3750 name: Option<&str>,
3751 ) -> Result<Vec<TrainingSession>, Error> {
3752 let params = HashMap::from([("trainer_id", experiment_id)]);
3753 let sessions: Vec<TrainingSession> = self
3754 .rpc("trainer.session.list".to_owned(), Some(params))
3755 .await?;
3756 if let Some(name) = name {
3757 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
3758 } else {
3759 Ok(sessions)
3760 }
3761 }
3762
3763 /// Return the trainer session with the specified trainer session ID. If
3764 /// the trainer session does not exist, an error is returned.
3765 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3766 pub async fn training_session(
3767 &self,
3768 session_id: TrainingSessionID,
3769 ) -> Result<TrainingSession, Error> {
3770 let params = HashMap::from([("trainer_session_id", session_id)]);
3771 self.rpc("trainer.session.get".to_owned(), Some(params))
3772 .await
3773 }
3774
3775 /// List validation sessions for the given project.
3776 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3777 pub async fn validation_sessions(
3778 &self,
3779 project_id: ProjectID,
3780 ) -> Result<Vec<ValidationSession>, Error> {
3781 let params = HashMap::from([("project_id", project_id)]);
3782 self.rpc("validate.session.list".to_owned(), Some(params))
3783 .await
3784 }
3785
3786 /// Retrieve a specific validation session.
3787 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3788 pub async fn validation_session(
3789 &self,
3790 session_id: ValidationSessionID,
3791 ) -> Result<ValidationSession, Error> {
3792 let params = HashMap::from([("validate_session_id", session_id)]);
3793 self.rpc("validate.session.get".to_owned(), Some(params))
3794 .await
3795 }
3796
3797 /// List the artifacts for the specified trainer session. The artifacts
3798 /// are returned as a vector of strings.
3799 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3800 pub async fn artifacts(
3801 &self,
3802 training_session_id: TrainingSessionID,
3803 ) -> Result<Vec<Artifact>, Error> {
3804 let params = HashMap::from([("training_session_id", training_session_id)]);
3805 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
3806 .await
3807 }
3808
3809 /// Download the model artifact for the specified trainer session to the
3810 /// specified file path, if path is not provided it will be downloaded to
3811 /// the current directory with the same filename.
3812 ///
3813 /// # Progress
3814 ///
3815 /// Reports progress with `status: None` as file data is received. Progress
3816 /// unit is bytes downloaded. Total is determined from the HTTP
3817 /// Content-Length header (may be 0 if server doesn't provide it).
3818 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
3819 pub async fn download_artifact(
3820 &self,
3821 training_session_id: TrainingSessionID,
3822 modelname: &str,
3823 filename: Option<PathBuf>,
3824 progress: Option<Sender<Progress>>,
3825 ) -> Result<(), Error> {
3826 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
3827 let resp = self
3828 .http
3829 .get(format!(
3830 "{}/download_model?training_session_id={}&file={}",
3831 self.url,
3832 training_session_id.value(),
3833 modelname
3834 ))
3835 .header("Authorization", format!("Bearer {}", self.token().await))
3836 .send()
3837 .await?;
3838 if !resp.status().is_success() {
3839 let err = resp.error_for_status_ref().unwrap_err();
3840 return Err(Error::HttpError(err));
3841 }
3842
3843 if let Some(parent) = filename.parent() {
3844 fs::create_dir_all(parent).await?;
3845 }
3846
3847 if let Some(progress) = progress {
3848 let total = resp.content_length().unwrap_or(0) as usize;
3849 let _ = progress
3850 .send(Progress {
3851 current: 0,
3852 total,
3853 status: None,
3854 })
3855 .await;
3856
3857 let mut file = File::create(filename).await?;
3858 let mut current = 0;
3859 let mut stream = resp.bytes_stream();
3860
3861 while let Some(item) = stream.next().await {
3862 let chunk = item?;
3863 file.write_all(&chunk).await?;
3864 current += chunk.len();
3865 let _ = progress
3866 .send(Progress {
3867 current,
3868 total,
3869 status: None,
3870 })
3871 .await;
3872 }
3873 } else {
3874 let body = resp.bytes().await?;
3875 fs::write(filename, body).await?;
3876 }
3877
3878 Ok(())
3879 }
3880
3881 /// Download the model checkpoint associated with the specified trainer
3882 /// session to the specified file path, if path is not provided it will be
3883 /// downloaded to the current directory with the same filename.
3884 ///
3885 /// There is no API for listing checkpoints it is expected that trainers are
3886 /// aware of possible checkpoints and their names within the checkpoint
3887 /// folder on the server.
3888 ///
3889 /// # Progress
3890 ///
3891 /// Reports progress with `status: None` as file data is received. Progress
3892 /// unit is bytes downloaded. Total is determined from the HTTP
3893 /// Content-Length header (may be 0 if server doesn't provide it).
3894 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
3895 pub async fn download_checkpoint(
3896 &self,
3897 training_session_id: TrainingSessionID,
3898 checkpoint: &str,
3899 filename: Option<PathBuf>,
3900 progress: Option<Sender<Progress>>,
3901 ) -> Result<(), Error> {
3902 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
3903 let resp = self
3904 .http
3905 .get(format!(
3906 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
3907 self.url,
3908 training_session_id.value(),
3909 checkpoint
3910 ))
3911 .header("Authorization", format!("Bearer {}", self.token().await))
3912 .send()
3913 .await?;
3914 if !resp.status().is_success() {
3915 let err = resp.error_for_status_ref().unwrap_err();
3916 return Err(Error::HttpError(err));
3917 }
3918
3919 if let Some(parent) = filename.parent() {
3920 fs::create_dir_all(parent).await?;
3921 }
3922
3923 if let Some(progress) = progress {
3924 let total = resp.content_length().unwrap_or(0) as usize;
3925 let _ = progress
3926 .send(Progress {
3927 current: 0,
3928 total,
3929 status: None,
3930 })
3931 .await;
3932
3933 let mut file = File::create(filename).await?;
3934 let mut current = 0;
3935 let mut stream = resp.bytes_stream();
3936
3937 while let Some(item) = stream.next().await {
3938 let chunk = item?;
3939 file.write_all(&chunk).await?;
3940 current += chunk.len();
3941 let _ = progress
3942 .send(Progress {
3943 current,
3944 total,
3945 status: None,
3946 })
3947 .await;
3948 }
3949 } else {
3950 let body = resp.bytes().await?;
3951 fs::write(filename, body).await?;
3952 }
3953
3954 Ok(())
3955 }
3956
3957 /// Return a list of tasks for the current user.
3958 ///
3959 /// # Arguments
3960 ///
3961 /// * `name` - Optional filter for task name (client-side substring match)
3962 /// * `workflow` - Optional filter for workflow/task type. If provided,
3963 /// filters server-side by exact match. Valid values include: "trainer",
3964 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
3965 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
3966 /// "convertor", "twostage"
3967 /// * `status` - Optional filter for task status (e.g., "running",
3968 /// "complete", "error")
3969 /// * `manager` - Optional filter for task manager type (e.g., "aws",
3970 /// "user", "kubernetes")
3971 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3972 pub async fn tasks(
3973 &self,
3974 name: Option<&str>,
3975 workflow: Option<&str>,
3976 status: Option<&str>,
3977 manager: Option<&str>,
3978 ) -> Result<Vec<Task>, Error> {
3979 let mut params = TasksListParams {
3980 continue_token: None,
3981 types: workflow.map(|w| vec![w.to_owned()]),
3982 status: status.map(|s| vec![s.to_owned()]),
3983 manager: manager.map(|m| vec![m.to_owned()]),
3984 };
3985 let mut tasks = Vec::new();
3986
3987 loop {
3988 let result = self
3989 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
3990 .await?;
3991 tasks.extend(result.tasks);
3992
3993 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
3994 params.continue_token = None;
3995 } else {
3996 params.continue_token = result.continue_token;
3997 }
3998
3999 if params.continue_token.is_none() {
4000 break;
4001 }
4002 }
4003
4004 if let Some(name) = name {
4005 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4006 }
4007
4008 Ok(tasks)
4009 }
4010
4011 /// Retrieve the task information and status.
4012 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4013 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4014 self.rpc(
4015 "task.get".to_owned(),
4016 Some(HashMap::from([("id", task_id)])),
4017 )
4018 .await
4019 }
4020
4021 /// Updates the tasks status.
4022 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4023 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4024 let status = TaskStatus {
4025 task_id,
4026 status: status.to_owned(),
4027 };
4028 self.rpc("docker.update.status".to_owned(), Some(status))
4029 .await
4030 }
4031
4032 /// Defines the stages for the task. The stages are defined as a mapping
4033 /// from stage names to their descriptions. Once stages are defined their
4034 /// status can be updated using the update_stage method.
4035 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4036 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4037 let stages: Vec<HashMap<String, String>> = stages
4038 .iter()
4039 .map(|(key, value)| {
4040 let mut stage_map = HashMap::new();
4041 stage_map.insert(key.to_string(), value.to_string());
4042 stage_map
4043 })
4044 .collect();
4045 let params = TaskStages { task_id, stages };
4046 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4047 Ok(())
4048 }
4049
4050 /// Updates the progress of the task for the provided stage and status
4051 /// information.
4052 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4053 pub async fn update_stage(
4054 &self,
4055 task_id: TaskID,
4056 stage: &str,
4057 status: &str,
4058 message: &str,
4059 percentage: u8,
4060 ) -> Result<(), Error> {
4061 let stage = Stage::new(
4062 Some(task_id),
4063 stage.to_owned(),
4064 Some(status.to_owned()),
4065 Some(message.to_owned()),
4066 percentage,
4067 );
4068 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4069 Ok(())
4070 }
4071
4072 /// Raw fetch from the Studio server is used for downloading files.
4073 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4074 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4075 let req = self
4076 .http
4077 .get(format!("{}/{}", self.url, query))
4078 .header("User-Agent", "EdgeFirst Client")
4079 .header("Authorization", format!("Bearer {}", self.token().await));
4080 let resp = req.send().await?;
4081
4082 if resp.status().is_success() {
4083 let body = resp.bytes().await?;
4084
4085 if log_enabled!(Level::Trace) {
4086 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
4087 }
4088
4089 Ok(body.to_vec())
4090 } else {
4091 let err = resp.error_for_status_ref().unwrap_err();
4092 Err(Error::HttpError(err))
4093 }
4094 }
4095
4096 /// Sends a multipart post request to the server. This is used by the
4097 /// upload and download APIs which do not use JSON-RPC but instead transfer
4098 /// files using multipart/form-data.
4099 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
4100 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
4101 let req = self
4102 .http
4103 .post(format!("{}/api?method={}", self.url, method))
4104 .header("Accept", "application/json")
4105 .header("User-Agent", "EdgeFirst Client")
4106 .header("Authorization", format!("Bearer {}", self.token().await))
4107 .multipart(form);
4108 let resp = req.send().await?;
4109
4110 if resp.status().is_success() {
4111 let body = resp.bytes().await?;
4112
4113 if log_enabled!(Level::Trace) {
4114 trace!(
4115 "POST Multipart Response: {}",
4116 String::from_utf8_lossy(&body)
4117 );
4118 }
4119
4120 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
4121 Ok(response) => response,
4122 Err(err) => {
4123 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4124 return Err(err.into());
4125 }
4126 };
4127
4128 if let Some(error) = response.error {
4129 Err(Error::RpcError(error.code, error.message))
4130 } else if let Some(result) = response.result {
4131 Ok(result)
4132 } else {
4133 Err(Error::InvalidResponse)
4134 }
4135 } else {
4136 let err = resp.error_for_status_ref().unwrap_err();
4137 Err(Error::HttpError(err))
4138 }
4139 }
4140
4141 /// Send a JSON-RPC request to the server. The method is the name of the
4142 /// method to call on the server. The params are the parameters to pass to
4143 /// the method. The method and params are serialized into a JSON-RPC
4144 /// request and sent to the server. The response is deserialized into
4145 /// the specified type and returned to the caller.
4146 ///
4147 /// NOTE: This API would generally not be called directly and instead users
4148 /// should use the higher-level methods provided by the client.
4149 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
4150 pub async fn rpc<Params, RpcResult>(
4151 &self,
4152 method: String,
4153 params: Option<Params>,
4154 ) -> Result<RpcResult, Error>
4155 where
4156 Params: Serialize,
4157 RpcResult: DeserializeOwned,
4158 {
4159 let auth_expires = self.token_expiration().await?;
4160 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
4161 self.renew_token().await?;
4162 }
4163
4164 self.rpc_without_auth(method, params).await
4165 }
4166
4167 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
4168 async fn rpc_without_auth<Params, RpcResult>(
4169 &self,
4170 method: String,
4171 params: Option<Params>,
4172 ) -> Result<RpcResult, Error>
4173 where
4174 Params: Serialize,
4175 RpcResult: DeserializeOwned,
4176 {
4177 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4178 .ok()
4179 .and_then(|s| s.parse().ok())
4180 .unwrap_or(5usize);
4181
4182 let url = format!("{}/api", self.url);
4183
4184 // Serialize request body once before retry loop to avoid Clone bound on Params
4185 let request = RpcRequest {
4186 method: method.clone(),
4187 params,
4188 ..Default::default()
4189 };
4190
4191 // Log request for debugging (log crate) and profiling (tracing crate)
4192 let request_json = if method == "auth.login" {
4193 // Redact auth.login params (contains password)
4194 serde_json::json!({
4195 "jsonrpc": "2.0",
4196 "method": &method,
4197 "params": "[REDACTED - contains credentials]",
4198 "id": request.id
4199 })
4200 .to_string()
4201 } else {
4202 serde_json::to_string(&request)?
4203 };
4204
4205 if log_enabled!(Level::Trace) {
4206 trace!("RPC Request: {}", request_json);
4207 }
4208
4209 // Record request on current span for Perfetto when profiling is enabled
4210 #[cfg(feature = "profiling")]
4211 tracing::Span::current().record("request", &request_json);
4212
4213 let request_body = serde_json::to_vec(&request)?;
4214 let mut last_error: Option<Error> = None;
4215
4216 for attempt in 0..=max_retries {
4217 if attempt > 0 {
4218 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
4219 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
4220 // while ensuring we never retry faster than the base delay
4221 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
4222 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
4223 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
4224 let delay = Duration::from_millis(delay_ms);
4225 warn!(
4226 "Retry {}/{} for RPC '{}' after {:?}",
4227 attempt, max_retries, method, delay
4228 );
4229 tokio::time::sleep(delay).await;
4230 }
4231
4232 let result = self
4233 .http
4234 .post(&url)
4235 .header("Accept", "application/json")
4236 .header("Content-Type", "application/json")
4237 .header("User-Agent", "EdgeFirst Client")
4238 .header("Authorization", format!("Bearer {}", self.token().await))
4239 .body(request_body.clone())
4240 .send()
4241 .await;
4242
4243 match result {
4244 Ok(res) => {
4245 let status = res.status();
4246 let status_code = status.as_u16();
4247
4248 // Check for retryable HTTP status codes before processing response
4249 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
4250 && attempt < max_retries
4251 {
4252 warn!(
4253 "RPC '{}' failed with HTTP {} (retrying)",
4254 method, status_code
4255 );
4256 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
4257 continue;
4258 }
4259
4260 // Process the response
4261 match self.process_rpc_response(res).await {
4262 Ok(result) => {
4263 if attempt > 0 {
4264 debug!("RPC '{}' succeeded on retry {}", method, attempt);
4265 }
4266 return Ok(result);
4267 }
4268 Err(e) => {
4269 // Don't retry client errors (4xx except 408, 429)
4270 if attempt > 0 {
4271 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4272 }
4273 return Err(e);
4274 }
4275 }
4276 }
4277 Err(e) => {
4278 // Transport error (timeout, connection failure, etc.)
4279 let is_timeout = e.is_timeout();
4280 let is_connect = e.is_connect();
4281
4282 if (is_timeout || is_connect) && attempt < max_retries {
4283 warn!(
4284 "RPC '{}' transport error (retrying): {}",
4285 method,
4286 if is_timeout {
4287 "timeout"
4288 } else {
4289 "connection failed"
4290 }
4291 );
4292 last_error = Some(Error::HttpError(e));
4293 continue;
4294 }
4295
4296 if attempt > 0 {
4297 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4298 }
4299 return Err(Error::HttpError(e));
4300 }
4301 }
4302 }
4303
4304 // Should not reach here
4305 Err(last_error.unwrap_or_else(|| {
4306 Error::InvalidParameters(format!(
4307 "RPC '{}' failed after {} retries",
4308 method, max_retries
4309 ))
4310 }))
4311 }
4312
4313 async fn process_rpc_response<RpcResult>(
4314 &self,
4315 res: reqwest::Response,
4316 ) -> Result<RpcResult, Error>
4317 where
4318 RpcResult: DeserializeOwned,
4319 {
4320 let body = res.bytes().await?;
4321 let response_str = String::from_utf8_lossy(&body);
4322
4323 if log_enabled!(Level::Trace) {
4324 trace!("RPC Response: {}", response_str);
4325 }
4326
4327 // Record response on current span for Perfetto when profiling is enabled
4328 // Truncate large responses to avoid bloating trace files
4329 #[cfg(feature = "profiling")]
4330 {
4331 const MAX_RESPONSE_LEN: usize = 4096;
4332 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
4333 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
4334 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
4335 format!(
4336 "{}...[truncated {} bytes]",
4337 &response_str[..safe_end],
4338 response_str.len() - safe_end
4339 )
4340 } else {
4341 response_str.to_string()
4342 };
4343 tracing::Span::current().record("response", &truncated);
4344 }
4345
4346 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
4347 Ok(response) => response,
4348 Err(err) => {
4349 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4350 return Err(err.into());
4351 }
4352 };
4353
4354 // FIXME: Studio Server always returns 999 as the id.
4355 // if request.id.to_string() != response.id {
4356 // return Err(Error::InvalidRpcId(response.id));
4357 // }
4358
4359 if let Some(error) = response.error {
4360 Err(Error::RpcError(error.code, error.message))
4361 } else if let Some(result) = response.result {
4362 Ok(result)
4363 } else {
4364 Err(Error::InvalidResponse)
4365 }
4366 }
4367}
4368
4369/// Process items in parallel with semaphore concurrency control and progress
4370/// tracking.
4371///
4372/// This helper eliminates boilerplate for parallel item processing with:
4373/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
4374/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
4375/// - Atomic progress counter with automatic item-level updates
4376/// - Progress updates sent after each item completes (not byte-level streaming)
4377/// - Proper error propagation from spawned tasks
4378///
4379/// Note: This is optimized for discrete items with post-completion progress
4380/// updates. For byte-level streaming progress or custom retry logic, use
4381/// specialized implementations.
4382///
4383/// # Arguments
4384///
4385/// * `items` - Collection of items to process in parallel
4386/// * `progress` - Optional progress channel for tracking completion
4387/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
4388/// * `work_fn` - Async function to execute for each item
4389///
4390/// # Examples
4391///
4392/// ```rust,ignore
4393/// // Use default concurrency
4394/// parallel_foreach_items(samples, progress, None, |sample| async move {
4395/// sample.download(&client, file_type).await?;
4396/// Ok(())
4397/// }).await?;
4398/// ```
4399async fn parallel_foreach_items<T, F, Fut>(
4400 items: Vec<T>,
4401 progress: Option<Sender<Progress>>,
4402 concurrency: Option<usize>,
4403 work_fn: F,
4404) -> Result<(), Error>
4405where
4406 T: Send + 'static,
4407 F: Fn(T) -> Fut + Send + Sync + 'static,
4408 Fut: Future<Output = Result<(), Error>> + Send + 'static,
4409{
4410 let total = items.len();
4411 let current = Arc::new(AtomicUsize::new(0));
4412 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
4413 let work_fn = Arc::new(work_fn);
4414
4415 let tasks = items
4416 .into_iter()
4417 .map(|item| {
4418 let sem = sem.clone();
4419 let current = current.clone();
4420 let progress = progress.clone();
4421 let work_fn = work_fn.clone();
4422
4423 tokio::spawn(async move {
4424 let _permit = sem.acquire().await.map_err(|_| {
4425 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4426 })?;
4427
4428 // Execute the actual work
4429 work_fn(item).await?;
4430
4431 // Update progress
4432 if let Some(progress) = &progress {
4433 let current = current.fetch_add(1, Ordering::SeqCst);
4434 let _ = progress
4435 .send(Progress {
4436 current: current + 1,
4437 total,
4438 status: None,
4439 })
4440 .await;
4441 }
4442
4443 Ok::<(), Error>(())
4444 })
4445 })
4446 .collect::<Vec<_>>();
4447
4448 join_all(tasks)
4449 .await
4450 .into_iter()
4451 .collect::<Result<Vec<_>, _>>()?
4452 .into_iter()
4453 .collect::<Result<Vec<_>, _>>()?;
4454
4455 if let Some(progress) = progress {
4456 drop(progress);
4457 }
4458
4459 Ok(())
4460}
4461
4462/// Upload a file to S3 using multipart upload with presigned URLs.
4463///
4464/// Splits a file into chunks (100MB each) and uploads them in parallel using
4465/// S3 multipart upload protocol. Returns completion parameters with ETags for
4466/// finalizing the upload.
4467///
4468/// This function handles:
4469/// - Splitting files into parts based on PART_SIZE (100MB)
4470/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
4471/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
4472/// - Retry logic (handled by reqwest client)
4473/// - Progress tracking across all parts
4474///
4475/// # Arguments
4476///
4477/// * `http` - HTTP client for making requests
4478/// * `part` - Snapshot part info with presigned URLs for each chunk
4479/// * `path` - Local file path to upload
4480/// * `total` - Total bytes across all files for progress calculation
4481/// * `current` - Atomic counter tracking bytes uploaded across all operations
4482/// * `progress` - Optional channel for sending progress updates
4483///
4484/// # Returns
4485///
4486/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
4487async fn upload_multipart(
4488 http: reqwest::Client,
4489 part: SnapshotPart,
4490 path: PathBuf,
4491 total: usize,
4492 confirmed_bytes: Arc<AtomicUsize>,
4493 progress: Option<Sender<Progress>>,
4494) -> Result<SnapshotCompleteMultipartParams, Error> {
4495 let filesize = path.metadata()?.len() as usize;
4496 let n_parts = filesize.div_ceil(PART_SIZE);
4497 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
4498
4499 let key = part.key.ok_or(Error::InvalidResponse)?;
4500 let upload_id = part.upload_id;
4501
4502 let urls = part.urls.clone();
4503
4504 // Pre-allocate ETag slots for all parts
4505 let etags = Arc::new(tokio::sync::Mutex::new(vec![
4506 EtagPart {
4507 etag: "".to_owned(),
4508 part_number: 0,
4509 };
4510 n_parts
4511 ]));
4512
4513 // Per-part byte counters for streaming progress (reset on retry)
4514 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
4515 (0..n_parts)
4516 .map(|_| AtomicUsize::new(0))
4517 .collect::<Vec<_>>(),
4518 );
4519
4520 // Upload all parts in parallel with concurrency limiting
4521 let tasks = (0..n_parts)
4522 .map(|part_idx| {
4523 let http = http.clone();
4524 let url = urls[part_idx].clone();
4525 let etags = etags.clone();
4526 let path = path.to_owned();
4527 let sem = sem.clone();
4528 let progress = progress.clone();
4529 let confirmed_bytes = confirmed_bytes.clone();
4530 let part_bytes = part_bytes.clone();
4531
4532 // Calculate this part's size
4533 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
4534 filesize % PART_SIZE
4535 } else {
4536 PART_SIZE
4537 };
4538
4539 tokio::spawn(async move {
4540 // Acquire semaphore permit to limit concurrent uploads
4541 let _permit = sem.acquire().await.map_err(|_| {
4542 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4543 })?;
4544
4545 // Upload part with streaming progress and retry logic
4546 let etag = upload_part_with_progress(
4547 http,
4548 url,
4549 path,
4550 part_idx,
4551 n_parts,
4552 part_size,
4553 total,
4554 confirmed_bytes.clone(),
4555 part_bytes.clone(),
4556 progress.clone(),
4557 )
4558 .await?;
4559
4560 // Store ETag for this part (needed to complete multipart upload)
4561 let mut etags_guard = etags.lock().await;
4562 etags_guard[part_idx] = EtagPart {
4563 etag,
4564 part_number: part_idx + 1,
4565 };
4566
4567 // Part completed successfully - add to confirmed bytes
4568 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
4569 // Reset part counter since it's now confirmed
4570 part_bytes[part_idx].store(0, Ordering::SeqCst);
4571
4572 // Send final progress update for this part
4573 if let Some(progress) = &progress {
4574 let current = confirmed_bytes.load(Ordering::SeqCst)
4575 + part_bytes
4576 .iter()
4577 .map(|p| p.load(Ordering::SeqCst))
4578 .sum::<usize>();
4579 let _ = progress
4580 .send(Progress {
4581 current,
4582 total,
4583 status: None,
4584 })
4585 .await;
4586 }
4587
4588 Ok::<(), Error>(())
4589 })
4590 })
4591 .collect::<Vec<_>>();
4592
4593 // Wait for all parts to complete (double collect to handle both JoinError and
4594 // inner Error)
4595 join_all(tasks)
4596 .await
4597 .into_iter()
4598 .collect::<Result<Vec<_>, _>>()?
4599 .into_iter()
4600 .collect::<Result<Vec<_>, _>>()?;
4601
4602 Ok(SnapshotCompleteMultipartParams {
4603 key,
4604 upload_id,
4605 etag_list: etags.lock().await.clone(),
4606 })
4607}
4608
4609/// Upload a single part with streaming progress tracking and retry logic.
4610///
4611/// Progress is reported continuously as bytes are sent. On retry, the part's
4612/// progress counter is reset to avoid over-reporting.
4613#[allow(clippy::too_many_arguments)]
4614async fn upload_part_with_progress(
4615 http: reqwest::Client,
4616 url: String,
4617 path: PathBuf,
4618 part_idx: usize,
4619 n_parts: usize,
4620 part_size: usize,
4621 total: usize,
4622 confirmed_bytes: Arc<AtomicUsize>,
4623 part_bytes: Arc<Vec<AtomicUsize>>,
4624 progress: Option<Sender<Progress>>,
4625) -> Result<String, Error> {
4626 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4627 .ok()
4628 .and_then(|s| s.parse().ok())
4629 .unwrap_or(3usize);
4630
4631 let mut last_error: Option<Error> = None;
4632
4633 for attempt in 0..=max_retries {
4634 if attempt > 0 {
4635 // Reset this part's progress counter before retry
4636 part_bytes[part_idx].store(0, Ordering::SeqCst);
4637
4638 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4639 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4640 warn!(
4641 "Retry {}/{} for part {} after {:?}",
4642 attempt, max_retries, part_idx, delay
4643 );
4644 tokio::time::sleep(delay).await;
4645 }
4646
4647 match upload_part_streaming(
4648 http.clone(),
4649 url.clone(),
4650 path.clone(),
4651 part_idx,
4652 n_parts,
4653 part_size,
4654 total,
4655 confirmed_bytes.clone(),
4656 part_bytes.clone(),
4657 progress.clone(),
4658 )
4659 .await
4660 {
4661 Ok(etag) => return Ok(etag),
4662 Err(e) => {
4663 // Check if error is retryable
4664 let is_retryable = matches!(
4665 &e,
4666 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
4667 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
4668 );
4669
4670 if is_retryable && attempt < max_retries {
4671 last_error = Some(e);
4672 continue;
4673 }
4674
4675 return Err(e);
4676 }
4677 }
4678 }
4679
4680 Err(last_error
4681 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
4682}
4683
4684/// Perform the actual upload with streaming progress.
4685#[allow(clippy::too_many_arguments)]
4686async fn upload_part_streaming(
4687 http: reqwest::Client,
4688 url: String,
4689 path: PathBuf,
4690 part_idx: usize,
4691 n_parts: usize,
4692 _part_size: usize,
4693 total: usize,
4694 confirmed_bytes: Arc<AtomicUsize>,
4695 part_bytes: Arc<Vec<AtomicUsize>>,
4696 progress: Option<Sender<Progress>>,
4697) -> Result<String, Error> {
4698 let filesize = path.metadata()?.len() as usize;
4699 let mut file = File::open(&path).await?;
4700 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
4701 .await?;
4702 let file = file.take(PART_SIZE as u64);
4703
4704 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
4705 filesize % PART_SIZE
4706 } else {
4707 PART_SIZE
4708 };
4709
4710 // Create stream with progress tracking
4711 let stream = FramedRead::new(file, BytesCodec::new());
4712
4713 // Wrap stream to track bytes sent and report progress
4714 let progress_stream = stream.map(move |result| {
4715 if let Ok(ref bytes) = result {
4716 let bytes_len = bytes.len();
4717 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
4718
4719 // Send progress update (fire-and-forget via try_send to avoid blocking)
4720 if let Some(ref progress) = progress {
4721 let current = confirmed_bytes.load(Ordering::SeqCst)
4722 + part_bytes
4723 .iter()
4724 .map(|p| p.load(Ordering::SeqCst))
4725 .sum::<usize>();
4726 // Best-effort progress reporting: use try_send to avoid blocking.
4727 // If the channel is full or closed, we intentionally skip this update
4728 // to avoid stalling the upload; subsequent updates will still be delivered.
4729 let _ = progress.try_send(Progress {
4730 current,
4731 total,
4732 status: None,
4733 });
4734 }
4735 }
4736 result.map(|b| b.freeze())
4737 });
4738
4739 let body = Body::wrap_stream(progress_stream);
4740
4741 let resp = http
4742 .put(url)
4743 .header(CONTENT_LENGTH, body_length)
4744 .body(body)
4745 .send()
4746 .await?
4747 .error_for_status()?;
4748
4749 let etag = resp
4750 .headers()
4751 .get("etag")
4752 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
4753 .to_str()
4754 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
4755 .to_owned();
4756
4757 // Studio Server requires etag without the quotes.
4758 let etag = etag
4759 .strip_prefix("\"")
4760 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
4761 let etag = etag
4762 .strip_suffix("\"")
4763 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
4764
4765 Ok(etag.to_owned())
4766}
4767
4768/// Upload a complete file to a presigned S3 URL using HTTP PUT.
4769///
4770/// This is used for populate_samples to upload files to S3 after
4771/// receiving presigned URLs from the server.
4772///
4773/// Includes explicit retry logic with exponential backoff for transient
4774/// failures.
4775async fn upload_file_to_presigned_url(
4776 http: reqwest::Client,
4777 url: &str,
4778 path: PathBuf,
4779) -> Result<(), Error> {
4780 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4781 .ok()
4782 .and_then(|s| s.parse().ok())
4783 .unwrap_or(3usize);
4784
4785 // Read the entire file into memory once
4786 let file_data = fs::read(&path).await?;
4787 let file_size = file_data.len();
4788 let filename = path.file_name().unwrap_or_default().to_string_lossy();
4789
4790 let mut last_error: Option<Error> = None;
4791
4792 for attempt in 0..=max_retries {
4793 if attempt > 0 {
4794 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4795 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4796 warn!(
4797 "Retry {}/{} for upload '{}' after {:?}",
4798 attempt, max_retries, filename, delay
4799 );
4800 tokio::time::sleep(delay).await;
4801 }
4802
4803 // Attempt upload
4804 let result = http
4805 .put(url)
4806 .header(CONTENT_LENGTH, file_size)
4807 .body(file_data.clone())
4808 .send()
4809 .await;
4810
4811 match result {
4812 Ok(resp) => {
4813 if resp.status().is_success() {
4814 if attempt > 0 {
4815 debug!(
4816 "Upload '{}' succeeded on retry {} ({} bytes)",
4817 filename, attempt, file_size
4818 );
4819 } else {
4820 debug!(
4821 "Successfully uploaded file: {} ({} bytes)",
4822 filename, file_size
4823 );
4824 }
4825 return Ok(());
4826 }
4827
4828 let status = resp.status();
4829 let status_code = status.as_u16();
4830
4831 // Check if error is retryable
4832 let is_retryable =
4833 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
4834
4835 if is_retryable && attempt < max_retries {
4836 let error_text = resp.text().await.unwrap_or_default();
4837 warn!(
4838 "Upload '{}' failed with HTTP {} (retryable): {}",
4839 filename, status_code, error_text
4840 );
4841 last_error = Some(Error::InvalidParameters(format!(
4842 "Upload failed: HTTP {} - {}",
4843 status, error_text
4844 )));
4845 continue;
4846 }
4847
4848 // Non-retryable error or max retries exceeded
4849 let error_text = resp.text().await.unwrap_or_default();
4850 if attempt > 0 {
4851 error!(
4852 "Upload '{}' failed after {} retries: HTTP {} - {}",
4853 filename, attempt, status, error_text
4854 );
4855 }
4856 return Err(Error::InvalidParameters(format!(
4857 "Upload failed: HTTP {} - {}",
4858 status, error_text
4859 )));
4860 }
4861 Err(e) => {
4862 // Transport error (timeout, connection failure, etc.)
4863 let is_timeout = e.is_timeout();
4864 let is_connect = e.is_connect();
4865
4866 if (is_timeout || is_connect) && attempt < max_retries {
4867 warn!(
4868 "Upload '{}' transport error (retrying): {}",
4869 filename,
4870 if is_timeout {
4871 "timeout"
4872 } else {
4873 "connection failed"
4874 }
4875 );
4876 last_error = Some(Error::HttpError(e));
4877 continue;
4878 }
4879
4880 // Non-retryable or max retries exceeded
4881 if attempt > 0 {
4882 error!(
4883 "Upload '{}' failed after {} retries: {}",
4884 filename, attempt, e
4885 );
4886 }
4887 return Err(Error::HttpError(e));
4888 }
4889 }
4890 }
4891
4892 // Should not reach here, but return last error if we do
4893 Err(last_error.unwrap_or_else(|| {
4894 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
4895 }))
4896}
4897
4898/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
4899///
4900/// This is used for populate_samples to upload file content from memory
4901/// (e.g., from ZIP archives) without writing to disk first.
4902///
4903/// Includes explicit retry logic with exponential backoff for transient
4904/// failures.
4905async fn upload_bytes_to_presigned_url(
4906 http: reqwest::Client,
4907 url: &str,
4908 file_data: Vec<u8>,
4909 filename: &str,
4910) -> Result<(), Error> {
4911 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4912 .ok()
4913 .and_then(|s| s.parse().ok())
4914 .unwrap_or(3usize);
4915
4916 let file_size = file_data.len();
4917 let mut last_error: Option<Error> = None;
4918
4919 for attempt in 0..=max_retries {
4920 if attempt > 0 {
4921 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4922 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4923 warn!(
4924 "Retry {}/{} for upload '{}' after {:?}",
4925 attempt, max_retries, filename, delay
4926 );
4927 tokio::time::sleep(delay).await;
4928 }
4929
4930 // Attempt upload
4931 let result = http
4932 .put(url)
4933 .header(CONTENT_LENGTH, file_size)
4934 .body(file_data.clone())
4935 .send()
4936 .await;
4937
4938 match result {
4939 Ok(resp) => {
4940 if resp.status().is_success() {
4941 if attempt > 0 {
4942 debug!(
4943 "Upload '{}' succeeded on retry {} ({} bytes)",
4944 filename, attempt, file_size
4945 );
4946 } else {
4947 debug!(
4948 "Successfully uploaded file: {} ({} bytes)",
4949 filename, file_size
4950 );
4951 }
4952 return Ok(());
4953 }
4954
4955 let status = resp.status();
4956 let status_code = status.as_u16();
4957
4958 // Check if error is retryable
4959 let is_retryable =
4960 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
4961
4962 if is_retryable && attempt < max_retries {
4963 let error_text = resp.text().await.unwrap_or_default();
4964 warn!(
4965 "Upload '{}' failed with HTTP {} (retryable): {}",
4966 filename, status_code, error_text
4967 );
4968 last_error = Some(Error::InvalidParameters(format!(
4969 "Upload failed: HTTP {} - {}",
4970 status, error_text
4971 )));
4972 continue;
4973 }
4974
4975 // Non-retryable error or max retries exceeded
4976 let error_text = resp.text().await.unwrap_or_default();
4977 if attempt > 0 {
4978 error!(
4979 "Upload '{}' failed after {} retries: HTTP {} - {}",
4980 filename, attempt, status, error_text
4981 );
4982 }
4983 return Err(Error::InvalidParameters(format!(
4984 "Upload failed: HTTP {} - {}",
4985 status, error_text
4986 )));
4987 }
4988 Err(e) => {
4989 // Transport error (timeout, connection failure, etc.)
4990 let is_timeout = e.is_timeout();
4991 let is_connect = e.is_connect();
4992
4993 if (is_timeout || is_connect) && attempt < max_retries {
4994 warn!(
4995 "Upload '{}' transport error (retrying): {}",
4996 filename,
4997 if is_timeout {
4998 "timeout"
4999 } else {
5000 "connection failed"
5001 }
5002 );
5003 last_error = Some(Error::HttpError(e));
5004 continue;
5005 }
5006
5007 // Non-retryable or max retries exceeded
5008 if attempt > 0 {
5009 error!(
5010 "Upload '{}' failed after {} retries: {}",
5011 filename, attempt, e
5012 );
5013 }
5014 return Err(Error::HttpError(e));
5015 }
5016 }
5017 }
5018
5019 // Should not reach here, but return last error if we do
5020 Err(last_error.unwrap_or_else(|| {
5021 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5022 }))
5023}
5024
5025#[cfg(test)]
5026mod tests {
5027 use super::*;
5028
5029 #[test]
5030 fn test_filter_and_sort_by_name_exact_match_first() {
5031 // Test that exact matches come first
5032 let items = vec![
5033 "Deer Roundtrip 123".to_string(),
5034 "Deer".to_string(),
5035 "Reindeer".to_string(),
5036 "DEER".to_string(),
5037 ];
5038 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5039 assert_eq!(result[0], "Deer"); // Exact match first
5040 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
5041 }
5042
5043 #[test]
5044 fn test_filter_and_sort_by_name_shorter_names_preferred() {
5045 // Test that shorter names (more specific) come before longer ones
5046 let items = vec![
5047 "Test Dataset ABC".to_string(),
5048 "Test".to_string(),
5049 "Test Dataset".to_string(),
5050 ];
5051 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5052 assert_eq!(result[0], "Test"); // Exact match first
5053 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
5054 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
5055 }
5056
5057 #[test]
5058 fn test_filter_and_sort_by_name_case_insensitive_filter() {
5059 // Test that filtering is case-insensitive
5060 let items = vec![
5061 "UPPERCASE".to_string(),
5062 "lowercase".to_string(),
5063 "MixedCase".to_string(),
5064 ];
5065 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
5066 assert_eq!(result.len(), 3); // All items should match
5067 }
5068
5069 #[test]
5070 fn test_filter_and_sort_by_name_no_matches() {
5071 // Test that empty result is returned when no matches
5072 let items = vec!["Apple".to_string(), "Banana".to_string()];
5073 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
5074 assert!(result.is_empty());
5075 }
5076
5077 #[test]
5078 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
5079 // Test alphabetical ordering for same-length names
5080 let items = vec![
5081 "TestC".to_string(),
5082 "TestA".to_string(),
5083 "TestB".to_string(),
5084 ];
5085 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5086 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
5087 }
5088
5089 #[test]
5090 fn test_build_filename_no_flatten() {
5091 // When flatten=false, should return base_name unchanged
5092 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
5093 assert_eq!(result, "image.jpg");
5094
5095 let result = Client::build_filename("test.png", false, None, None);
5096 assert_eq!(result, "test.png");
5097 }
5098
5099 #[test]
5100 fn test_build_filename_flatten_no_sequence() {
5101 // When flatten=true but no sequence, should return base_name unchanged
5102 let result = Client::build_filename("standalone.jpg", true, None, None);
5103 assert_eq!(result, "standalone.jpg");
5104 }
5105
5106 #[test]
5107 fn test_build_filename_flatten_with_sequence_not_prefixed() {
5108 // When flatten=true, in sequence, filename not prefixed → add prefix
5109 let result = Client::build_filename(
5110 "image.camera.jpeg",
5111 true,
5112 Some(&"deer_sequence".to_string()),
5113 Some(42),
5114 );
5115 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
5116 }
5117
5118 #[test]
5119 fn test_build_filename_flatten_with_sequence_no_frame() {
5120 // When flatten=true, in sequence, no frame number → prefix with sequence only
5121 let result =
5122 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
5123 assert_eq!(result, "sequence_A_image.jpg");
5124 }
5125
5126 #[test]
5127 fn test_build_filename_flatten_already_prefixed() {
5128 // When flatten=true, filename already starts with sequence_ → return unchanged
5129 let result = Client::build_filename(
5130 "deer_sequence_042.camera.jpeg",
5131 true,
5132 Some(&"deer_sequence".to_string()),
5133 Some(42),
5134 );
5135 assert_eq!(result, "deer_sequence_042.camera.jpeg");
5136 }
5137
5138 #[test]
5139 fn test_build_filename_flatten_already_prefixed_different_frame() {
5140 // Edge case: filename has sequence prefix but we're adding different frame
5141 // Should still respect existing prefix
5142 let result = Client::build_filename(
5143 "sequence_A_001.jpg",
5144 true,
5145 Some(&"sequence_A".to_string()),
5146 Some(2),
5147 );
5148 assert_eq!(result, "sequence_A_001.jpg");
5149 }
5150
5151 #[test]
5152 fn test_build_filename_flatten_partial_match() {
5153 // Edge case: filename contains sequence name but not as prefix
5154 let result = Client::build_filename(
5155 "test_sequence_A_image.jpg",
5156 true,
5157 Some(&"sequence_A".to_string()),
5158 Some(5),
5159 );
5160 // Should add prefix because it doesn't START with "sequence_A_"
5161 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
5162 }
5163
5164 #[test]
5165 fn test_build_filename_flatten_preserves_extension() {
5166 // Verify that file extensions are preserved correctly
5167 let extensions = vec![
5168 "jpeg",
5169 "jpg",
5170 "png",
5171 "camera.jpeg",
5172 "lidar.pcd",
5173 "depth.png",
5174 ];
5175
5176 for ext in extensions {
5177 let filename = format!("image.{}", ext);
5178 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
5179 assert!(
5180 result.ends_with(&format!(".{}", ext)),
5181 "Extension .{} not preserved in {}",
5182 ext,
5183 result
5184 );
5185 }
5186 }
5187
5188 #[test]
5189 fn test_build_filename_flatten_sanitization_compatibility() {
5190 // Test with sanitized path components (no special chars)
5191 let result = Client::build_filename(
5192 "sample_001.jpg",
5193 true,
5194 Some(&"seq_name_with_underscores".to_string()),
5195 Some(10),
5196 );
5197 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
5198 }
5199
5200 // =========================================================================
5201 // Additional filter_and_sort_by_name tests for exact match determinism
5202 // =========================================================================
5203
5204 #[test]
5205 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
5206 // Test that searching for "Deer" always returns "Deer" first, not
5207 // "Deer Roundtrip 20251129" or similar
5208 let items = vec![
5209 "Deer Roundtrip 20251129".to_string(),
5210 "White-Tailed Deer".to_string(),
5211 "Deer".to_string(),
5212 "Deer Snapshot Test".to_string(),
5213 "Reindeer Dataset".to_string(),
5214 ];
5215
5216 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5217
5218 // CRITICAL: First result must be exact match "Deer"
5219 assert_eq!(
5220 result.first().map(|s| s.as_str()),
5221 Some("Deer"),
5222 "Expected exact match 'Deer' first, got: {:?}",
5223 result.first()
5224 );
5225
5226 // Verify all items containing "Deer" are present (case-insensitive)
5227 assert_eq!(result.len(), 5);
5228 }
5229
5230 #[test]
5231 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
5232 // Verify case-sensitive exact match takes priority over case-insensitive
5233 let items = vec![
5234 "DEER".to_string(),
5235 "deer".to_string(),
5236 "Deer".to_string(),
5237 "Deer Test".to_string(),
5238 ];
5239
5240 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5241
5242 // Priority 1: Case-sensitive exact match "Deer" first
5243 assert_eq!(result[0], "Deer");
5244 // Priority 2: Case-insensitive exact matches next
5245 assert!(result[1] == "DEER" || result[1] == "deer");
5246 assert!(result[2] == "DEER" || result[2] == "deer");
5247 }
5248
5249 #[test]
5250 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
5251 // Realistic scenario: User searches for snapshot "Deer" and multiple
5252 // snapshots exist with similar names
5253 let items = vec![
5254 "Unit Testing - Deer Dataset Backup".to_string(),
5255 "Deer".to_string(),
5256 "Deer Snapshot 2025-01-15".to_string(),
5257 "Original Deer".to_string(),
5258 ];
5259
5260 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5261
5262 // MUST return exact match first for deterministic test behavior
5263 assert_eq!(
5264 result[0], "Deer",
5265 "Searching for 'Deer' should return exact 'Deer' first"
5266 );
5267 }
5268
5269 #[test]
5270 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
5271 // Realistic scenario: User searches for dataset "Deer" but multiple
5272 // datasets have "Deer" in their name
5273 let items = vec![
5274 "Deer Roundtrip".to_string(),
5275 "Deer".to_string(),
5276 "deer".to_string(),
5277 "White-Tailed Deer".to_string(),
5278 "Deer-V2".to_string(),
5279 ];
5280
5281 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5282
5283 // Exact case-sensitive match must be first
5284 assert_eq!(result[0], "Deer");
5285 // Case-insensitive exact match should be second
5286 assert_eq!(result[1], "deer");
5287 // Shorter names should come before longer names
5288 assert!(
5289 result.iter().position(|s| s == "Deer-V2").unwrap()
5290 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
5291 );
5292 }
5293
5294 #[test]
5295 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
5296 // CRITICAL: The first result should ALWAYS be the best match
5297 // This is essential for deterministic test behavior
5298 let scenarios = vec![
5299 // (items, filter, expected_first)
5300 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
5301 (vec!["test", "TEST", "Test Data"], "test", "test"),
5302 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
5303 ];
5304
5305 for (items, filter, expected_first) in scenarios {
5306 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
5307 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
5308
5309 assert_eq!(
5310 result.first().map(|s| s.as_str()),
5311 Some(expected_first),
5312 "For filter '{}', expected first result '{}', got: {:?}",
5313 filter,
5314 expected_first,
5315 result.first()
5316 );
5317 }
5318 }
5319
5320 #[test]
5321 fn test_with_server_clears_storage() {
5322 use crate::storage::MemoryTokenStorage;
5323
5324 // Create client with memory storage and a token
5325 let storage = Arc::new(MemoryTokenStorage::new());
5326 storage.store("test-token").unwrap();
5327
5328 let client = Client::new().unwrap().with_storage(storage.clone());
5329
5330 // Verify token is loaded
5331 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
5332
5333 // Change server - should clear storage
5334 let _new_client = client.with_server("test").unwrap();
5335
5336 // Verify storage was cleared
5337 assert_eq!(storage.load().unwrap(), None);
5338 }
5339}