edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SampleID, SamplesCountResult, SamplesListParams, SamplesListResult,
9 Snapshot, SnapshotCreateFromDataset, SnapshotFromDatasetResult, SnapshotID,
10 SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages, TaskStatus,
11 TasksListParams, TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
12 ValidationSessionID,
13 },
14 dataset::{
15 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
16 },
17 retry::{create_retry_policy, log_retry_configuration},
18 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
19};
20use base64::Engine as _;
21use chrono::{DateTime, Utc};
22use directories::ProjectDirs;
23use futures::{StreamExt as _, future::join_all};
24use log::{Level, debug, error, log_enabled, trace, warn};
25use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
26use serde::{Deserialize, Serialize, de::DeserializeOwned};
27use std::{
28 collections::HashMap,
29 ffi::OsStr,
30 fs::create_dir_all,
31 io::{SeekFrom, Write as _},
32 path::{Path, PathBuf},
33 sync::{
34 Arc,
35 atomic::{AtomicUsize, Ordering},
36 },
37 time::Duration,
38 vec,
39};
40use tokio::{
41 fs::{self, File},
42 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
43 sync::{RwLock, Semaphore, mpsc::Sender},
44};
45use tokio_util::codec::{BytesCodec, FramedRead};
46use walkdir::WalkDir;
47
48#[cfg(feature = "polars")]
49use polars::prelude::*;
50
51static PART_SIZE: usize = 100 * 1024 * 1024;
52
53/// Source for file content during upload - either a local path or raw bytes.
54#[derive(Clone)]
55enum FileSource {
56 /// File content from a local filesystem path.
57 Path(PathBuf),
58 /// File content as raw bytes (e.g., from a ZIP archive).
59 Bytes(Vec<u8>),
60}
61
62fn max_tasks() -> usize {
63 std::env::var("MAX_TASKS")
64 .ok()
65 .and_then(|v| v.parse().ok())
66 .unwrap_or_else(|| {
67 // Default to half the number of CPUs, minimum 2, maximum 8
68 let cpus = std::thread::available_parallelism()
69 .map(|n| n.get())
70 .unwrap_or(4);
71 (cpus / 2).clamp(2, 8)
72 })
73}
74
75/// Maximum concurrent upload tasks for multipart S3 uploads.
76///
77/// Higher concurrency improves upload throughput by saturating available
78/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
79fn max_upload_tasks() -> usize {
80 std::env::var("MAX_UPLOAD_TASKS")
81 .ok()
82 .and_then(|v| v.parse().ok())
83 .unwrap_or(8) // Default to 8 concurrent part uploads
84}
85
86/// Filters items by name and sorts by match quality.
87///
88/// Match quality priority (best to worst):
89/// 1. Exact match (case-sensitive)
90/// 2. Exact match (case-insensitive)
91/// 3. Substring match (shorter names first, then alphabetically)
92///
93/// This ensures that searching for "Deer" returns "Deer" before
94/// "Deer Roundtrip 20251129" or "Reindeer".
95fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
96where
97 F: Fn(&T) -> &str,
98{
99 let filter_lower = filter.to_lowercase();
100 let mut filtered: Vec<T> = items
101 .into_iter()
102 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
103 .collect();
104
105 filtered.sort_by(|a, b| {
106 let name_a = get_name(a);
107 let name_b = get_name(b);
108
109 // Priority 1: Exact match (case-sensitive)
110 let exact_a = name_a == filter;
111 let exact_b = name_b == filter;
112 if exact_a != exact_b {
113 return exact_b.cmp(&exact_a); // true (exact) comes first
114 }
115
116 // Priority 2: Exact match (case-insensitive)
117 let exact_ci_a = name_a.to_lowercase() == filter_lower;
118 let exact_ci_b = name_b.to_lowercase() == filter_lower;
119 if exact_ci_a != exact_ci_b {
120 return exact_ci_b.cmp(&exact_ci_a);
121 }
122
123 // Priority 3: Shorter names first (more specific matches)
124 let len_cmp = name_a.len().cmp(&name_b.len());
125 if len_cmp != std::cmp::Ordering::Equal {
126 return len_cmp;
127 }
128
129 // Priority 4: Alphabetical order for stability
130 name_a.cmp(name_b)
131 });
132
133 filtered
134}
135
136fn sanitize_path_component(name: &str) -> String {
137 let trimmed = name.trim();
138 if trimmed.is_empty() {
139 return "unnamed".to_string();
140 }
141
142 let component = Path::new(trimmed)
143 .file_name()
144 .unwrap_or_else(|| OsStr::new(trimmed));
145
146 let sanitized: String = component
147 .to_string_lossy()
148 .chars()
149 .map(|c| match c {
150 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
151 _ => c,
152 })
153 .collect();
154
155 if sanitized.is_empty() {
156 "unnamed".to_string()
157 } else {
158 sanitized
159 }
160}
161
162/// Progress information for long-running operations.
163///
164/// This struct tracks the current progress of operations like file uploads,
165/// downloads, or dataset processing. It provides the current count, total
166/// count, and an optional status string to enable progress reporting in
167/// applications.
168///
169/// # Multi-Stage Progress
170///
171/// The `status` field enables multi-stage progress tracking. When an operation
172/// has multiple phases, the status field changes to indicate the current phase.
173/// Applications should detect status changes to reset their progress display.
174///
175/// # Operation Progress Details
176///
177/// | Operation | Status | Unit | Notes |
178/// |-----------|--------|------|-------|
179/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
180/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
181/// | [`samples`] | `None` | samples | Paginated API fetch |
182/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
183/// | [`annotations`] | `None` | samples | Samples processed for annotations |
184/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
185/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
186/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
187///
188/// [`download_dataset`]: Client::download_dataset
189/// [`populate_samples`]: Client::populate_samples
190/// [`samples`]: Client::samples
191/// [`sample_names`]: Client::sample_names
192/// [`annotations`]: Client::annotations
193/// [`download_artifact`]: Client::download_artifact
194/// [`download_checkpoint`]: Client::download_checkpoint
195/// [`download_snapshot`]: Client::download_snapshot
196///
197/// # Examples
198///
199/// Basic progress display:
200///
201/// ```rust
202/// use edgefirst_client::Progress;
203///
204/// let progress = Progress {
205/// current: 25,
206/// total: 100,
207/// status: Some("Downloading".to_string()),
208/// };
209/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
210/// println!(
211/// "{}: {:.1}% ({}/{})",
212/// progress.status.as_deref().unwrap_or("Progress"),
213/// percentage,
214/// progress.current,
215/// progress.total
216/// );
217/// ```
218///
219/// Multi-stage progress handling (e.g., for `download_dataset`):
220///
221/// ```rust,ignore
222/// let mut last_status: Option<String> = None;
223///
224/// while let Some(progress) = rx.recv().await {
225/// // Detect stage change and reset progress bar
226/// if progress.status != last_status {
227/// if let Some(ref status) = progress.status {
228/// println!("\n{}", status);
229/// }
230/// last_status = progress.status.clone();
231/// }
232///
233/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
234/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
235/// }
236/// ```
237#[derive(Debug, Clone)]
238pub struct Progress {
239 /// Current number of completed items or bytes.
240 pub current: usize,
241 /// Total number of items or bytes to process.
242 pub total: usize,
243 /// Optional status describing the current operation phase.
244 ///
245 /// When this value changes from `None` to `Some(...)` or between different
246 /// values, it indicates a new phase has started. Applications should reset
247 /// their progress display when the status changes.
248 ///
249 /// Currently only [`Client::download_dataset`] uses status changes:
250 /// - Phase 1: `None` while fetching sample metadata
251 /// - Phase 2: `"Downloading"` while downloading files
252 ///
253 /// All other operations use `None` throughout.
254 pub status: Option<String>,
255}
256
257#[derive(Serialize)]
258struct RpcRequest<Params> {
259 id: u64,
260 jsonrpc: String,
261 method: String,
262 params: Option<Params>,
263}
264
265impl<T> Default for RpcRequest<T> {
266 fn default() -> Self {
267 RpcRequest {
268 id: 0,
269 jsonrpc: "2.0".to_string(),
270 method: "".to_string(),
271 params: None,
272 }
273 }
274}
275
276#[derive(Deserialize)]
277struct RpcError {
278 code: i32,
279 message: String,
280}
281
282#[derive(Deserialize)]
283struct RpcResponse<RpcResult> {
284 #[allow(dead_code)]
285 id: String,
286 #[allow(dead_code)]
287 jsonrpc: String,
288 error: Option<RpcError>,
289 result: Option<RpcResult>,
290}
291
292#[derive(Deserialize)]
293#[allow(dead_code)]
294struct EmptyResult {}
295
296#[derive(Debug, Serialize)]
297#[allow(dead_code)]
298struct SnapshotCreateParams {
299 snapshot_name: String,
300 keys: Vec<String>,
301}
302
303#[derive(Debug, Deserialize)]
304#[allow(dead_code)]
305struct SnapshotCreateResult {
306 snapshot_id: SnapshotID,
307 urls: Vec<String>,
308}
309
310#[derive(Debug, Serialize)]
311struct SnapshotCreateMultipartParams {
312 snapshot_name: String,
313 keys: Vec<String>,
314 file_sizes: Vec<usize>,
315 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
316 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
317 snapshot_type: Option<String>,
318}
319
320#[derive(Debug, Deserialize)]
321#[serde(untagged)]
322enum SnapshotCreateMultipartResultField {
323 Id(u64),
324 Part(SnapshotPart),
325}
326
327#[derive(Debug, Serialize)]
328struct SnapshotCompleteMultipartParams {
329 key: String,
330 upload_id: String,
331 etag_list: Vec<EtagPart>,
332}
333
334#[derive(Debug, Clone, Serialize)]
335struct EtagPart {
336 #[serde(rename = "ETag")]
337 etag: String,
338 #[serde(rename = "PartNumber")]
339 part_number: usize,
340}
341
342#[derive(Debug, Clone, Deserialize)]
343struct SnapshotPart {
344 key: Option<String>,
345 upload_id: String,
346 urls: Vec<String>,
347}
348
349#[derive(Debug, Serialize)]
350struct SnapshotStatusParams {
351 snapshot_id: SnapshotID,
352 status: String,
353}
354
355#[derive(Deserialize, Debug)]
356struct SnapshotStatusResult {
357 #[allow(dead_code)]
358 pub id: SnapshotID,
359 #[allow(dead_code)]
360 pub uid: String,
361 #[allow(dead_code)]
362 pub description: String,
363 #[allow(dead_code)]
364 pub date: String,
365 #[allow(dead_code)]
366 pub status: String,
367}
368
369#[derive(Serialize)]
370#[allow(dead_code)]
371struct ImageListParams {
372 images_filter: ImagesFilter,
373 image_files_filter: HashMap<String, String>,
374 only_ids: bool,
375}
376
377#[derive(Serialize)]
378#[allow(dead_code)]
379struct ImagesFilter {
380 dataset_id: DatasetID,
381}
382
383/// Main client for interacting with EdgeFirst Studio Server.
384///
385/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
386/// and manages authentication, RPC calls, and data operations. It provides
387/// methods for managing projects, datasets, experiments, training sessions,
388/// and various utility functions for data processing.
389///
390/// The client supports multiple authentication methods and can work with both
391/// SaaS and self-hosted EdgeFirst Studio instances.
392///
393/// # Features
394///
395/// - **Authentication**: Token-based authentication with automatic persistence
396/// - **Dataset Management**: Upload, download, and manipulate datasets
397/// - **Project Operations**: Create and manage projects and experiments
398/// - **Training & Validation**: Submit and monitor ML training jobs
399/// - **Data Integration**: Convert between EdgeFirst datasets and popular
400/// formats
401/// - **Progress Tracking**: Real-time progress updates for long-running
402/// operations
403///
404/// # Examples
405///
406/// ```no_run
407/// use edgefirst_client::{Client, DatasetID};
408/// use std::str::FromStr;
409///
410/// # async fn example() -> Result<(), edgefirst_client::Error> {
411/// // Create a new client and authenticate
412/// let mut client = Client::new()?;
413/// let client = client
414/// .with_login("your-email@example.com", "password")
415/// .await?;
416///
417/// // Or use an existing token
418/// let base_client = Client::new()?;
419/// let client = base_client.with_token("your-token-here")?;
420///
421/// // Get organization and projects
422/// let org = client.organization().await?;
423/// let projects = client.projects(None).await?;
424///
425/// // Work with datasets
426/// let dataset_id = DatasetID::from_str("ds-abc123")?;
427/// let dataset = client.dataset(dataset_id).await?;
428/// # Ok(())
429/// # }
430/// ```
431/// Client is Clone but cannot derive Debug due to dyn TokenStorage
432#[derive(Clone)]
433pub struct Client {
434 http: reqwest::Client,
435 /// HTTP client for long-running bulk transfers (uploads/downloads, no total-request
436 /// timeout). An idle read timeout is still configured on the underlying client, and
437 /// some operations (such as uploads) may apply additional per-request timeouts.
438 bulk_http: reqwest::Client,
439 url: String,
440 token: Arc<RwLock<String>>,
441 /// Token storage backend. When set, tokens are automatically persisted.
442 storage: Option<Arc<dyn TokenStorage>>,
443 /// Legacy token path field for backwards compatibility with
444 /// with_token_path(). Deprecated: Use with_storage() instead.
445 token_path: Option<PathBuf>,
446}
447
448impl std::fmt::Debug for Client {
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 f.debug_struct("Client")
451 .field("url", &self.url)
452 .field("has_storage", &self.storage.is_some())
453 .field("token_path", &self.token_path)
454 .finish()
455 }
456}
457
458/// Private context struct for pagination operations
459struct FetchContext<'a> {
460 dataset_id: DatasetID,
461 annotation_set_id: Option<AnnotationSetID>,
462 groups: &'a [String],
463 types: Vec<String>,
464 labels: &'a HashMap<String, u64>,
465}
466
467impl Client {
468 /// Create a new unauthenticated client with the default saas server.
469 ///
470 /// By default, the client uses [`FileTokenStorage`] for token persistence.
471 /// Use [`with_storage`][Self::with_storage],
472 /// [`with_memory_storage`][Self::with_memory_storage],
473 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
474 /// behavior.
475 ///
476 /// To connect to a different server, use [`with_server`][Self::with_server]
477 /// or [`with_token`][Self::with_token] (tokens include the server
478 /// instance).
479 ///
480 /// This client is created without a token and will need to authenticate
481 /// before using methods that require authentication.
482 ///
483 /// # Examples
484 ///
485 /// ```rust,no_run
486 /// use edgefirst_client::Client;
487 ///
488 /// # fn main() -> Result<(), edgefirst_client::Error> {
489 /// // Create client with default file storage
490 /// let client = Client::new()?;
491 ///
492 /// // Create client without token persistence
493 /// let client = Client::new()?.with_memory_storage();
494 /// # Ok(())
495 /// # }
496 /// ```
497 pub fn new() -> Result<Self, Error> {
498 log_retry_configuration();
499
500 // Get timeout from environment or use default
501 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
502 .ok()
503 .and_then(|s| s.parse().ok())
504 .unwrap_or(30); // Default 30s total deadline for API calls
505
506 // Per-chunk idle timeout for bulk transfers: fires only when no bytes
507 // arrive for this duration. Resets after every received chunk, so a
508 // healthy multi-GB transfer will never be interrupted.
509 let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
510 .ok()
511 .and_then(|s| s.parse().ok())
512 .unwrap_or(120); // Default 120s idle timeout for bulk transfers
513
514 // Create single HTTP client with URL-based retry policy
515 //
516 // The retry policy classifies requests into two categories:
517 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
518 // errors
519 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
520 //
521 // This allows the same client to handle both API calls and file operations
522 // with appropriate retry behavior for each. See retry.rs for details.
523 let http = reqwest::Client::builder()
524 .connect_timeout(Duration::from_secs(10))
525 .timeout(Duration::from_secs(timeout_secs))
526 .pool_idle_timeout(Duration::from_secs(90))
527 .pool_max_idle_per_host(10)
528 .retry(create_retry_policy())
529 .build()?;
530
531 // Separate HTTP client for bulk transfers (uploads and downloads).
532 // No total-request timeout (EDGEFIRST_TIMEOUT does not apply here).
533 // Uses read_timeout instead: resets after every received chunk, so a
534 // healthy large transfer is never interrupted, but a truly stalled
535 // connection (no bytes for EDGEFIRST_READ_TIMEOUT seconds) is aborted.
536 let bulk_http = reqwest::Client::builder()
537 .connect_timeout(Duration::from_secs(30))
538 .read_timeout(Duration::from_secs(read_timeout_secs))
539 .pool_idle_timeout(Duration::from_secs(90))
540 .pool_max_idle_per_host(10)
541 .retry(create_retry_policy())
542 .build()?;
543
544 // Default to file storage, loading any existing token
545 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
546 Ok(file_storage) => Arc::new(file_storage),
547 Err(e) => {
548 warn!(
549 "Could not initialize file token storage: {}. Using memory storage.",
550 e
551 );
552 Arc::new(MemoryTokenStorage::new())
553 }
554 };
555
556 // Try to load existing token from storage
557 let token = match storage.load() {
558 Ok(Some(t)) => t,
559 Ok(None) => String::new(),
560 Err(e) => {
561 warn!(
562 "Failed to load token from storage: {}. Starting with empty token.",
563 e
564 );
565 String::new()
566 }
567 };
568
569 // Extract server from token if available
570 let url = if !token.is_empty() {
571 match Self::extract_server_from_token(&token) {
572 Ok(server) => format!("https://{}.edgefirst.studio", server),
573 Err(e) => {
574 warn!(
575 "Failed to extract server from token: {}. Using default server.",
576 e
577 );
578 "https://edgefirst.studio".to_string()
579 }
580 }
581 } else {
582 "https://edgefirst.studio".to_string()
583 };
584
585 Ok(Client {
586 http,
587 bulk_http,
588 url,
589 token: Arc::new(tokio::sync::RwLock::new(token)),
590 storage: Some(storage),
591 token_path: None,
592 })
593 }
594
595 /// Returns a new client connected to the specified server instance.
596 ///
597 /// The server parameter is an instance name that maps to a URL:
598 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
599 /// server)
600 /// - `"test"` → `https://test.edgefirst.studio`
601 /// - `"stage"` → `https://stage.edgefirst.studio`
602 /// - `"dev"` → `https://dev.edgefirst.studio`
603 /// - `"{name}"` → `https://{name}.edgefirst.studio`
604 ///
605 /// # Server Selection Priority
606 ///
607 /// When using the CLI or Python API, server selection follows this
608 /// priority:
609 ///
610 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
611 /// they were issued for. If you have a valid token, its server is used.
612 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
613 /// token is available. If a token exists with a different server, a
614 /// warning is emitted and the token's server takes priority.
615 /// 3. **Default `"saas"`** - If no token and no server specified, the
616 /// production server (`https://edgefirst.studio`) is used.
617 ///
618 /// # Important Notes
619 ///
620 /// - If a token is already set in the client, calling this method will
621 /// **drop the token** as tokens are specific to the server instance.
622 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
623 /// token's server before calling this method.
624 /// - For login operations, call `with_server()` first, then authenticate.
625 ///
626 /// # Examples
627 ///
628 /// ```rust,no_run
629 /// use edgefirst_client::Client;
630 ///
631 /// # fn main() -> Result<(), edgefirst_client::Error> {
632 /// let client = Client::new()?.with_server("test")?;
633 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
634 /// # Ok(())
635 /// # }
636 /// ```
637 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
638 let url = match server {
639 "" | "saas" => "https://edgefirst.studio".to_string(),
640 name => format!("https://{}.edgefirst.studio", name),
641 };
642
643 // Clear token from storage when changing servers to prevent
644 // authentication issues with stale tokens from different instances
645 if let Some(ref storage) = self.storage
646 && let Err(e) = storage.clear()
647 {
648 warn!(
649 "Failed to clear token from storage when changing servers: {}",
650 e
651 );
652 }
653
654 Ok(Client {
655 url,
656 token: Arc::new(tokio::sync::RwLock::new(String::new())),
657 ..self.clone()
658 })
659 }
660
661 /// Returns a new client with the specified token storage backend.
662 ///
663 /// Use this to configure custom token storage, such as platform-specific
664 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
665 ///
666 /// # Examples
667 ///
668 /// ```rust,no_run
669 /// use edgefirst_client::{Client, FileTokenStorage};
670 /// use std::{path::PathBuf, sync::Arc};
671 ///
672 /// # fn main() -> Result<(), edgefirst_client::Error> {
673 /// // Use a custom file path for token storage
674 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
675 /// let client = Client::new()?.with_storage(Arc::new(storage));
676 /// # Ok(())
677 /// # }
678 /// ```
679 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
680 // Try to load existing token from the new storage
681 let token = match storage.load() {
682 Ok(Some(t)) => t,
683 Ok(None) => String::new(),
684 Err(e) => {
685 warn!(
686 "Failed to load token from storage: {}. Starting with empty token.",
687 e
688 );
689 String::new()
690 }
691 };
692
693 Client {
694 token: Arc::new(tokio::sync::RwLock::new(token)),
695 storage: Some(storage),
696 token_path: None,
697 ..self
698 }
699 }
700
701 /// Returns a new client with in-memory token storage (no persistence).
702 ///
703 /// Tokens are stored in memory only and lost when the application exits.
704 /// This is useful for testing or when you want to manage token persistence
705 /// externally.
706 ///
707 /// # Examples
708 ///
709 /// ```rust,no_run
710 /// use edgefirst_client::Client;
711 ///
712 /// # fn main() -> Result<(), edgefirst_client::Error> {
713 /// let client = Client::new()?.with_memory_storage();
714 /// # Ok(())
715 /// # }
716 /// ```
717 pub fn with_memory_storage(self) -> Self {
718 Client {
719 token: Arc::new(tokio::sync::RwLock::new(String::new())),
720 storage: Some(Arc::new(MemoryTokenStorage::new())),
721 token_path: None,
722 ..self
723 }
724 }
725
726 /// Returns a new client with no token storage.
727 ///
728 /// Tokens are not persisted. Use this when you want to manage tokens
729 /// entirely manually.
730 ///
731 /// # Examples
732 ///
733 /// ```rust,no_run
734 /// use edgefirst_client::Client;
735 ///
736 /// # fn main() -> Result<(), edgefirst_client::Error> {
737 /// let client = Client::new()?.with_no_storage();
738 /// # Ok(())
739 /// # }
740 /// ```
741 pub fn with_no_storage(self) -> Self {
742 Client {
743 storage: None,
744 token_path: None,
745 ..self
746 }
747 }
748
749 /// Returns a new client authenticated with the provided username and
750 /// password.
751 ///
752 /// The token is automatically persisted to storage (if configured).
753 ///
754 /// # Examples
755 ///
756 /// ```rust,no_run
757 /// use edgefirst_client::Client;
758 ///
759 /// # async fn example() -> Result<(), edgefirst_client::Error> {
760 /// let client = Client::new()?
761 /// .with_server("test")?
762 /// .with_login("user@example.com", "password")
763 /// .await?;
764 /// # Ok(())
765 /// # }
766 /// ```
767 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
768 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
769 let params = HashMap::from([("username", username), ("password", password)]);
770 let login: LoginResult = self
771 .rpc_without_auth("auth.login".to_owned(), Some(params))
772 .await?;
773
774 // Validate that the server returned a non-empty token
775 if login.token.is_empty() {
776 return Err(Error::EmptyToken);
777 }
778
779 // Persist token to storage if configured
780 if let Some(ref storage) = self.storage
781 && let Err(e) = storage.store(&login.token)
782 {
783 warn!("Failed to persist token to storage: {}", e);
784 }
785
786 Ok(Client {
787 token: Arc::new(tokio::sync::RwLock::new(login.token)),
788 ..self.clone()
789 })
790 }
791
792 /// Returns a new client which will load and save the token to the specified
793 /// path.
794 ///
795 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
796 /// [`FileTokenStorage`] instead for more flexible token management.
797 ///
798 /// This method is maintained for backwards compatibility with existing
799 /// code. It disables the default storage and uses file-based storage at
800 /// the specified path.
801 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
802 let token_path = match token_path {
803 Some(path) => path.to_path_buf(),
804 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
805 .ok_or_else(|| {
806 Error::IoError(std::io::Error::new(
807 std::io::ErrorKind::NotFound,
808 "Could not determine user config directory",
809 ))
810 })?
811 .config_dir()
812 .join("token"),
813 };
814
815 debug!("Using token path (legacy): {:?}", token_path);
816
817 let token = match token_path.exists() {
818 true => std::fs::read_to_string(&token_path)?,
819 false => "".to_string(),
820 };
821
822 if !token.is_empty() {
823 match self.with_token(&token) {
824 Ok(client) => Ok(Client {
825 token_path: Some(token_path),
826 storage: None, // Disable new storage when using legacy token_path
827 ..client
828 }),
829 Err(e) => {
830 // Token is corrupted or invalid - remove it and continue with no token
831 warn!(
832 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
833 token_path, e
834 );
835 if let Err(remove_err) = std::fs::remove_file(&token_path) {
836 warn!("Failed to remove corrupted token file: {:?}", remove_err);
837 }
838 // Clear any token from default storage to ensure we don't use it
839 Ok(Client {
840 token_path: Some(token_path),
841 storage: None,
842 token: Arc::new(RwLock::new("".to_string())),
843 ..self.clone()
844 })
845 }
846 }
847 } else {
848 // No token in the legacy file - clear any token from default storage
849 Ok(Client {
850 token_path: Some(token_path),
851 storage: None,
852 token: Arc::new(RwLock::new("".to_string())),
853 ..self.clone()
854 })
855 }
856 }
857
858 /// Returns a new client authenticated with the provided token.
859 ///
860 /// The token is automatically persisted to storage (if configured).
861 /// The server URL is extracted from the token payload.
862 ///
863 /// # Examples
864 ///
865 /// ```rust,no_run
866 /// use edgefirst_client::Client;
867 ///
868 /// # fn main() -> Result<(), edgefirst_client::Error> {
869 /// let client = Client::new()?.with_token("your-jwt-token")?;
870 /// # Ok(())
871 /// # }
872 /// ```
873 /// Extract server name from JWT token payload.
874 ///
875 /// Helper method to parse the JWT token and extract the "server" field
876 /// from the payload. Returns the server name (e.g., "test", "stage", "")
877 /// or an error if the token is invalid.
878 fn extract_server_from_token(token: &str) -> Result<String, Error> {
879 let token_parts: Vec<&str> = token.split('.').collect();
880 if token_parts.len() != 3 {
881 return Err(Error::InvalidToken);
882 }
883
884 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
885 .decode(token_parts[1])
886 .map_err(|_| Error::InvalidToken)?;
887 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
888 let server = match payload.get("server") {
889 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
890 None => return Err(Error::InvalidToken),
891 };
892
893 Ok(server)
894 }
895
896 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
897 if token.is_empty() {
898 return Ok(self.clone());
899 }
900
901 let server = Self::extract_server_from_token(token)?;
902
903 // Persist token to storage if configured
904 if let Some(ref storage) = self.storage
905 && let Err(e) = storage.store(token)
906 {
907 warn!("Failed to persist token to storage: {}", e);
908 }
909
910 Ok(Client {
911 url: format!("https://{}.edgefirst.studio", server),
912 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
913 ..self.clone()
914 })
915 }
916
917 /// Persist the current token to storage.
918 ///
919 /// This is automatically called when using [`with_login`][Self::with_login]
920 /// or [`with_token`][Self::with_token], so you typically don't need to call
921 /// this directly.
922 ///
923 /// If using the legacy `token_path` configuration, saves to the file path.
924 /// If using the new storage abstraction, saves to the configured storage.
925 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
926 pub async fn save_token(&self) -> Result<(), Error> {
927 let token = self.token.read().await;
928
929 // Try new storage first
930 if let Some(ref storage) = self.storage {
931 storage.store(&token)?;
932 debug!("Token saved to storage");
933 return Ok(());
934 }
935
936 // Fall back to legacy token_path behavior
937 let path = self.token_path.clone().unwrap_or_else(|| {
938 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
939 .map(|dirs| dirs.config_dir().join("token"))
940 .unwrap_or_else(|| PathBuf::from(".token"))
941 });
942
943 create_dir_all(path.parent().ok_or_else(|| {
944 Error::IoError(std::io::Error::new(
945 std::io::ErrorKind::InvalidInput,
946 "Token path has no parent directory",
947 ))
948 })?)?;
949 let mut file = std::fs::File::create(&path)?;
950 file.write_all(token.as_bytes())?;
951
952 debug!("Saved token to {:?}", path);
953
954 Ok(())
955 }
956
957 /// Return the version of the EdgeFirst Studio server for the current
958 /// client connection.
959 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
960 pub async fn version(&self) -> Result<String, Error> {
961 let version: HashMap<String, String> = self
962 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
963 .await?;
964 let version = version.get("version").ok_or(Error::InvalidResponse)?;
965 Ok(version.to_owned())
966 }
967
968 /// Clear the token used to authenticate the client with the server.
969 ///
970 /// Clears the token from memory and from storage (if configured).
971 /// If using the legacy `token_path` configuration, removes the token file.
972 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
973 pub async fn logout(&self) -> Result<(), Error> {
974 {
975 let mut token = self.token.write().await;
976 *token = "".to_string();
977 }
978
979 // Clear from new storage if configured
980 if let Some(ref storage) = self.storage
981 && let Err(e) = storage.clear()
982 {
983 warn!("Failed to clear token from storage: {}", e);
984 }
985
986 // Also clear legacy token_path if configured
987 if let Some(path) = &self.token_path
988 && path.exists()
989 {
990 fs::remove_file(path).await?;
991 }
992
993 Ok(())
994 }
995
996 /// Return the token used to authenticate the client with the server. When
997 /// logging into the server using a username and password, the token is
998 /// returned by the server and stored in the client for future interactions.
999 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1000 pub async fn token(&self) -> String {
1001 self.token.read().await.clone()
1002 }
1003
1004 /// Verify the token used to authenticate the client with the server. This
1005 /// method is used to ensure that the token is still valid and has not
1006 /// expired. If the token is invalid, the server will return an error and
1007 /// the client will need to login again.
1008 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1009 pub async fn verify_token(&self) -> Result<(), Error> {
1010 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
1011 .await?;
1012 Ok::<(), Error>(())
1013 }
1014
1015 /// Renew the token used to authenticate the client with the server.
1016 ///
1017 /// Refreshes the token before it expires. If the token has already expired,
1018 /// the server will return an error and you will need to login again.
1019 ///
1020 /// The new token is automatically persisted to storage (if configured).
1021 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1022 pub async fn renew_token(&self) -> Result<(), Error> {
1023 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1024 let result: LoginResult = self
1025 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1026 .await?;
1027
1028 {
1029 let mut token = self.token.write().await;
1030 *token = result.token.clone();
1031 }
1032
1033 // Persist to new storage if configured
1034 if let Some(ref storage) = self.storage
1035 && let Err(e) = storage.store(&result.token)
1036 {
1037 warn!("Failed to persist renewed token to storage: {}", e);
1038 }
1039
1040 // Also persist to legacy token_path if configured
1041 if self.token_path.is_some() {
1042 self.save_token().await?;
1043 }
1044
1045 Ok(())
1046 }
1047
1048 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1049 let token = self.token.read().await;
1050 if token.is_empty() {
1051 return Err(Error::EmptyToken);
1052 }
1053
1054 let token_parts: Vec<&str> = token.split('.').collect();
1055 if token_parts.len() != 3 {
1056 return Err(Error::InvalidToken);
1057 }
1058
1059 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1060 .decode(token_parts[1])
1061 .map_err(|_| Error::InvalidToken)?;
1062 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1063 match payload.get(field) {
1064 Some(value) => Ok(value.to_owned()),
1065 None => Err(Error::InvalidToken),
1066 }
1067 }
1068
1069 /// Returns the URL of the EdgeFirst Studio server for the current client.
1070 pub fn url(&self) -> &str {
1071 &self.url
1072 }
1073
1074 /// Returns the server name for the current client.
1075 ///
1076 /// This extracts the server name from the client's URL:
1077 /// - `https://edgefirst.studio` → `"saas"`
1078 /// - `https://test.edgefirst.studio` → `"test"`
1079 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1080 ///
1081 /// # Examples
1082 ///
1083 /// ```rust,no_run
1084 /// use edgefirst_client::Client;
1085 ///
1086 /// # fn main() -> Result<(), edgefirst_client::Error> {
1087 /// let client = Client::new()?.with_server("test")?;
1088 /// assert_eq!(client.server(), "test");
1089 ///
1090 /// let client = Client::new()?; // default
1091 /// assert_eq!(client.server(), "saas");
1092 /// # Ok(())
1093 /// # }
1094 /// ```
1095 pub fn server(&self) -> &str {
1096 if self.url == "https://edgefirst.studio" {
1097 "saas"
1098 } else if let Some(name) = self.url.strip_prefix("https://") {
1099 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1100 } else {
1101 "saas"
1102 }
1103 }
1104
1105 /// Returns the username associated with the current token.
1106 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1107 pub async fn username(&self) -> Result<String, Error> {
1108 match self.token_field("username").await? {
1109 serde_json::Value::String(username) => Ok(username),
1110 _ => Err(Error::InvalidToken),
1111 }
1112 }
1113
1114 /// Returns the expiration time for the current token.
1115 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1116 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1117 let ts = match self.token_field("exp").await? {
1118 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1119 _ => return Err(Error::InvalidToken),
1120 };
1121
1122 match DateTime::<Utc>::from_timestamp(ts, 0) {
1123 Some(dt) => Ok(dt),
1124 None => Err(Error::InvalidToken),
1125 }
1126 }
1127
1128 /// Returns the organization information for the current user.
1129 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1130 pub async fn organization(&self) -> Result<Organization, Error> {
1131 self.rpc::<(), Organization>("org.get".to_owned(), None)
1132 .await
1133 }
1134
1135 /// Returns a list of projects available to the user. The projects are
1136 /// returned as a vector of Project objects. If a name filter is
1137 /// provided, only projects matching the filter are returned.
1138 ///
1139 /// Results are sorted by match quality: exact matches first, then
1140 /// case-insensitive exact matches, then shorter names (more specific),
1141 /// then alphabetically.
1142 ///
1143 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1144 /// Projects contain datasets, trainers, and trainer sessions. Projects
1145 /// are used to group related datasets and trainers together.
1146 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1147 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1148 let projects = self
1149 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1150 .await?;
1151 if let Some(name) = name {
1152 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1153 } else {
1154 Ok(projects)
1155 }
1156 }
1157
1158 /// Return the project with the specified project ID. If the project does
1159 /// not exist, an error is returned.
1160 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1161 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1162 let params = HashMap::from([("project_id", project_id)]);
1163 self.rpc("project.get".to_owned(), Some(params)).await
1164 }
1165
1166 /// Returns a list of datasets available to the user. The datasets are
1167 /// returned as a vector of Dataset objects. If a name filter is
1168 /// provided, only datasets matching the filter are returned.
1169 ///
1170 /// Results are sorted by match quality: exact matches first, then
1171 /// case-insensitive exact matches, then shorter names (more specific),
1172 /// then alphabetically. This ensures "Deer" returns before "Deer
1173 /// Roundtrip".
1174 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1175 pub async fn datasets(
1176 &self,
1177 project_id: ProjectID,
1178 name: Option<&str>,
1179 ) -> Result<Vec<Dataset>, Error> {
1180 let params = HashMap::from([("project_id", project_id)]);
1181 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1182 if let Some(name) = name {
1183 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1184 } else {
1185 Ok(datasets)
1186 }
1187 }
1188
1189 /// Return the dataset with the specified dataset ID. If the dataset does
1190 /// not exist, an error is returned.
1191 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1192 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1193 let params = HashMap::from([("dataset_id", dataset_id)]);
1194 self.rpc("dataset.get".to_owned(), Some(params)).await
1195 }
1196
1197 /// Lists the labels for the specified dataset.
1198 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1199 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1200 let params = HashMap::from([("dataset_id", dataset_id)]);
1201 self.rpc("label.list".to_owned(), Some(params)).await
1202 }
1203
1204 /// Add a new label to the dataset with the specified name.
1205 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1206 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1207 let new_label = NewLabel {
1208 dataset_id,
1209 labels: vec![NewLabelObject {
1210 name: name.to_owned(),
1211 }],
1212 };
1213 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1214 Ok(())
1215 }
1216
1217 /// Removes the label with the specified ID from the dataset. Label IDs are
1218 /// globally unique so the dataset_id is not required.
1219 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1220 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1221 let params = HashMap::from([("label_id", label_id)]);
1222 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1223 Ok(())
1224 }
1225
1226 /// Creates a new dataset in the specified project.
1227 ///
1228 /// # Arguments
1229 ///
1230 /// * `project_id` - The ID of the project to create the dataset in
1231 /// * `name` - The name of the new dataset
1232 /// * `description` - Optional description for the dataset
1233 ///
1234 /// # Returns
1235 ///
1236 /// Returns the dataset ID of the newly created dataset.
1237 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1238 pub async fn create_dataset(
1239 &self,
1240 project_id: &str,
1241 name: &str,
1242 description: Option<&str>,
1243 ) -> Result<DatasetID, Error> {
1244 let mut params = HashMap::new();
1245 params.insert("project_id", project_id);
1246 params.insert("name", name);
1247 if let Some(desc) = description {
1248 params.insert("description", desc);
1249 }
1250
1251 #[derive(Deserialize)]
1252 struct CreateDatasetResult {
1253 id: DatasetID,
1254 }
1255
1256 let result: CreateDatasetResult =
1257 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1258 Ok(result.id)
1259 }
1260
1261 /// Deletes a dataset by marking it as deleted.
1262 ///
1263 /// # Arguments
1264 ///
1265 /// * `dataset_id` - The ID of the dataset to delete
1266 ///
1267 /// # Returns
1268 ///
1269 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1270 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1271 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1272 let params = HashMap::from([("id", dataset_id)]);
1273 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1274 Ok(())
1275 }
1276
1277 /// Updates the label with the specified ID to have the new name or index.
1278 /// Label IDs cannot be changed. Label IDs are globally unique so the
1279 /// dataset_id is not required.
1280 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1281 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1282 #[derive(Serialize)]
1283 struct Params {
1284 dataset_id: DatasetID,
1285 label_id: u64,
1286 label_name: String,
1287 label_index: u64,
1288 }
1289
1290 let _: String = self
1291 .rpc(
1292 "label.update".to_owned(),
1293 Some(Params {
1294 dataset_id: label.dataset_id(),
1295 label_id: label.id(),
1296 label_name: label.name().to_owned(),
1297 label_index: label.index(),
1298 }),
1299 )
1300 .await?;
1301 Ok(())
1302 }
1303
1304 /// Lists the groups for the specified dataset.
1305 ///
1306 /// Groups are used to organize samples into logical subsets such as
1307 /// "train", "val", "test", etc. Each sample can belong to at most one
1308 /// group at a time.
1309 ///
1310 /// # Arguments
1311 ///
1312 /// * `dataset_id` - The ID of the dataset to list groups for
1313 ///
1314 /// # Returns
1315 ///
1316 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1317 /// empty vector if no groups have been created yet.
1318 ///
1319 /// # Errors
1320 ///
1321 /// Returns an error if the dataset does not exist or cannot be accessed.
1322 ///
1323 /// # Example
1324 ///
1325 /// ```rust,no_run
1326 /// # use edgefirst_client::{Client, DatasetID};
1327 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1328 /// let client = Client::new()?.with_token_path(None)?;
1329 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1330 ///
1331 /// let groups = client.groups(dataset_id).await?;
1332 /// for group in groups {
1333 /// println!("{}: {}", group.id, group.name);
1334 /// }
1335 /// # Ok(())
1336 /// # }
1337 /// ```
1338 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1339 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1340 let params = HashMap::from([("dataset_id", dataset_id)]);
1341 self.rpc("groups.list".to_owned(), Some(params)).await
1342 }
1343
1344 /// Gets an existing group by name or creates a new one.
1345 ///
1346 /// This is a convenience method that first checks if a group with the
1347 /// specified name exists, and creates it if not. This is useful when
1348 /// you need to ensure a group exists before assigning samples to it.
1349 ///
1350 /// # Arguments
1351 ///
1352 /// * `dataset_id` - The ID of the dataset
1353 /// * `name` - The name of the group (e.g., "train", "val", "test")
1354 ///
1355 /// # Returns
1356 ///
1357 /// Returns the group ID (either existing or newly created).
1358 ///
1359 /// # Errors
1360 ///
1361 /// Returns an error if:
1362 /// - The dataset does not exist or cannot be accessed
1363 /// - The group creation fails
1364 ///
1365 /// # Concurrency
1366 ///
1367 /// This method handles concurrent creation attempts gracefully. If another
1368 /// process creates the group between the existence check and creation,
1369 /// this method will return the existing group's ID.
1370 ///
1371 /// # Example
1372 ///
1373 /// ```rust,no_run
1374 /// # use edgefirst_client::{Client, DatasetID};
1375 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1376 /// let client = Client::new()?.with_token_path(None)?;
1377 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1378 ///
1379 /// // Get or create a "train" group
1380 /// let train_group_id = client
1381 /// .get_or_create_group(dataset_id.clone(), "train")
1382 /// .await?;
1383 /// println!("Train group ID: {}", train_group_id);
1384 ///
1385 /// // Calling again returns the same ID
1386 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1387 /// assert_eq!(train_group_id, same_id);
1388 /// # Ok(())
1389 /// # }
1390 /// ```
1391 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1392 pub async fn get_or_create_group(
1393 &self,
1394 dataset_id: DatasetID,
1395 name: &str,
1396 ) -> Result<u64, Error> {
1397 // First check if the group already exists
1398 let groups = self.groups(dataset_id).await?;
1399 if let Some(group) = groups.iter().find(|g| g.name == name) {
1400 return Ok(group.id);
1401 }
1402
1403 // Create the group
1404 #[derive(Serialize)]
1405 struct CreateGroupParams {
1406 dataset_id: DatasetID,
1407 group_names: Vec<String>,
1408 group_splits: Vec<i64>,
1409 }
1410
1411 let params = CreateGroupParams {
1412 dataset_id,
1413 group_names: vec![name.to_string()],
1414 group_splits: vec![0], // No automatic splitting
1415 };
1416
1417 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1418 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1419 Ok(group.id)
1420 } else {
1421 // Group might have been created by concurrent call, try fetching again
1422 let groups = self.groups(dataset_id).await?;
1423 groups
1424 .iter()
1425 .find(|g| g.name == name)
1426 .map(|g| g.id)
1427 .ok_or_else(|| {
1428 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1429 })
1430 }
1431 }
1432
1433 /// Sets the group for a sample.
1434 ///
1435 /// Assigns a sample to a specific group. Each sample can belong to at most
1436 /// one group at a time. Setting a new group replaces any existing group
1437 /// assignment.
1438 ///
1439 /// # Arguments
1440 ///
1441 /// * `sample_id` - The ID of the sample (image) to update
1442 /// * `group_id` - The ID of the group to assign. Use
1443 /// [`get_or_create_group`] to obtain a group ID from a name.
1444 ///
1445 /// # Returns
1446 ///
1447 /// Returns `Ok(())` on success.
1448 ///
1449 /// # Errors
1450 ///
1451 /// Returns an error if:
1452 /// - The sample does not exist
1453 /// - The group does not exist
1454 /// - Insufficient permissions to modify the sample
1455 ///
1456 /// # Example
1457 ///
1458 /// ```rust,no_run
1459 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1460 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1461 /// let client = Client::new()?.with_token_path(None)?;
1462 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1463 /// let sample_id: SampleID = 12345.into();
1464 ///
1465 /// // Get or create the "val" group
1466 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1467 ///
1468 /// // Assign the sample to the "val" group
1469 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1470 /// # Ok(())
1471 /// # }
1472 /// ```
1473 ///
1474 /// [`get_or_create_group`]: Self::get_or_create_group
1475 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1476 pub async fn set_sample_group_id(
1477 &self,
1478 sample_id: SampleID,
1479 group_id: u64,
1480 ) -> Result<(), Error> {
1481 #[derive(Serialize)]
1482 struct SetGroupParams {
1483 image_id: SampleID,
1484 group_id: u64,
1485 }
1486
1487 let params = SetGroupParams {
1488 image_id: sample_id,
1489 group_id,
1490 };
1491 let _: String = self
1492 .rpc("image.set_group_id".to_owned(), Some(params))
1493 .await?;
1494 Ok(())
1495 }
1496
1497 /// Downloads dataset samples to the local filesystem.
1498 ///
1499 /// # Arguments
1500 ///
1501 /// * `dataset_id` - The unique identifier of the dataset
1502 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1503 /// * `file_types` - File types to download. Supported types:
1504 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
1505 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
1506 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
1507 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
1508 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
1509 /// - `FileType::RadarCube` - Radar cube data (.png format)
1510 /// - `FileType::All` - All sensor types (expands to all of the above)
1511 /// * `output` - Local directory to save downloaded files
1512 /// * `flatten` - If true, download all files to output root without
1513 /// sequence subdirectories. When flattening, filenames are prefixed with
1514 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1515 /// unavailable) unless the filename already starts with
1516 /// `{sequence_name}_`, to avoid conflicts between sequences.
1517 /// * `progress` - Optional channel for progress updates
1518 ///
1519 /// # Progress
1520 ///
1521 /// This operation has two phases with distinct progress reporting:
1522 ///
1523 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
1524 /// from the server. Progress counts samples fetched.
1525 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
1526 /// files to disk. Progress counts samples completed (each sample may
1527 /// have multiple files for different sensor types).
1528 ///
1529 /// Applications should detect the status change from `None` to
1530 /// `"Downloading"` to reset their progress bar for the second phase.
1531 ///
1532 /// # Returns
1533 ///
1534 /// Returns `Ok(())` on success or an error if download fails.
1535 ///
1536 /// # Example
1537 ///
1538 /// ```rust,no_run
1539 /// # use edgefirst_client::{Client, DatasetID, FileType};
1540 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1541 /// let client = Client::new()?.with_token_path(None)?;
1542 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1543 ///
1544 /// // Download with sequence subdirectories (default)
1545 /// client
1546 /// .download_dataset(
1547 /// dataset_id,
1548 /// &[],
1549 /// &[FileType::Image],
1550 /// "./data".into(),
1551 /// false,
1552 /// None,
1553 /// )
1554 /// .await?;
1555 ///
1556 /// // Download flattened (all files in one directory)
1557 /// client
1558 /// .download_dataset(
1559 /// dataset_id,
1560 /// &[],
1561 /// &[FileType::Image],
1562 /// "./data".into(),
1563 /// true,
1564 /// None,
1565 /// )
1566 /// .await?;
1567 ///
1568 /// // Download all sensor types
1569 /// client
1570 /// .download_dataset(
1571 /// dataset_id,
1572 /// &[],
1573 /// &FileType::expand_types(&[FileType::All]),
1574 /// "./data".into(),
1575 /// false,
1576 /// None,
1577 /// )
1578 /// .await?;
1579 /// # Ok(())
1580 /// # }
1581 /// ```
1582 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
1583 pub async fn download_dataset(
1584 &self,
1585 dataset_id: DatasetID,
1586 groups: &[String],
1587 file_types: &[FileType],
1588 output: PathBuf,
1589 flatten: bool,
1590 progress: Option<Sender<Progress>>,
1591 ) -> Result<(), Error> {
1592 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
1593 let samples = self
1594 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1595 .await?;
1596 fs::create_dir_all(&output).await?;
1597
1598 // Phase 2: Download actual files using direct semaphore pattern
1599 let total = samples.len();
1600 let current = Arc::new(AtomicUsize::new(0));
1601 let sem = Arc::new(Semaphore::new(max_tasks()));
1602
1603 // Send initial progress for download phase
1604 if let Some(ref progress) = progress {
1605 let _ = progress
1606 .send(Progress {
1607 current: 0,
1608 total,
1609 status: Some("Downloading".to_string()),
1610 })
1611 .await;
1612 }
1613
1614 let tasks = samples
1615 .into_iter()
1616 .map(|sample| {
1617 let client = self.clone();
1618 let file_types = file_types.to_vec();
1619 let output = output.clone();
1620 let progress = progress.clone();
1621 let current = current.clone();
1622 let sem = sem.clone();
1623
1624 tokio::spawn(async move {
1625 let _permit = sem.acquire().await.map_err(|_| {
1626 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1627 })?;
1628
1629 for file_type in &file_types {
1630 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1631 let (file_ext, is_image) = match file_type {
1632 FileType::Image => (
1633 infer::get(&data)
1634 .expect("Failed to identify image file format for sample")
1635 .extension()
1636 .to_string(),
1637 true,
1638 ),
1639 other => (other.file_extension().to_string(), false),
1640 };
1641
1642 // Determine target directory based on sequence membership and
1643 // flatten option
1644 // - flatten=false + sequence_name: dataset/sequence_name/
1645 // - flatten=false + no sequence: dataset/ (root level)
1646 // - flatten=true: dataset/ (all files in output root)
1647 // NOTE: group (train/val/test) is NOT used for directory structure
1648 let sequence_dir = sample
1649 .sequence_name()
1650 .map(|name| sanitize_path_component(name));
1651
1652 let target_dir = if flatten {
1653 output.clone()
1654 } else {
1655 sequence_dir
1656 .as_ref()
1657 .map(|seq| output.join(seq))
1658 .unwrap_or_else(|| output.clone())
1659 };
1660 fs::create_dir_all(&target_dir).await?;
1661
1662 let sanitized_sample_name = sample
1663 .name()
1664 .map(|name| sanitize_path_component(&name))
1665 .unwrap_or_else(|| "unknown".to_string());
1666
1667 let image_name = sample.image_name().map(sanitize_path_component);
1668
1669 // Construct filename with smart prefixing for flatten mode
1670 // When flatten=true and sample belongs to a sequence:
1671 // - Check if filename already starts with "{sequence_name}_"
1672 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1673 // - If yes, use filename as-is (already uniquely named)
1674 let file_name = if is_image {
1675 if let Some(img_name) = image_name {
1676 Client::build_filename(
1677 &img_name,
1678 flatten,
1679 sequence_dir.as_ref(),
1680 sample.frame_number(),
1681 )
1682 } else {
1683 format!("{}.{}", sanitized_sample_name, file_ext)
1684 }
1685 } else {
1686 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1687 Client::build_filename(
1688 &base_name,
1689 flatten,
1690 sequence_dir.as_ref(),
1691 sample.frame_number(),
1692 )
1693 };
1694
1695 let file_path = target_dir.join(&file_name);
1696
1697 let mut file = File::create(&file_path).await?;
1698 file.write_all(&data).await?;
1699 }
1700 }
1701
1702 // Update progress after sample completes
1703 if let Some(progress) = &progress {
1704 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
1705 let _ = progress
1706 .send(Progress {
1707 current: completed,
1708 total,
1709 status: Some("Downloading".to_string()),
1710 })
1711 .await;
1712 }
1713
1714 Ok::<(), Error>(())
1715 })
1716 })
1717 .collect::<Vec<_>>();
1718
1719 join_all(tasks)
1720 .await
1721 .into_iter()
1722 .collect::<Result<Vec<_>, _>>()?
1723 .into_iter()
1724 .collect::<Result<Vec<_>, _>>()?;
1725
1726 Ok(())
1727 }
1728
1729 /// Builds a filename with smart prefixing for flatten mode.
1730 ///
1731 /// When flattening sequences into a single directory, this function ensures
1732 /// unique filenames by checking if the sequence prefix already exists and
1733 /// adding it if necessary.
1734 ///
1735 /// # Logic
1736 ///
1737 /// - If `flatten=false`: returns `base_name` unchanged
1738 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
1739 /// - If `flatten=true` and in sequence:
1740 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
1741 /// unchanged
1742 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
1743 /// `{sequence_name}_{base_name}`
1744 fn build_filename(
1745 base_name: &str,
1746 flatten: bool,
1747 sequence_name: Option<&String>,
1748 frame_number: Option<u32>,
1749 ) -> String {
1750 if !flatten || sequence_name.is_none() {
1751 return base_name.to_string();
1752 }
1753
1754 let seq_name = sequence_name.unwrap();
1755 let prefix = format!("{}_", seq_name);
1756
1757 // Check if already prefixed with sequence name
1758 if base_name.starts_with(&prefix) {
1759 base_name.to_string()
1760 } else {
1761 // Add sequence (and optionally frame) prefix
1762 match frame_number {
1763 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
1764 None => format!("{}{}", prefix, base_name),
1765 }
1766 }
1767 }
1768
1769 /// List available annotation sets for the specified dataset.
1770 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1771 pub async fn annotation_sets(
1772 &self,
1773 dataset_id: DatasetID,
1774 ) -> Result<Vec<AnnotationSet>, Error> {
1775 let params = HashMap::from([("dataset_id", dataset_id)]);
1776 self.rpc("annset.list".to_owned(), Some(params)).await
1777 }
1778
1779 /// Create a new annotation set for the specified dataset.
1780 ///
1781 /// # Arguments
1782 ///
1783 /// * `dataset_id` - The ID of the dataset to create the annotation set in
1784 /// * `name` - The name of the new annotation set
1785 /// * `description` - Optional description for the annotation set
1786 ///
1787 /// # Returns
1788 ///
1789 /// Returns the annotation set ID of the newly created annotation set.
1790 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1791 pub async fn create_annotation_set(
1792 &self,
1793 dataset_id: DatasetID,
1794 name: &str,
1795 description: Option<&str>,
1796 ) -> Result<AnnotationSetID, Error> {
1797 #[derive(Serialize)]
1798 struct Params<'a> {
1799 dataset_id: DatasetID,
1800 name: &'a str,
1801 operator: &'a str,
1802 #[serde(skip_serializing_if = "Option::is_none")]
1803 description: Option<&'a str>,
1804 }
1805
1806 #[derive(Deserialize)]
1807 struct CreateAnnotationSetResult {
1808 id: AnnotationSetID,
1809 }
1810
1811 let username = self.username().await?;
1812 let result: CreateAnnotationSetResult = self
1813 .rpc(
1814 "annset.add".to_owned(),
1815 Some(Params {
1816 dataset_id,
1817 name,
1818 operator: &username,
1819 description,
1820 }),
1821 )
1822 .await?;
1823 Ok(result.id)
1824 }
1825
1826 /// Deletes an annotation set by marking it as deleted.
1827 ///
1828 /// # Arguments
1829 ///
1830 /// * `annotation_set_id` - The ID of the annotation set to delete
1831 ///
1832 /// # Returns
1833 ///
1834 /// Returns `Ok(())` if the annotation set was successfully marked as
1835 /// deleted.
1836 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
1837 pub async fn delete_annotation_set(
1838 &self,
1839 annotation_set_id: AnnotationSetID,
1840 ) -> Result<(), Error> {
1841 let params = HashMap::from([("id", annotation_set_id)]);
1842 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
1843 Ok(())
1844 }
1845
1846 /// Retrieve the annotation set with the specified ID.
1847 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
1848 pub async fn annotation_set(
1849 &self,
1850 annotation_set_id: AnnotationSetID,
1851 ) -> Result<AnnotationSet, Error> {
1852 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
1853 self.rpc("annset.get".to_owned(), Some(params)).await
1854 }
1855
1856 /// Get the annotations for the specified annotation set with the
1857 /// requested annotation types. The annotation types are used to filter
1858 /// the annotations returned. The groups parameter is used to filter for
1859 /// dataset groups (train, val, test). Images which do not have any
1860 /// annotations are also included in the result as long as they are in the
1861 /// requested groups (when specified).
1862 ///
1863 /// The result is a vector of Annotations objects which contain the
1864 /// full dataset along with the annotations for the specified types.
1865 ///
1866 /// # Progress
1867 ///
1868 /// Reports progress with `status: None` as samples are fetched and
1869 /// processed for their annotations. Progress unit is samples processed
1870 /// (not individual annotations).
1871 ///
1872 /// To get the annotations as a DataFrame, use the `samples_dataframe`
1873 /// method instead.
1874 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
1875 pub async fn annotations(
1876 &self,
1877 annotation_set_id: AnnotationSetID,
1878 groups: &[String],
1879 annotation_types: &[AnnotationType],
1880 progress: Option<Sender<Progress>>,
1881 ) -> Result<Vec<Annotation>, Error> {
1882 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
1883 let labels = self
1884 .labels(dataset_id)
1885 .await?
1886 .into_iter()
1887 .map(|label| (label.name().to_string(), label.index()))
1888 .collect::<HashMap<_, _>>();
1889 let total = self
1890 .samples_count(
1891 dataset_id,
1892 Some(annotation_set_id),
1893 annotation_types,
1894 groups,
1895 &[],
1896 )
1897 .await?
1898 .total as usize;
1899
1900 if total == 0 {
1901 return Ok(vec![]);
1902 }
1903
1904 let context = FetchContext {
1905 dataset_id,
1906 annotation_set_id: Some(annotation_set_id),
1907 groups,
1908 types: annotation_types.iter().map(|t| t.to_string()).collect(),
1909 labels: &labels,
1910 };
1911
1912 self.fetch_annotations_paginated(context, total, progress)
1913 .await
1914 }
1915
1916 async fn fetch_annotations_paginated(
1917 &self,
1918 context: FetchContext<'_>,
1919 total: usize,
1920 progress: Option<Sender<Progress>>,
1921 ) -> Result<Vec<Annotation>, Error> {
1922 let mut annotations = vec![];
1923 let mut continue_token: Option<String> = None;
1924 let mut current = 0;
1925
1926 loop {
1927 let params = SamplesListParams {
1928 dataset_id: context.dataset_id,
1929 annotation_set_id: context.annotation_set_id,
1930 types: context.types.clone(),
1931 group_names: context.groups.to_vec(),
1932 continue_token,
1933 };
1934
1935 let result: SamplesListResult =
1936 self.rpc("samples.list".to_owned(), Some(params)).await?;
1937 current += result.samples.len();
1938 continue_token = result.continue_token;
1939
1940 if result.samples.is_empty() {
1941 break;
1942 }
1943
1944 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
1945
1946 if let Some(progress) = &progress {
1947 let _ = progress
1948 .send(Progress {
1949 current,
1950 total,
1951 status: None,
1952 })
1953 .await;
1954 }
1955
1956 match &continue_token {
1957 Some(token) if !token.is_empty() => continue,
1958 _ => break,
1959 }
1960 }
1961
1962 drop(progress);
1963 Ok(annotations)
1964 }
1965
1966 fn process_sample_annotations(
1967 &self,
1968 samples: &[Sample],
1969 labels: &HashMap<String, u64>,
1970 annotations: &mut Vec<Annotation>,
1971 ) {
1972 for sample in samples {
1973 if sample.annotations().is_empty() {
1974 let mut annotation = Annotation::new();
1975 annotation.set_sample_id(sample.id());
1976 annotation.set_name(sample.name());
1977 annotation.set_sequence_name(sample.sequence_name().cloned());
1978 annotation.set_frame_number(sample.frame_number());
1979 annotation.set_group(sample.group().cloned());
1980 annotations.push(annotation);
1981 continue;
1982 }
1983
1984 for annotation in sample.annotations() {
1985 let mut annotation = annotation.clone();
1986 annotation.set_sample_id(sample.id());
1987 annotation.set_name(sample.name());
1988 annotation.set_sequence_name(sample.sequence_name().cloned());
1989 annotation.set_frame_number(sample.frame_number());
1990 annotation.set_group(sample.group().cloned());
1991 Self::set_label_index_from_map(&mut annotation, labels);
1992 annotations.push(annotation);
1993 }
1994 }
1995 }
1996
1997 /// Delete annotations in bulk from specified samples.
1998 ///
1999 /// This method calls the `annotation.bulk.del` API to efficiently remove
2000 /// annotations from multiple samples at once. Useful for clearing
2001 /// annotations before re-importing updated data.
2002 ///
2003 /// # Arguments
2004 /// * `annotation_set_id` - The annotation set containing the annotations
2005 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
2006 /// for masks
2007 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
2008 ///
2009 /// # Example
2010 /// ```no_run
2011 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
2012 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2013 /// # let client = Client::new()?.with_login("user", "pass").await?;
2014 /// let annotation_set_id = AnnotationSetID::from(123);
2015 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2016 ///
2017 /// client
2018 /// .delete_annotations_bulk(
2019 /// annotation_set_id,
2020 /// &["box".to_string(), "seg".to_string()],
2021 /// &sample_ids,
2022 /// )
2023 /// .await?;
2024 /// # Ok(())
2025 /// # }
2026 /// ```
2027 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2028 pub async fn delete_annotations_bulk(
2029 &self,
2030 annotation_set_id: AnnotationSetID,
2031 annotation_types: &[String],
2032 sample_ids: &[SampleID],
2033 ) -> Result<(), Error> {
2034 use crate::api::AnnotationBulkDeleteParams;
2035
2036 let params = AnnotationBulkDeleteParams {
2037 annotation_set_id: annotation_set_id.into(),
2038 annotation_types: annotation_types.to_vec(),
2039 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2040 delete_all: None,
2041 };
2042
2043 let _: String = self
2044 .rpc("annotation.bulk.del".to_owned(), Some(params))
2045 .await?;
2046 Ok(())
2047 }
2048
2049 /// Add annotations in bulk.
2050 ///
2051 /// This method calls the `annotation.add_bulk` API to efficiently add
2052 /// multiple annotations at once. The annotations must be in server format
2053 /// with image_id references.
2054 ///
2055 /// # Arguments
2056 /// * `annotation_set_id` - The annotation set to add annotations to
2057 /// * `annotations` - Vector of server-format annotations to add
2058 ///
2059 /// # Returns
2060 /// Vector of created annotation records from the server.
2061 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2062 pub async fn add_annotations_bulk(
2063 &self,
2064 annotation_set_id: AnnotationSetID,
2065 annotations: Vec<crate::api::ServerAnnotation>,
2066 ) -> Result<Vec<serde_json::Value>, Error> {
2067 use crate::api::AnnotationAddBulkParams;
2068
2069 let params = AnnotationAddBulkParams {
2070 annotation_set_id: annotation_set_id.into(),
2071 annotations,
2072 };
2073
2074 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2075 .await
2076 }
2077
2078 /// Helper to parse frame number from image_name when sequence_name is
2079 /// present. This ensures frame_number is always derived from the image
2080 /// filename, not from the server's frame_number field (which may be
2081 /// inconsistent).
2082 ///
2083 /// Returns Some(frame_number) if sequence_name is present and frame can be
2084 /// parsed, otherwise None.
2085 fn parse_frame_from_image_name(
2086 image_name: Option<&String>,
2087 sequence_name: Option<&String>,
2088 ) -> Option<u32> {
2089 use std::path::Path;
2090
2091 let sequence = sequence_name?;
2092 let name = image_name?;
2093
2094 // Extract stem (remove extension)
2095 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2096
2097 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2098 stem.strip_prefix(sequence)
2099 .and_then(|suffix| suffix.strip_prefix('_'))
2100 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2101 }
2102
2103 /// Helper to set label index from a label map
2104 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2105 if let Some(label) = annotation.label() {
2106 annotation.set_label_index(Some(labels[label.as_str()]));
2107 }
2108 }
2109
2110 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2111 pub async fn samples_count(
2112 &self,
2113 dataset_id: DatasetID,
2114 annotation_set_id: Option<AnnotationSetID>,
2115 annotation_types: &[AnnotationType],
2116 groups: &[String],
2117 types: &[FileType],
2118 ) -> Result<SamplesCountResult, Error> {
2119 // Use server type names for API calls (e.g., "box" instead of "box2d")
2120 let types = annotation_types
2121 .iter()
2122 .map(|t| t.as_server_type().to_string())
2123 .chain(types.iter().map(|t| t.to_string()))
2124 .collect::<Vec<_>>();
2125
2126 let params = SamplesListParams {
2127 dataset_id,
2128 annotation_set_id,
2129 group_names: groups.to_vec(),
2130 types,
2131 continue_token: None,
2132 };
2133
2134 self.rpc("samples.count".to_owned(), Some(params)).await
2135 }
2136
2137 /// Fetches samples from a dataset with optional annotation and file type
2138 /// filters.
2139 ///
2140 /// # Arguments
2141 ///
2142 /// * `dataset_id` - The dataset to fetch samples from
2143 /// * `annotation_set_id` - Optional annotation set to include annotations
2144 /// from
2145 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2146 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2147 /// * `types` - File types to include metadata for
2148 /// * `progress` - Optional channel for progress updates
2149 ///
2150 /// # Progress
2151 ///
2152 /// Reports progress with `status: None` as samples are fetched from the
2153 /// server in paginated batches. Progress unit is samples fetched.
2154 ///
2155 /// # Returns
2156 ///
2157 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2158 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2159 pub async fn samples(
2160 &self,
2161 dataset_id: DatasetID,
2162 annotation_set_id: Option<AnnotationSetID>,
2163 annotation_types: &[AnnotationType],
2164 groups: &[String],
2165 types: &[FileType],
2166 progress: Option<Sender<Progress>>,
2167 ) -> Result<Vec<Sample>, Error> {
2168 // Use server type names for API calls (e.g., "box" instead of "box2d")
2169 let types_vec = annotation_types
2170 .iter()
2171 .map(|t| t.as_server_type().to_string())
2172 .chain(types.iter().map(|t| t.to_string()))
2173 .collect::<Vec<_>>();
2174 let labels = self
2175 .labels(dataset_id)
2176 .await?
2177 .into_iter()
2178 .map(|label| (label.name().to_string(), label.index()))
2179 .collect::<HashMap<_, _>>();
2180 let total = self
2181 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2182 .await?
2183 .total as usize;
2184
2185 if total == 0 {
2186 return Ok(vec![]);
2187 }
2188
2189 let context = FetchContext {
2190 dataset_id,
2191 annotation_set_id,
2192 groups,
2193 types: types_vec,
2194 labels: &labels,
2195 };
2196
2197 self.fetch_samples_paginated(context, total, progress).await
2198 }
2199
2200 /// Get all sample names in a dataset.
2201 ///
2202 /// This is an efficient method for checking which samples already exist,
2203 /// useful for resuming interrupted imports. It only retrieves sample names
2204 /// without loading full annotation data.
2205 ///
2206 /// # Arguments
2207 ///
2208 /// * `dataset_id` - The dataset to query
2209 /// * `groups` - Optional group filter (empty = all groups)
2210 /// * `progress` - Optional progress channel
2211 ///
2212 /// # Progress
2213 ///
2214 /// Reports progress with `status: None` as sample names are fetched from
2215 /// the server in paginated batches. Progress unit is samples fetched.
2216 ///
2217 /// # Returns
2218 ///
2219 /// A HashSet of sample names (image_name field) that exist in the dataset.
2220 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2221 pub async fn sample_names(
2222 &self,
2223 dataset_id: DatasetID,
2224 groups: &[String],
2225 progress: Option<Sender<Progress>>,
2226 ) -> Result<std::collections::HashSet<String>, Error> {
2227 use std::collections::HashSet;
2228
2229 let total = self
2230 .samples_count(dataset_id, None, &[], groups, &[])
2231 .await?
2232 .total as usize;
2233
2234 if total == 0 {
2235 return Ok(HashSet::new());
2236 }
2237
2238 let mut names = HashSet::with_capacity(total);
2239 let mut continue_token: Option<String> = None;
2240 let mut current = 0;
2241
2242 loop {
2243 let params = SamplesListParams {
2244 dataset_id,
2245 annotation_set_id: None,
2246 types: vec![], // No type filter - we just want names
2247 group_names: groups.to_vec(),
2248 continue_token: continue_token.clone(),
2249 };
2250
2251 let result: SamplesListResult =
2252 self.rpc("samples.list".to_owned(), Some(params)).await?;
2253 current += result.samples.len();
2254 continue_token = result.continue_token;
2255
2256 if result.samples.is_empty() {
2257 break;
2258 }
2259
2260 // Extract sample names (normalized without extension)
2261 for sample in result.samples {
2262 if let Some(name) = sample.name() {
2263 names.insert(name);
2264 }
2265 }
2266
2267 if let Some(ref p) = progress {
2268 let _ = p
2269 .send(Progress {
2270 current,
2271 total,
2272 status: None,
2273 })
2274 .await;
2275 }
2276
2277 match &continue_token {
2278 Some(token) if !token.is_empty() => continue,
2279 _ => break,
2280 }
2281 }
2282
2283 Ok(names)
2284 }
2285
2286 async fn fetch_samples_paginated(
2287 &self,
2288 context: FetchContext<'_>,
2289 total: usize,
2290 progress: Option<Sender<Progress>>,
2291 ) -> Result<Vec<Sample>, Error> {
2292 let mut samples = vec![];
2293 let mut continue_token: Option<String> = None;
2294 let mut current = 0;
2295
2296 loop {
2297 let params = SamplesListParams {
2298 dataset_id: context.dataset_id,
2299 annotation_set_id: context.annotation_set_id,
2300 types: context.types.clone(),
2301 group_names: context.groups.to_vec(),
2302 continue_token: continue_token.clone(),
2303 };
2304
2305 let result: SamplesListResult =
2306 self.rpc("samples.list".to_owned(), Some(params)).await?;
2307 current += result.samples.len();
2308 continue_token = result.continue_token;
2309
2310 if result.samples.is_empty() {
2311 break;
2312 }
2313
2314 samples.append(
2315 &mut result
2316 .samples
2317 .into_iter()
2318 .map(|s| {
2319 // Use server's frame_number if valid (>= 0 after deserialization)
2320 // Otherwise parse from image_name as fallback
2321 // This ensures we respect explicit frame_number from uploads
2322 // while still handling legacy data that only has filename encoding
2323 let frame_number = s.frame_number.or_else(|| {
2324 Self::parse_frame_from_image_name(
2325 s.image_name.as_ref(),
2326 s.sequence_name.as_ref(),
2327 )
2328 });
2329
2330 let mut anns = s.annotations().to_vec();
2331 for ann in &mut anns {
2332 // Set annotation fields from parent sample
2333 ann.set_name(s.name());
2334 ann.set_group(s.group().cloned());
2335 ann.set_sequence_name(s.sequence_name().cloned());
2336 ann.set_frame_number(frame_number);
2337 Self::set_label_index_from_map(ann, context.labels);
2338 }
2339 s.with_annotations(anns).with_frame_number(frame_number)
2340 })
2341 .collect::<Vec<_>>(),
2342 );
2343
2344 if let Some(progress) = &progress {
2345 let _ = progress
2346 .send(Progress {
2347 current,
2348 total,
2349 status: None,
2350 })
2351 .await;
2352 }
2353
2354 match &continue_token {
2355 Some(token) if !token.is_empty() => continue,
2356 _ => break,
2357 }
2358 }
2359
2360 drop(progress);
2361 Ok(samples)
2362 }
2363
2364 /// Populates (imports) samples into a dataset using the `samples.populate2`
2365 /// API.
2366 ///
2367 /// This method creates new samples in the specified dataset, optionally
2368 /// with annotations and sensor data files. For each sample, the `files`
2369 /// field is checked for local file paths. If a filename is a valid path
2370 /// to an existing file, the file will be automatically uploaded to S3
2371 /// using presigned URLs returned by the server. The filename in the
2372 /// request is replaced with the basename (path removed) before sending
2373 /// to the server.
2374 ///
2375 /// # Important Notes
2376 ///
2377 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2378 /// annotations. Without it, the server will accept the request but will
2379 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2380 /// available annotation sets for a dataset, or create a new one via the
2381 /// Studio UI.
2382 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2383 /// boxes. Divide pixel coordinates by image width/height before creating
2384 /// [`Box2d`](crate::Box2d) annotations.
2385 /// - **Files are uploaded automatically** when the filename is a valid
2386 /// local path. The method will replace the full path with just the
2387 /// basename before sending to the server.
2388 /// - **Image dimensions are extracted automatically** for image files using
2389 /// the `imagesize` crate. The width/height are sent to the server, but
2390 /// note that the server currently doesn't return these fields when
2391 /// fetching samples back.
2392 /// - **UUIDs are generated automatically** if not provided. If you need
2393 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
2394 /// that the server doesn't currently return UUIDs in sample queries.
2395 ///
2396 /// # Arguments
2397 ///
2398 /// * `dataset_id` - The ID of the dataset to populate
2399 /// * `annotation_set_id` - **Required** if samples contain annotations,
2400 /// otherwise they will be ignored. Query with
2401 /// [`Client::annotation_sets`].
2402 /// * `samples` - Vector of samples to import with metadata and file
2403 /// references. For files, use the full local path - it will be uploaded
2404 /// automatically. UUIDs and image dimensions will be
2405 /// auto-generated/extracted if not provided.
2406 /// * `progress` - Optional channel for progress updates
2407 ///
2408 /// # Progress
2409 ///
2410 /// Reports progress with `status: None` as each sample's files are
2411 /// uploaded. Progress unit is samples (not individual files). Each
2412 /// sample may contain multiple files (image, lidar, radar, etc.) which
2413 /// are all uploaded before the sample is counted as complete.
2414 ///
2415 /// # Returns
2416 ///
2417 /// Returns the API result with sample UUIDs and upload status.
2418 ///
2419 /// # Example
2420 ///
2421 /// ```no_run
2422 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2423 ///
2424 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2425 /// # let client = Client::new()?.with_login("user", "pass").await?;
2426 /// # let dataset_id = DatasetID::from(1);
2427 /// // Query available annotation sets for the dataset
2428 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2429 /// let annotation_set_id = annotation_sets
2430 /// .first()
2431 /// .ok_or_else(|| {
2432 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2433 /// })?
2434 /// .id();
2435 ///
2436 /// // Create sample with annotation (UUID will be auto-generated)
2437 /// let mut sample = Sample::new();
2438 /// sample.width = Some(1920);
2439 /// sample.height = Some(1080);
2440 /// sample.group = Some("train".to_string());
2441 ///
2442 /// // Add file - use full path to local file, it will be uploaded automatically
2443 /// sample.files = vec![SampleFile::with_filename(
2444 /// "image".to_string(),
2445 /// "/path/to/image.jpg".to_string(),
2446 /// )];
2447 ///
2448 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2449 /// let mut annotation = Annotation::new();
2450 /// annotation.set_label(Some("person".to_string()));
2451 /// // Normalize pixel coordinates by dividing by image dimensions
2452 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2453 /// annotation.set_box2d(Some(bbox));
2454 /// sample.annotations = vec![annotation];
2455 ///
2456 /// // Populate with annotation_set_id (REQUIRED for annotations)
2457 /// let result = client
2458 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2459 /// .await?;
2460 /// # Ok(())
2461 /// # }
2462 /// ```
2463 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2464 pub async fn populate_samples(
2465 &self,
2466 dataset_id: DatasetID,
2467 annotation_set_id: Option<AnnotationSetID>,
2468 samples: Vec<Sample>,
2469 progress: Option<Sender<Progress>>,
2470 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2471 self.populate_samples_with_concurrency(
2472 dataset_id,
2473 annotation_set_id,
2474 samples,
2475 progress,
2476 None,
2477 )
2478 .await
2479 }
2480
2481 /// Populate samples with custom upload concurrency.
2482 ///
2483 /// Same as [`populate_samples`](Self::populate_samples) but allows
2484 /// specifying the maximum number of concurrent file uploads. Use this
2485 /// for bulk imports where higher concurrency can significantly reduce
2486 /// upload time.
2487 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2488 pub async fn populate_samples_with_concurrency(
2489 &self,
2490 dataset_id: DatasetID,
2491 annotation_set_id: Option<AnnotationSetID>,
2492 samples: Vec<Sample>,
2493 progress: Option<Sender<Progress>>,
2494 concurrency: Option<usize>,
2495 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2496 use crate::api::SamplesPopulateParams;
2497
2498 // Track which files need to be uploaded
2499 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
2500
2501 // Process samples to detect local files and generate UUIDs
2502 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
2503
2504 let has_files_to_upload = !files_to_upload.is_empty();
2505
2506 // Call populate API with presigned_urls=true if we have files to upload
2507 let params = SamplesPopulateParams {
2508 dataset_id,
2509 annotation_set_id,
2510 presigned_urls: Some(has_files_to_upload),
2511 samples,
2512 };
2513
2514 let results: Vec<crate::SamplesPopulateResult> = self
2515 .rpc("samples.populate2".to_owned(), Some(params))
2516 .await?;
2517
2518 // Upload files if we have any
2519 if has_files_to_upload {
2520 self.upload_sample_files(&results, files_to_upload, progress, concurrency)
2521 .await?;
2522 }
2523
2524 Ok(results)
2525 }
2526
2527 fn prepare_samples_for_upload(
2528 &self,
2529 samples: Vec<Sample>,
2530 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2531 ) -> Result<Vec<Sample>, Error> {
2532 Ok(samples
2533 .into_iter()
2534 .map(|mut sample| {
2535 // Generate UUID if not provided
2536 if sample.uuid.is_none() {
2537 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2538 }
2539
2540 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2541
2542 // Process files: detect local paths and queue for upload
2543 let files_copy = sample.files.clone();
2544 let updated_files: Vec<crate::SampleFile> = files_copy
2545 .iter()
2546 .map(|file| {
2547 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2548 })
2549 .collect();
2550
2551 sample.files = updated_files;
2552 sample
2553 })
2554 .collect())
2555 }
2556
2557 fn process_sample_file(
2558 &self,
2559 file: &crate::SampleFile,
2560 sample_uuid: &str,
2561 sample: &mut Sample,
2562 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2563 ) -> crate::SampleFile {
2564 use std::path::Path;
2565
2566 // Handle files with raw bytes (e.g., from ZIP archives)
2567 if let Some(bytes) = file.bytes()
2568 && let Some(filename) = file.filename()
2569 {
2570 // For image files with bytes, try to extract dimensions if not already set
2571 if file.file_type() == "image"
2572 && (sample.width.is_none() || sample.height.is_none())
2573 && let Ok(size) = imagesize::blob_size(bytes)
2574 {
2575 sample.width = Some(size.width as u32);
2576 sample.height = Some(size.height as u32);
2577 }
2578
2579 // Store the bytes for later upload
2580 files_to_upload.push((
2581 sample_uuid.to_string(),
2582 file.file_type().to_string(),
2583 FileSource::Bytes(bytes.to_vec()),
2584 filename.to_string(),
2585 ));
2586
2587 // Return SampleFile with just the filename
2588 return crate::SampleFile::with_filename(
2589 file.file_type().to_string(),
2590 filename.to_string(),
2591 );
2592 }
2593
2594 // Handle files with local paths
2595 if let Some(filename) = file.filename() {
2596 let path = Path::new(filename);
2597
2598 // Check if this is a valid local file path
2599 if path.exists()
2600 && path.is_file()
2601 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2602 {
2603 // For image files, try to extract dimensions if not already set
2604 if file.file_type() == "image"
2605 && (sample.width.is_none() || sample.height.is_none())
2606 && let Ok(size) = imagesize::size(path)
2607 {
2608 sample.width = Some(size.width as u32);
2609 sample.height = Some(size.height as u32);
2610 }
2611
2612 // Store the full path for later upload
2613 files_to_upload.push((
2614 sample_uuid.to_string(),
2615 file.file_type().to_string(),
2616 FileSource::Path(path.to_path_buf()),
2617 basename.to_string(),
2618 ));
2619
2620 // Return SampleFile with just the basename
2621 return crate::SampleFile::with_filename(
2622 file.file_type().to_string(),
2623 basename.to_string(),
2624 );
2625 }
2626 }
2627 // Return the file unchanged if not a local path
2628 file.clone()
2629 }
2630
2631 async fn upload_sample_files(
2632 &self,
2633 results: &[crate::SamplesPopulateResult],
2634 files_to_upload: Vec<(String, String, FileSource, String)>,
2635 progress: Option<Sender<Progress>>,
2636 concurrency: Option<usize>,
2637 ) -> Result<(), Error> {
2638 // Build a map from (sample_uuid, basename) -> file source
2639 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
2640 for (uuid, _file_type, source, basename) in files_to_upload {
2641 upload_map.insert((uuid, basename), source);
2642 }
2643
2644 let http = self.bulk_http.clone();
2645
2646 // Extract the data we need for parallel upload
2647 let upload_tasks: Vec<_> = results
2648 .iter()
2649 .map(|result| (result.uuid.clone(), result.urls.clone()))
2650 .collect();
2651
2652 parallel_foreach_items(
2653 upload_tasks,
2654 progress.clone(),
2655 concurrency,
2656 move |(uuid, urls)| {
2657 let http = http.clone();
2658 let upload_map = upload_map.clone();
2659
2660 async move {
2661 // Upload all files for this sample
2662 for url_info in &urls {
2663 if let Some(source) =
2664 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2665 {
2666 match source {
2667 FileSource::Path(path) => {
2668 upload_file_to_presigned_url(
2669 http.clone(),
2670 &url_info.url,
2671 path.clone(),
2672 )
2673 .await?;
2674 }
2675 FileSource::Bytes(bytes) => {
2676 upload_bytes_to_presigned_url(
2677 http.clone(),
2678 &url_info.url,
2679 bytes.clone(),
2680 &url_info.filename,
2681 )
2682 .await?;
2683 }
2684 }
2685 }
2686 }
2687
2688 Ok(())
2689 }
2690 },
2691 )
2692 .await
2693 }
2694
2695 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2696 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
2697 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
2698 if !url.starts_with("http://") && !url.starts_with("https://") {
2699 return Err(Error::InvalidParameters(format!(
2700 "Invalid URL (must be absolute): {}",
2701 url
2702 )));
2703 }
2704
2705 let resp = self.bulk_http.get(url).send().await?;
2706
2707 if !resp.status().is_success() {
2708 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2709 }
2710
2711 let bytes = resp.bytes().await?;
2712 Ok(bytes.to_vec())
2713 }
2714
2715 /// Get samples as a DataFrame with complete 2025.10 schema.
2716 ///
2717 /// This is the recommended method for obtaining dataset annotations in
2718 /// DataFrame format. It includes all sample metadata (size, location,
2719 /// pose, degradation) as optional columns.
2720 ///
2721 /// # Arguments
2722 ///
2723 /// * `dataset_id` - Dataset identifier
2724 /// * `annotation_set_id` - Optional annotation set filter
2725 /// * `groups` - Dataset groups to include (train, val, test)
2726 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2727 /// * `progress` - Optional progress callback
2728 ///
2729 /// # Progress
2730 ///
2731 /// Reports progress with `status: None` as samples are fetched from the
2732 /// server in paginated batches. Progress unit is samples fetched. This
2733 /// method delegates to [`samples()`](Self::samples) and shares its
2734 /// progress behavior.
2735 ///
2736 /// # Example
2737 ///
2738 /// ```rust,no_run
2739 /// use edgefirst_client::Client;
2740 ///
2741 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2742 /// # let client = Client::new()?;
2743 /// # let dataset_id = 1.into();
2744 /// # let annotation_set_id = 1.into();
2745 /// let df = client
2746 /// .samples_dataframe(
2747 /// dataset_id,
2748 /// Some(annotation_set_id),
2749 /// &["train".to_string()],
2750 /// &[],
2751 /// None,
2752 /// )
2753 /// .await?;
2754 /// println!("DataFrame shape: {:?}", df.shape());
2755 /// # Ok(())
2756 /// # }
2757 /// ```
2758 #[cfg(feature = "polars")]
2759 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2760 pub async fn samples_dataframe(
2761 &self,
2762 dataset_id: DatasetID,
2763 annotation_set_id: Option<AnnotationSetID>,
2764 groups: &[String],
2765 types: &[AnnotationType],
2766 progress: Option<Sender<Progress>>,
2767 ) -> Result<DataFrame, Error> {
2768 use crate::dataset::samples_dataframe;
2769
2770 let samples = self
2771 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
2772 .await?;
2773 samples_dataframe(&samples)
2774 }
2775
2776 /// List available snapshots. If a name is provided, only snapshots
2777 /// containing that name are returned.
2778 ///
2779 /// Results are sorted by match quality: exact matches first, then
2780 /// case-insensitive exact matches, then shorter descriptions (more
2781 /// specific), then alphabetically.
2782 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2783 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
2784 let snapshots: Vec<Snapshot> = self
2785 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
2786 .await?;
2787 if let Some(name) = name {
2788 Ok(filter_and_sort_by_name(snapshots, name, |s| {
2789 s.description()
2790 }))
2791 } else {
2792 Ok(snapshots)
2793 }
2794 }
2795
2796 /// Get the snapshot with the specified id.
2797 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
2798 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
2799 let params = HashMap::from([("snapshot_id", snapshot_id)]);
2800 self.rpc("snapshots.get".to_owned(), Some(params)).await
2801 }
2802
2803 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
2804 ///
2805 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
2806 /// pairs) that serve two primary purposes:
2807 ///
2808 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
2809 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
2810 /// restored with AGTG (Automatic Ground Truth Generation) and optional
2811 /// auto-depth processing.
2812 ///
2813 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
2814 /// migration between EdgeFirst Studio instances using the create →
2815 /// download → upload → restore workflow.
2816 ///
2817 /// Large files are automatically chunked into 100MB parts and uploaded
2818 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
2819 /// is streamed without loading into memory, maintaining constant memory
2820 /// usage.
2821 ///
2822 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
2823 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
2824 /// better for large files to avoid timeout issues. Higher values (16-32)
2825 /// are better for many small files.
2826 ///
2827 /// # Arguments
2828 ///
2829 /// * `path` - Local file path to MCAP file or directory containing
2830 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
2831 /// * `progress` - Optional channel to receive upload progress updates
2832 ///
2833 /// # Progress
2834 ///
2835 /// Reports progress with `status: None` as file data is uploaded. Progress
2836 /// unit is bytes uploaded. For single files, total is the file size. For
2837 /// directories, total is the combined size of all files.
2838 ///
2839 /// # Returns
2840 ///
2841 /// Returns a `Snapshot` object with ID, description, status, path, and
2842 /// creation timestamp on success.
2843 ///
2844 /// # Errors
2845 ///
2846 /// Returns an error if:
2847 /// * Path doesn't exist or contains invalid UTF-8
2848 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
2849 /// * Upload fails or network error occurs
2850 /// * Server rejects the snapshot
2851 ///
2852 /// # Example
2853 ///
2854 /// ```no_run
2855 /// # use edgefirst_client::{Client, Progress};
2856 /// # use tokio::sync::mpsc;
2857 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2858 /// let client = Client::new()?.with_token_path(None)?;
2859 ///
2860 /// // Upload MCAP file with progress tracking
2861 /// let (tx, mut rx) = mpsc::channel(1);
2862 /// tokio::spawn(async move {
2863 /// while let Some(Progress {
2864 /// current,
2865 /// total,
2866 /// status,
2867 /// }) = rx.recv().await
2868 /// {
2869 /// println!(
2870 /// "{}: {}/{} bytes ({:.1}%)",
2871 /// status.as_deref().unwrap_or("Upload"),
2872 /// current,
2873 /// total,
2874 /// (current as f64 / total as f64) * 100.0
2875 /// );
2876 /// }
2877 /// });
2878 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
2879 /// println!("Created snapshot: {:?}", snapshot.id());
2880 ///
2881 /// // Upload dataset directory (no progress)
2882 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
2883 /// # Ok(())
2884 /// # }
2885 /// ```
2886 ///
2887 /// # See Also
2888 ///
2889 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
2890 /// dataset
2891 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
2892 /// data
2893 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
2894 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
2895 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
2896 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
2897 pub async fn create_snapshot(
2898 &self,
2899 path: &str,
2900 progress: Option<Sender<Progress>>,
2901 ) -> Result<Snapshot, Error> {
2902 let path = Path::new(path);
2903
2904 if path.is_dir() {
2905 let path_str = path.to_str().ok_or_else(|| {
2906 Error::IoError(std::io::Error::new(
2907 std::io::ErrorKind::InvalidInput,
2908 "Path contains invalid UTF-8",
2909 ))
2910 })?;
2911 return self.create_snapshot_folder(path_str, progress).await;
2912 }
2913
2914 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
2915 Error::IoError(std::io::Error::new(
2916 std::io::ErrorKind::InvalidInput,
2917 "Invalid filename",
2918 ))
2919 })?;
2920 let total = path.metadata()?.len() as usize;
2921 let current = Arc::new(AtomicUsize::new(0));
2922
2923 if let Some(progress) = &progress {
2924 let _ = progress
2925 .send(Progress {
2926 current: 0,
2927 total,
2928 status: None,
2929 })
2930 .await;
2931 }
2932
2933 let params = SnapshotCreateMultipartParams {
2934 snapshot_name: name.to_owned(),
2935 keys: vec![name.to_owned()],
2936 file_sizes: vec![total],
2937 snapshot_type: None,
2938 };
2939 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
2940 .rpc(
2941 "snapshots.create_upload_url_multipart".to_owned(),
2942 Some(params),
2943 )
2944 .await?;
2945
2946 let snapshot_id = match multipart.get("snapshot_id") {
2947 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
2948 _ => return Err(Error::InvalidResponse),
2949 };
2950
2951 let snapshot = self.snapshot(snapshot_id).await?;
2952 let part_prefix = snapshot
2953 .path()
2954 .split("::/")
2955 .last()
2956 .ok_or(Error::InvalidResponse)?
2957 .to_owned();
2958 let part_key = format!("{}/{}", part_prefix, name);
2959 let mut part = match multipart.get(&part_key) {
2960 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
2961 _ => return Err(Error::InvalidResponse),
2962 }
2963 .clone();
2964 part.key = Some(part_key);
2965
2966 let params = upload_multipart(
2967 self.bulk_http.clone(),
2968 part.clone(),
2969 path.to_path_buf(),
2970 total,
2971 current,
2972 progress.clone(),
2973 )
2974 .await?;
2975
2976 let complete: String = self
2977 .rpc(
2978 "snapshots.complete_multipart_upload".to_owned(),
2979 Some(params),
2980 )
2981 .await?;
2982 debug!("Snapshot Multipart Complete: {:?}", complete);
2983
2984 let params: SnapshotStatusParams = SnapshotStatusParams {
2985 snapshot_id,
2986 status: "available".to_owned(),
2987 };
2988 let _: SnapshotStatusResult = self
2989 .rpc("snapshots.update".to_owned(), Some(params))
2990 .await?;
2991
2992 if let Some(progress) = progress {
2993 drop(progress);
2994 }
2995
2996 self.snapshot(snapshot_id).await
2997 }
2998
2999 async fn create_snapshot_folder(
3000 &self,
3001 path: &str,
3002 progress: Option<Sender<Progress>>,
3003 ) -> Result<Snapshot, Error> {
3004 let path = Path::new(path);
3005 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3006 Error::IoError(std::io::Error::new(
3007 std::io::ErrorKind::InvalidInput,
3008 "Invalid directory name",
3009 ))
3010 })?;
3011
3012 let files = WalkDir::new(path)
3013 .into_iter()
3014 .filter_map(|entry| entry.ok())
3015 .filter(|entry| entry.file_type().is_file())
3016 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3017 .collect::<Vec<_>>();
3018
3019 let total: usize = files
3020 .iter()
3021 .filter_map(|file| path.join(file).metadata().ok())
3022 .map(|metadata| metadata.len() as usize)
3023 .sum();
3024 let current = Arc::new(AtomicUsize::new(0));
3025
3026 if let Some(progress) = &progress {
3027 let _ = progress
3028 .send(Progress {
3029 current: 0,
3030 total,
3031 status: None,
3032 })
3033 .await;
3034 }
3035
3036 let keys = files
3037 .iter()
3038 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3039 .collect::<Vec<_>>();
3040 let file_sizes = files
3041 .iter()
3042 .filter_map(|key| path.join(key).metadata().ok())
3043 .map(|metadata| metadata.len() as usize)
3044 .collect::<Vec<_>>();
3045
3046 let params = SnapshotCreateMultipartParams {
3047 snapshot_name: name.to_owned(),
3048 keys,
3049 file_sizes,
3050 snapshot_type: None,
3051 };
3052
3053 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3054 .rpc(
3055 "snapshots.create_upload_url_multipart".to_owned(),
3056 Some(params),
3057 )
3058 .await?;
3059
3060 let snapshot_id = match multipart.get("snapshot_id") {
3061 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3062 _ => return Err(Error::InvalidResponse),
3063 };
3064
3065 let snapshot = self.snapshot(snapshot_id).await?;
3066 let part_prefix = snapshot
3067 .path()
3068 .split("::/")
3069 .last()
3070 .ok_or(Error::InvalidResponse)?
3071 .to_owned();
3072
3073 for file in files {
3074 let file_str = file.to_str().ok_or_else(|| {
3075 Error::IoError(std::io::Error::new(
3076 std::io::ErrorKind::InvalidInput,
3077 "File path contains invalid UTF-8",
3078 ))
3079 })?;
3080 let part_key = format!("{}/{}", part_prefix, file_str);
3081 let mut part = match multipart.get(&part_key) {
3082 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3083 _ => return Err(Error::InvalidResponse),
3084 }
3085 .clone();
3086 part.key = Some(part_key);
3087
3088 let params = upload_multipart(
3089 self.bulk_http.clone(),
3090 part.clone(),
3091 path.join(file),
3092 total,
3093 current.clone(),
3094 progress.clone(),
3095 )
3096 .await?;
3097
3098 let complete: String = self
3099 .rpc(
3100 "snapshots.complete_multipart_upload".to_owned(),
3101 Some(params),
3102 )
3103 .await?;
3104 debug!("Snapshot Part Complete: {:?}", complete);
3105 }
3106
3107 let params = SnapshotStatusParams {
3108 snapshot_id,
3109 status: "available".to_owned(),
3110 };
3111 let _: SnapshotStatusResult = self
3112 .rpc("snapshots.update".to_owned(), Some(params))
3113 .await?;
3114
3115 if let Some(progress) = progress {
3116 drop(progress);
3117 }
3118
3119 self.snapshot(snapshot_id).await
3120 }
3121
3122 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3123 ///
3124 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3125 /// This format is the native EdgeFirst Dataset Format used for efficient
3126 /// dataset storage and transfer.
3127 ///
3128 /// # Arguments
3129 ///
3130 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3131 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3132 /// * `description` - Optional description for the snapshot
3133 /// * `progress` - Optional progress channel for upload tracking
3134 ///
3135 /// # File Requirements
3136 ///
3137 /// - Arrow file must have `.arrow` extension
3138 /// - ZIP file must have `.zip` extension
3139 /// - Both files must exist and be readable
3140 ///
3141 /// # Example
3142 ///
3143 /// ```no_run
3144 /// # use edgefirst_client::Client;
3145 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3146 /// let client = Client::new()?.with_token_path(None)?;
3147 ///
3148 /// let snapshot = client
3149 /// .create_snapshot_edgefirst_format(
3150 /// "dataset.arrow",
3151 /// "dataset.zip",
3152 /// Some("My Dataset Snapshot"),
3153 /// None,
3154 /// )
3155 /// .await?;
3156 /// println!("Created snapshot: {}", snapshot.id());
3157 /// # Ok(())
3158 /// # }
3159 /// ```
3160 ///
3161 /// # See Also
3162 ///
3163 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3164 /// folder
3165 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3166 /// dataset
3167 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3168 pub async fn create_snapshot_edgefirst_format(
3169 &self,
3170 arrow_path: &str,
3171 zip_path: &str,
3172 description: Option<&str>,
3173 progress: Option<Sender<Progress>>,
3174 ) -> Result<Snapshot, Error> {
3175 let arrow_path = Path::new(arrow_path);
3176 let zip_path = Path::new(zip_path);
3177
3178 // Validate files exist
3179 if !arrow_path.exists() {
3180 return Err(Error::IoError(std::io::Error::new(
3181 std::io::ErrorKind::NotFound,
3182 format!("Arrow file not found: {}", arrow_path.display()),
3183 )));
3184 }
3185 if !zip_path.exists() {
3186 return Err(Error::IoError(std::io::Error::new(
3187 std::io::ErrorKind::NotFound,
3188 format!("ZIP file not found: {}", zip_path.display()),
3189 )));
3190 }
3191
3192 // Get file names
3193 let arrow_name = arrow_path
3194 .file_name()
3195 .and_then(|n| n.to_str())
3196 .ok_or_else(|| {
3197 Error::IoError(std::io::Error::new(
3198 std::io::ErrorKind::InvalidInput,
3199 "Invalid Arrow filename",
3200 ))
3201 })?;
3202 let zip_name = zip_path
3203 .file_name()
3204 .and_then(|n| n.to_str())
3205 .ok_or_else(|| {
3206 Error::IoError(std::io::Error::new(
3207 std::io::ErrorKind::InvalidInput,
3208 "Invalid ZIP filename",
3209 ))
3210 })?;
3211
3212 // Generate snapshot name from arrow file (without extension)
3213 let snapshot_name = description
3214 .map(|s| s.to_string())
3215 .or_else(|| {
3216 arrow_path
3217 .file_stem()
3218 .and_then(|s| s.to_str())
3219 .map(|s| s.to_string())
3220 })
3221 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3222
3223 // Calculate file sizes
3224 let arrow_size = arrow_path.metadata()?.len() as usize;
3225 let zip_size = zip_path.metadata()?.len() as usize;
3226 let total = arrow_size + zip_size;
3227 let current = Arc::new(AtomicUsize::new(0));
3228
3229 if let Some(progress) = &progress {
3230 let _ = progress
3231 .send(Progress {
3232 current: 0,
3233 total,
3234 status: None,
3235 })
3236 .await;
3237 }
3238
3239 // Create multipart upload request with "ziparrow" type
3240 let params = SnapshotCreateMultipartParams {
3241 snapshot_name,
3242 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
3243 file_sizes: vec![arrow_size, zip_size],
3244 snapshot_type: Some("ziparrow".to_string()),
3245 };
3246
3247 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3248 .rpc(
3249 "snapshots.create_upload_url_multipart".to_owned(),
3250 Some(params),
3251 )
3252 .await?;
3253
3254 let snapshot_id = match multipart.get("snapshot_id") {
3255 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3256 _ => return Err(Error::InvalidResponse),
3257 };
3258
3259 let snapshot = self.snapshot(snapshot_id).await?;
3260 let part_prefix = snapshot
3261 .path()
3262 .split("::/")
3263 .last()
3264 .ok_or(Error::InvalidResponse)?
3265 .to_owned();
3266
3267 // Upload Arrow file
3268 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
3269 let mut arrow_part = match multipart.get(&arrow_key) {
3270 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3271 _ => return Err(Error::InvalidResponse),
3272 };
3273 arrow_part.key = Some(arrow_key);
3274
3275 let params = upload_multipart(
3276 self.bulk_http.clone(),
3277 arrow_part,
3278 arrow_path.to_path_buf(),
3279 total,
3280 current.clone(),
3281 progress.clone(),
3282 )
3283 .await?;
3284
3285 let _: String = self
3286 .rpc(
3287 "snapshots.complete_multipart_upload".to_owned(),
3288 Some(params),
3289 )
3290 .await?;
3291 debug!("Arrow file upload complete");
3292
3293 // Upload ZIP file
3294 let zip_key = format!("{}/{}", part_prefix, zip_name);
3295 let mut zip_part = match multipart.get(&zip_key) {
3296 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3297 _ => return Err(Error::InvalidResponse),
3298 };
3299 zip_part.key = Some(zip_key);
3300
3301 let params = upload_multipart(
3302 self.bulk_http.clone(),
3303 zip_part,
3304 zip_path.to_path_buf(),
3305 total,
3306 current.clone(),
3307 progress.clone(),
3308 )
3309 .await?;
3310
3311 let _: String = self
3312 .rpc(
3313 "snapshots.complete_multipart_upload".to_owned(),
3314 Some(params),
3315 )
3316 .await?;
3317 debug!("ZIP file upload complete");
3318
3319 // Mark snapshot as available
3320 let params = SnapshotStatusParams {
3321 snapshot_id,
3322 status: "available".to_owned(),
3323 };
3324 let _: SnapshotStatusResult = self
3325 .rpc("snapshots.update".to_owned(), Some(params))
3326 .await?;
3327
3328 if let Some(progress) = progress {
3329 drop(progress);
3330 }
3331
3332 self.snapshot(snapshot_id).await
3333 }
3334
3335 /// Delete a snapshot from EdgeFirst Studio.
3336 ///
3337 /// Permanently removes a snapshot and its associated data. This operation
3338 /// cannot be undone.
3339 ///
3340 /// # Arguments
3341 ///
3342 /// * `snapshot_id` - The snapshot ID to delete
3343 ///
3344 /// # Errors
3345 ///
3346 /// Returns an error if:
3347 /// * Snapshot doesn't exist
3348 /// * User lacks permission to delete the snapshot
3349 /// * Server error occurs
3350 ///
3351 /// # Example
3352 ///
3353 /// ```no_run
3354 /// # use edgefirst_client::{Client, SnapshotID};
3355 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3356 /// let client = Client::new()?.with_token_path(None)?;
3357 /// let snapshot_id = SnapshotID::from(123);
3358 /// client.delete_snapshot(snapshot_id).await?;
3359 /// # Ok(())
3360 /// # }
3361 /// ```
3362 ///
3363 /// # See Also
3364 ///
3365 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3366 /// * [`snapshots`](Self::snapshots) - List all snapshots
3367 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3368 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
3369 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3370 let _: serde_json::Value = self
3371 .rpc("snapshots.delete".to_owned(), Some(params))
3372 .await?;
3373 Ok(())
3374 }
3375
3376 /// Create a snapshot from an existing dataset on the server.
3377 ///
3378 /// Triggers server-side snapshot generation which exports the dataset's
3379 /// images and annotations into a downloadable EdgeFirst Dataset Format
3380 /// snapshot.
3381 ///
3382 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
3383 /// while restore creates a dataset from a snapshot, this method creates a
3384 /// snapshot from a dataset.
3385 ///
3386 /// # Arguments
3387 ///
3388 /// * `dataset_id` - The dataset ID to create snapshot from
3389 /// * `description` - Description for the created snapshot
3390 ///
3391 /// # Returns
3392 ///
3393 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
3394 /// for monitoring progress.
3395 ///
3396 /// # Errors
3397 ///
3398 /// Returns an error if:
3399 /// * Dataset doesn't exist
3400 /// * User lacks permission to access the dataset
3401 /// * Server rejects the request
3402 ///
3403 /// # Example
3404 ///
3405 /// ```no_run
3406 /// # use edgefirst_client::{Client, DatasetID};
3407 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3408 /// let client = Client::new()?.with_token_path(None)?;
3409 /// let dataset_id = DatasetID::from(123);
3410 ///
3411 /// // Create snapshot from dataset (all annotation sets)
3412 /// let result = client
3413 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
3414 /// .await?;
3415 /// println!("Created snapshot: {:?}", result.id);
3416 ///
3417 /// // Monitor progress via task ID
3418 /// if let Some(task_id) = result.task_id {
3419 /// println!("Task: {}", task_id);
3420 /// }
3421 /// # Ok(())
3422 /// # }
3423 /// ```
3424 ///
3425 /// # See Also
3426 ///
3427 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
3428 /// snapshot
3429 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3430 /// dataset
3431 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3432 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3433 pub async fn create_snapshot_from_dataset(
3434 &self,
3435 dataset_id: DatasetID,
3436 description: &str,
3437 annotation_set_id: Option<AnnotationSetID>,
3438 ) -> Result<SnapshotFromDatasetResult, Error> {
3439 // Resolve annotation_set_id: use provided value or fetch default
3440 let annotation_set_id = match annotation_set_id {
3441 Some(id) => id,
3442 None => {
3443 // Fetch annotation sets and find default ("annotations") or use first
3444 let sets = self.annotation_sets(dataset_id).await?;
3445 if sets.is_empty() {
3446 return Err(Error::InvalidParameters(
3447 "No annotation sets available for dataset".to_owned(),
3448 ));
3449 }
3450 // Look for "annotations" set (default), otherwise use first
3451 sets.iter()
3452 .find(|s| s.name() == "annotations")
3453 .unwrap_or(&sets[0])
3454 .id()
3455 }
3456 };
3457 let params = SnapshotCreateFromDataset {
3458 description: description.to_owned(),
3459 dataset_id,
3460 annotation_set_id,
3461 };
3462 self.rpc("snapshots.create".to_owned(), Some(params)).await
3463 }
3464
3465 /// Download a snapshot from EdgeFirst Studio to local storage.
3466 ///
3467 /// Downloads all files in a snapshot (single MCAP file or directory of
3468 /// EdgeFirst Dataset Format files) to the specified output path. Files are
3469 /// downloaded concurrently with progress tracking.
3470 ///
3471 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3472 /// downloads (default: half of CPU cores, min 2, max 8).
3473 ///
3474 /// # Arguments
3475 ///
3476 /// * `snapshot_id` - The snapshot ID to download
3477 /// * `output` - Local directory path to save downloaded files
3478 /// * `progress` - Optional channel to receive download progress updates
3479 ///
3480 /// # Progress
3481 ///
3482 /// Reports progress with `status: None` as file data is received. Progress
3483 /// unit is bytes downloaded across all files combined. The total
3484 /// accumulates as file sizes become known (from HTTP Content-Length
3485 /// headers), so both `current` and `total` may increase during
3486 /// download.
3487 ///
3488 /// # Errors
3489 ///
3490 /// Returns an error if:
3491 /// * Snapshot doesn't exist
3492 /// * Output directory cannot be created
3493 /// * Download fails or network error occurs
3494 ///
3495 /// # Example
3496 ///
3497 /// ```no_run
3498 /// # use edgefirst_client::{Client, SnapshotID, Progress};
3499 /// # use tokio::sync::mpsc;
3500 /// # use std::path::PathBuf;
3501 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3502 /// let client = Client::new()?.with_token_path(None)?;
3503 /// let snapshot_id = SnapshotID::from(123);
3504 ///
3505 /// // Download with progress tracking
3506 /// let (tx, mut rx) = mpsc::channel(1);
3507 /// tokio::spawn(async move {
3508 /// while let Some(Progress {
3509 /// current,
3510 /// total,
3511 /// status,
3512 /// }) = rx.recv().await
3513 /// {
3514 /// println!(
3515 /// "{}: {}/{} bytes",
3516 /// status.as_deref().unwrap_or("Download"),
3517 /// current,
3518 /// total
3519 /// );
3520 /// }
3521 /// });
3522 /// client
3523 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
3524 /// .await?;
3525 /// # Ok(())
3526 /// # }
3527 /// ```
3528 ///
3529 /// # See Also
3530 ///
3531 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3532 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3533 /// dataset
3534 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3535 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
3536 pub async fn download_snapshot(
3537 &self,
3538 snapshot_id: SnapshotID,
3539 output: PathBuf,
3540 progress: Option<Sender<Progress>>,
3541 ) -> Result<(), Error> {
3542 fs::create_dir_all(&output).await?;
3543
3544 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3545 let items: HashMap<String, String> = self
3546 .rpc("snapshots.create_download_url".to_owned(), Some(params))
3547 .await?;
3548
3549 // Single-phase: each task holds its semaphore permit for the full
3550 // lifetime of the request (GET → headers → stream → disk). This bounds
3551 // the number of simultaneously-open connections to max_tasks() and
3552 // avoids accumulating all responses in memory before streaming.
3553 //
3554 // total is updated atomically as each response's Content-Length header
3555 // arrives, so progress tracking is accurate without a separate phase.
3556 let http = self.bulk_http.clone();
3557 let current = Arc::new(AtomicUsize::new(0));
3558 let total = Arc::new(AtomicUsize::new(0));
3559 let sem = Arc::new(Semaphore::new(max_tasks()));
3560
3561 let tasks = items
3562 .into_iter()
3563 .map(|(key, url)| {
3564 let http = http.clone();
3565 let output = output.clone();
3566 let progress = progress.clone();
3567 let current = current.clone();
3568 let total = total.clone();
3569 let sem = sem.clone();
3570
3571 tokio::spawn(async move {
3572 let _permit = sem.acquire().await.map_err(|_| {
3573 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3574 })?;
3575
3576 let res = http.get(url).send().await?;
3577 let res = res.error_for_status()?;
3578
3579 // Contribute this file's size to the running total so the
3580 // caller's progress bar knows the overall scope.
3581 if let Some(len) = res.content_length() {
3582 total.fetch_add(len as usize, Ordering::SeqCst);
3583 }
3584
3585 let mut file = File::create(output.join(key)).await?;
3586 let mut stream = res.bytes_stream();
3587
3588 while let Some(chunk) = stream.next().await {
3589 let chunk = chunk?;
3590 file.write_all(&chunk).await?;
3591 let len = chunk.len();
3592
3593 if let Some(progress) = &progress {
3594 let cur = current.fetch_add(len, Ordering::SeqCst) + len;
3595 let tot = total.load(Ordering::SeqCst);
3596 let _ = progress
3597 .send(Progress {
3598 current: cur,
3599 total: tot,
3600 status: None,
3601 })
3602 .await;
3603 }
3604 }
3605
3606 Ok::<(), Error>(())
3607 })
3608 })
3609 .collect::<Vec<_>>();
3610
3611 join_all(tasks)
3612 .await
3613 .into_iter()
3614 .collect::<Result<Vec<_>, _>>()?
3615 .into_iter()
3616 .collect::<Result<Vec<_>, _>>()?;
3617
3618 Ok(())
3619 }
3620
3621 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
3622 ///
3623 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
3624 /// the specified project. For MCAP files, supports:
3625 ///
3626 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
3627 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
3628 /// present)
3629 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
3630 /// * **Topic filtering**: Select specific MCAP topics to restore
3631 ///
3632 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
3633 /// dataset structure.
3634 ///
3635 /// # Arguments
3636 ///
3637 /// * `project_id` - Target project ID
3638 /// * `snapshot_id` - Snapshot ID to restore
3639 /// * `topics` - MCAP topics to include (empty = all topics)
3640 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
3641 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
3642 /// * `dataset_name` - Optional custom dataset name
3643 /// * `dataset_description` - Optional dataset description
3644 ///
3645 /// # Returns
3646 ///
3647 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
3648 ///
3649 /// # Errors
3650 ///
3651 /// Returns an error if:
3652 /// * Snapshot or project doesn't exist
3653 /// * Snapshot format is invalid
3654 /// * Server rejects restoration parameters
3655 ///
3656 /// # Example
3657 ///
3658 /// ```no_run
3659 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
3660 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3661 /// let client = Client::new()?.with_token_path(None)?;
3662 /// let project_id = ProjectID::from(1);
3663 /// let snapshot_id = SnapshotID::from(123);
3664 ///
3665 /// // Restore MCAP with AGTG for "person" and "car" detection
3666 /// let result = client
3667 /// .restore_snapshot(
3668 /// project_id,
3669 /// snapshot_id,
3670 /// &[], // All topics
3671 /// &["person".to_string(), "car".to_string()], // AGTG labels
3672 /// true, // Auto-depth
3673 /// Some("Highway Dataset"),
3674 /// Some("Collected on I-95"),
3675 /// )
3676 /// .await?;
3677 /// println!("Restored to dataset: {:?}", result.dataset_id);
3678 /// # Ok(())
3679 /// # }
3680 /// ```
3681 ///
3682 /// # See Also
3683 ///
3684 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3685 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3686 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3687 #[allow(clippy::too_many_arguments)]
3688 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3689 pub async fn restore_snapshot(
3690 &self,
3691 project_id: ProjectID,
3692 snapshot_id: SnapshotID,
3693 topics: &[String],
3694 autolabel: &[String],
3695 autodepth: bool,
3696 dataset_name: Option<&str>,
3697 dataset_description: Option<&str>,
3698 ) -> Result<SnapshotRestoreResult, Error> {
3699 let params = SnapshotRestore {
3700 project_id,
3701 snapshot_id,
3702 fps: 1,
3703 autodepth,
3704 agtg_pipeline: !autolabel.is_empty(),
3705 autolabel: autolabel.to_vec(),
3706 topics: topics.to_vec(),
3707 dataset_name: dataset_name.map(|s| s.to_owned()),
3708 dataset_description: dataset_description.map(|s| s.to_owned()),
3709 };
3710 self.rpc("snapshots.restore".to_owned(), Some(params)).await
3711 }
3712
3713 /// Returns a list of experiments available to the user. The experiments
3714 /// are returned as a vector of Experiment objects. If name is provided
3715 /// then only experiments containing this string are returned.
3716 ///
3717 /// Results are sorted by match quality: exact matches first, then
3718 /// case-insensitive exact matches, then shorter names (more specific),
3719 /// then alphabetically.
3720 ///
3721 /// Experiments provide a method of organizing training and validation
3722 /// sessions together and are akin to an Experiment in MLFlow terminology.
3723 /// Each experiment can have multiple trainer sessions associated with it,
3724 /// these would be akin to runs in MLFlow terminology.
3725 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3726 pub async fn experiments(
3727 &self,
3728 project_id: ProjectID,
3729 name: Option<&str>,
3730 ) -> Result<Vec<Experiment>, Error> {
3731 let params = HashMap::from([("project_id", project_id)]);
3732 let experiments: Vec<Experiment> =
3733 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
3734 if let Some(name) = name {
3735 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
3736 } else {
3737 Ok(experiments)
3738 }
3739 }
3740
3741 /// Return the experiment with the specified experiment ID. If the
3742 /// experiment does not exist, an error is returned.
3743 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3744 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
3745 let params = HashMap::from([("trainer_id", experiment_id)]);
3746 self.rpc("trainer.get".to_owned(), Some(params)).await
3747 }
3748
3749 /// Returns a list of trainer sessions available to the user. The trainer
3750 /// sessions are returned as a vector of TrainingSession objects. If name
3751 /// is provided then only trainer sessions containing this string are
3752 /// returned.
3753 ///
3754 /// Results are sorted by match quality: exact matches first, then
3755 /// case-insensitive exact matches, then shorter names (more specific),
3756 /// then alphabetically.
3757 ///
3758 /// Trainer sessions are akin to runs in MLFlow terminology. These
3759 /// represent an actual training session which will produce metrics and
3760 /// model artifacts.
3761 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3762 pub async fn training_sessions(
3763 &self,
3764 experiment_id: ExperimentID,
3765 name: Option<&str>,
3766 ) -> Result<Vec<TrainingSession>, Error> {
3767 let params = HashMap::from([("trainer_id", experiment_id)]);
3768 let sessions: Vec<TrainingSession> = self
3769 .rpc("trainer.session.list".to_owned(), Some(params))
3770 .await?;
3771 if let Some(name) = name {
3772 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
3773 } else {
3774 Ok(sessions)
3775 }
3776 }
3777
3778 /// Return the trainer session with the specified trainer session ID. If
3779 /// the trainer session does not exist, an error is returned.
3780 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3781 pub async fn training_session(
3782 &self,
3783 session_id: TrainingSessionID,
3784 ) -> Result<TrainingSession, Error> {
3785 let params = HashMap::from([("trainer_session_id", session_id)]);
3786 self.rpc("trainer.session.get".to_owned(), Some(params))
3787 .await
3788 }
3789
3790 /// List validation sessions for the given project.
3791 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3792 pub async fn validation_sessions(
3793 &self,
3794 project_id: ProjectID,
3795 ) -> Result<Vec<ValidationSession>, Error> {
3796 let params = HashMap::from([("project_id", project_id)]);
3797 self.rpc("validate.session.list".to_owned(), Some(params))
3798 .await
3799 }
3800
3801 /// Retrieve a specific validation session.
3802 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3803 pub async fn validation_session(
3804 &self,
3805 session_id: ValidationSessionID,
3806 ) -> Result<ValidationSession, Error> {
3807 let params = HashMap::from([("validate_session_id", session_id)]);
3808 self.rpc("validate.session.get".to_owned(), Some(params))
3809 .await
3810 }
3811
3812 /// List the artifacts for the specified trainer session. The artifacts
3813 /// are returned as a vector of strings.
3814 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3815 pub async fn artifacts(
3816 &self,
3817 training_session_id: TrainingSessionID,
3818 ) -> Result<Vec<Artifact>, Error> {
3819 let params = HashMap::from([("training_session_id", training_session_id)]);
3820 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
3821 .await
3822 }
3823
3824 /// Download the model artifact for the specified trainer session to the
3825 /// specified file path, if path is not provided it will be downloaded to
3826 /// the current directory with the same filename.
3827 ///
3828 /// # Progress
3829 ///
3830 /// Reports progress with `status: None` as file data is received. Progress
3831 /// unit is bytes downloaded. Total is determined from the HTTP
3832 /// Content-Length header (may be 0 if server doesn't provide it).
3833 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
3834 pub async fn download_artifact(
3835 &self,
3836 training_session_id: TrainingSessionID,
3837 modelname: &str,
3838 filename: Option<PathBuf>,
3839 progress: Option<Sender<Progress>>,
3840 ) -> Result<(), Error> {
3841 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
3842 let resp = self
3843 .bulk_http
3844 .get(format!(
3845 "{}/download_model?training_session_id={}&file={}",
3846 self.url,
3847 training_session_id.value(),
3848 modelname
3849 ))
3850 .header("Authorization", format!("Bearer {}", self.token().await))
3851 .send()
3852 .await?;
3853 if !resp.status().is_success() {
3854 let err = resp.error_for_status_ref().unwrap_err();
3855 return Err(Error::HttpError(err));
3856 }
3857
3858 if let Some(parent) = filename.parent() {
3859 fs::create_dir_all(parent).await?;
3860 }
3861
3862 let total = resp.content_length().unwrap_or(0) as usize;
3863
3864 if let Some(ref progress) = progress {
3865 let _ = progress
3866 .send(Progress {
3867 current: 0,
3868 total,
3869 status: None,
3870 })
3871 .await;
3872 }
3873
3874 let mut file = File::create(filename).await?;
3875 let mut current = 0;
3876 let mut stream = resp.bytes_stream();
3877
3878 while let Some(item) = stream.next().await {
3879 let chunk = item?;
3880 file.write_all(&chunk).await?;
3881 current += chunk.len();
3882 if let Some(ref progress) = progress {
3883 let _ = progress
3884 .send(Progress {
3885 current,
3886 total,
3887 status: None,
3888 })
3889 .await;
3890 }
3891 }
3892
3893 // Flush tokio's internal write buffer to the OS before returning.
3894 // tokio::fs::File buffers writes internally; without this, the buffer
3895 // may not reach the filesystem before the caller reads the file.
3896 file.flush().await?;
3897
3898 Ok(())
3899 }
3900
3901 /// Download the model checkpoint associated with the specified trainer
3902 /// session to the specified file path, if path is not provided it will be
3903 /// downloaded to the current directory with the same filename.
3904 ///
3905 /// There is no API for listing checkpoints it is expected that trainers are
3906 /// aware of possible checkpoints and their names within the checkpoint
3907 /// folder on the server.
3908 ///
3909 /// # Progress
3910 ///
3911 /// Reports progress with `status: None` as file data is received. Progress
3912 /// unit is bytes downloaded. Total is determined from the HTTP
3913 /// Content-Length header (may be 0 if server doesn't provide it).
3914 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
3915 pub async fn download_checkpoint(
3916 &self,
3917 training_session_id: TrainingSessionID,
3918 checkpoint: &str,
3919 filename: Option<PathBuf>,
3920 progress: Option<Sender<Progress>>,
3921 ) -> Result<(), Error> {
3922 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
3923 let resp = self
3924 .bulk_http
3925 .get(format!(
3926 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
3927 self.url,
3928 training_session_id.value(),
3929 checkpoint
3930 ))
3931 .header("Authorization", format!("Bearer {}", self.token().await))
3932 .send()
3933 .await?;
3934 if !resp.status().is_success() {
3935 let err = resp.error_for_status_ref().unwrap_err();
3936 return Err(Error::HttpError(err));
3937 }
3938
3939 if let Some(parent) = filename.parent() {
3940 fs::create_dir_all(parent).await?;
3941 }
3942
3943 let total = resp.content_length().unwrap_or(0) as usize;
3944
3945 if let Some(ref progress) = progress {
3946 let _ = progress
3947 .send(Progress {
3948 current: 0,
3949 total,
3950 status: None,
3951 })
3952 .await;
3953 }
3954
3955 let mut file = File::create(filename).await?;
3956 let mut current = 0;
3957 let mut stream = resp.bytes_stream();
3958
3959 while let Some(item) = stream.next().await {
3960 let chunk = item?;
3961 file.write_all(&chunk).await?;
3962 current += chunk.len();
3963 if let Some(ref progress) = progress {
3964 let _ = progress
3965 .send(Progress {
3966 current,
3967 total,
3968 status: None,
3969 })
3970 .await;
3971 }
3972 }
3973
3974 // Flush tokio's internal write buffer to the OS before returning.
3975 // tokio::fs::File buffers writes internally; without this, the buffer
3976 // may not reach the filesystem before the caller reads the file.
3977 file.flush().await?;
3978
3979 Ok(())
3980 }
3981
3982 /// Return a list of tasks for the current user.
3983 ///
3984 /// # Arguments
3985 ///
3986 /// * `name` - Optional filter for task name (client-side substring match)
3987 /// * `workflow` - Optional filter for workflow/task type. If provided,
3988 /// filters server-side by exact match. Valid values include: "trainer",
3989 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
3990 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
3991 /// "convertor", "twostage"
3992 /// * `status` - Optional filter for task status (e.g., "running",
3993 /// "complete", "error")
3994 /// * `manager` - Optional filter for task manager type (e.g., "aws",
3995 /// "user", "kubernetes")
3996 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3997 pub async fn tasks(
3998 &self,
3999 name: Option<&str>,
4000 workflow: Option<&str>,
4001 status: Option<&str>,
4002 manager: Option<&str>,
4003 ) -> Result<Vec<Task>, Error> {
4004 let mut params = TasksListParams {
4005 continue_token: None,
4006 types: workflow.map(|w| vec![w.to_owned()]),
4007 status: status.map(|s| vec![s.to_owned()]),
4008 manager: manager.map(|m| vec![m.to_owned()]),
4009 };
4010 let mut tasks = Vec::new();
4011
4012 loop {
4013 let result = self
4014 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
4015 .await?;
4016 tasks.extend(result.tasks);
4017
4018 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
4019 params.continue_token = None;
4020 } else {
4021 params.continue_token = result.continue_token;
4022 }
4023
4024 if params.continue_token.is_none() {
4025 break;
4026 }
4027 }
4028
4029 if let Some(name) = name {
4030 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4031 }
4032
4033 Ok(tasks)
4034 }
4035
4036 /// Retrieve the task information and status.
4037 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4038 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4039 self.rpc(
4040 "task.get".to_owned(),
4041 Some(HashMap::from([("id", task_id)])),
4042 )
4043 .await
4044 }
4045
4046 /// Updates the tasks status.
4047 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4048 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4049 let status = TaskStatus {
4050 task_id,
4051 status: status.to_owned(),
4052 };
4053 self.rpc("docker.update.status".to_owned(), Some(status))
4054 .await
4055 }
4056
4057 /// Defines the stages for the task. The stages are defined as a mapping
4058 /// from stage names to their descriptions. Once stages are defined their
4059 /// status can be updated using the update_stage method.
4060 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4061 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4062 let stages: Vec<HashMap<String, String>> = stages
4063 .iter()
4064 .map(|(key, value)| {
4065 let mut stage_map = HashMap::new();
4066 stage_map.insert(key.to_string(), value.to_string());
4067 stage_map
4068 })
4069 .collect();
4070 let params = TaskStages { task_id, stages };
4071 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4072 Ok(())
4073 }
4074
4075 /// Updates the progress of the task for the provided stage and status
4076 /// information.
4077 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4078 pub async fn update_stage(
4079 &self,
4080 task_id: TaskID,
4081 stage: &str,
4082 status: &str,
4083 message: &str,
4084 percentage: u8,
4085 ) -> Result<(), Error> {
4086 let stage = Stage::new(
4087 Some(task_id),
4088 stage.to_owned(),
4089 Some(status.to_owned()),
4090 Some(message.to_owned()),
4091 percentage,
4092 );
4093 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4094 Ok(())
4095 }
4096
4097 /// Authenticated fetch from the Studio server using the bulk HTTP client
4098 /// (no total-request timeout; idle read timeout per chunk).
4099 ///
4100 /// **Buffers the entire response body into memory.** Suitable for small to
4101 /// medium payloads. For very large binary downloads (multi-GB artifacts or
4102 /// checkpoints), prefer a streaming approach that writes directly to disk.
4103 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4104 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4105 let req = self
4106 .bulk_http
4107 .get(format!("{}/{}", self.url, query))
4108 .header("User-Agent", "EdgeFirst Client")
4109 .header("Authorization", format!("Bearer {}", self.token().await));
4110 let resp = req.send().await?;
4111
4112 if resp.status().is_success() {
4113 let body = resp.bytes().await?;
4114
4115 if log_enabled!(Level::Trace) {
4116 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
4117 }
4118
4119 Ok(body.to_vec())
4120 } else {
4121 let err = resp.error_for_status_ref().unwrap_err();
4122 Err(Error::HttpError(err))
4123 }
4124 }
4125
4126 /// Sends a multipart post request to the server. This is used by the
4127 /// upload and download APIs which do not use JSON-RPC but instead transfer
4128 /// files using multipart/form-data.
4129 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
4130 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
4131 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4132 .ok()
4133 .and_then(|s| s.parse().ok())
4134 .unwrap_or(600u64);
4135
4136 let req = self
4137 .http
4138 .post(format!("{}/api?method={}", self.url, method))
4139 .header("Accept", "application/json")
4140 .header("User-Agent", "EdgeFirst Client")
4141 .header("Authorization", format!("Bearer {}", self.token().await))
4142 .timeout(Duration::from_secs(upload_timeout_secs))
4143 .multipart(form);
4144 let resp = req.send().await?;
4145
4146 if resp.status().is_success() {
4147 let body = resp.bytes().await?;
4148
4149 if log_enabled!(Level::Trace) {
4150 trace!(
4151 "POST Multipart Response: {}",
4152 String::from_utf8_lossy(&body)
4153 );
4154 }
4155
4156 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
4157 Ok(response) => response,
4158 Err(err) => {
4159 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4160 return Err(err.into());
4161 }
4162 };
4163
4164 if let Some(error) = response.error {
4165 Err(Error::RpcError(error.code, error.message))
4166 } else if let Some(result) = response.result {
4167 Ok(result)
4168 } else {
4169 Err(Error::InvalidResponse)
4170 }
4171 } else {
4172 let err = resp.error_for_status_ref().unwrap_err();
4173 Err(Error::HttpError(err))
4174 }
4175 }
4176
4177 /// Send a JSON-RPC request to the server. The method is the name of the
4178 /// method to call on the server. The params are the parameters to pass to
4179 /// the method. The method and params are serialized into a JSON-RPC
4180 /// request and sent to the server. The response is deserialized into
4181 /// the specified type and returned to the caller.
4182 ///
4183 /// NOTE: This API would generally not be called directly and instead users
4184 /// should use the higher-level methods provided by the client.
4185 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
4186 pub async fn rpc<Params, RpcResult>(
4187 &self,
4188 method: String,
4189 params: Option<Params>,
4190 ) -> Result<RpcResult, Error>
4191 where
4192 Params: Serialize,
4193 RpcResult: DeserializeOwned,
4194 {
4195 let auth_expires = self.token_expiration().await?;
4196 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
4197 self.renew_token().await?;
4198 }
4199
4200 self.rpc_without_auth(method, params).await
4201 }
4202
4203 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
4204 async fn rpc_without_auth<Params, RpcResult>(
4205 &self,
4206 method: String,
4207 params: Option<Params>,
4208 ) -> Result<RpcResult, Error>
4209 where
4210 Params: Serialize,
4211 RpcResult: DeserializeOwned,
4212 {
4213 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4214 .ok()
4215 .and_then(|s| s.parse().ok())
4216 .unwrap_or(5usize);
4217
4218 let url = format!("{}/api", self.url);
4219
4220 // Serialize request body once before retry loop to avoid Clone bound on Params
4221 let request = RpcRequest {
4222 method: method.clone(),
4223 params,
4224 ..Default::default()
4225 };
4226
4227 // Log request for debugging (log crate) and profiling (tracing crate)
4228 let request_json = if method == "auth.login" {
4229 // Redact auth.login params (contains password)
4230 serde_json::json!({
4231 "jsonrpc": "2.0",
4232 "method": &method,
4233 "params": "[REDACTED - contains credentials]",
4234 "id": request.id
4235 })
4236 .to_string()
4237 } else {
4238 serde_json::to_string(&request)?
4239 };
4240
4241 if log_enabled!(Level::Trace) {
4242 trace!("RPC Request: {}", request_json);
4243 }
4244
4245 // Record request on current span for Perfetto when profiling is enabled
4246 #[cfg(feature = "profiling")]
4247 tracing::Span::current().record("request", &request_json);
4248
4249 let request_body = serde_json::to_vec(&request)?;
4250 let mut last_error: Option<Error> = None;
4251
4252 for attempt in 0..=max_retries {
4253 if attempt > 0 {
4254 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
4255 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
4256 // while ensuring we never retry faster than the base delay
4257 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
4258 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
4259 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
4260 let delay = Duration::from_millis(delay_ms);
4261 warn!(
4262 "Retry {}/{} for RPC '{}' after {:?}",
4263 attempt, max_retries, method, delay
4264 );
4265 tokio::time::sleep(delay).await;
4266 }
4267
4268 let result = self
4269 .http
4270 .post(&url)
4271 .header("Accept", "application/json")
4272 .header("Content-Type", "application/json")
4273 .header("User-Agent", "EdgeFirst Client")
4274 .header("Authorization", format!("Bearer {}", self.token().await))
4275 .body(request_body.clone())
4276 .send()
4277 .await;
4278
4279 match result {
4280 Ok(res) => {
4281 let status = res.status();
4282 let status_code = status.as_u16();
4283
4284 // Check for retryable HTTP status codes before processing response
4285 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
4286 && attempt < max_retries
4287 {
4288 warn!(
4289 "RPC '{}' failed with HTTP {} (retrying)",
4290 method, status_code
4291 );
4292 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
4293 continue;
4294 }
4295
4296 // Process the response
4297 match self.process_rpc_response(res).await {
4298 Ok(result) => {
4299 if attempt > 0 {
4300 debug!("RPC '{}' succeeded on retry {}", method, attempt);
4301 }
4302 return Ok(result);
4303 }
4304 Err(e) => {
4305 // Don't retry client errors (4xx except 408, 429)
4306 if attempt > 0 {
4307 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4308 }
4309 return Err(e);
4310 }
4311 }
4312 }
4313 Err(e) => {
4314 // Transport error (timeout, connection failure, etc.)
4315 let is_timeout = e.is_timeout();
4316 let is_connect = e.is_connect();
4317
4318 if (is_timeout || is_connect) && attempt < max_retries {
4319 warn!(
4320 "RPC '{}' transport error (retrying): {}",
4321 method,
4322 if is_timeout {
4323 "timeout"
4324 } else {
4325 "connection failed"
4326 }
4327 );
4328 last_error = Some(Error::HttpError(e));
4329 continue;
4330 }
4331
4332 if attempt > 0 {
4333 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4334 }
4335 return Err(Error::HttpError(e));
4336 }
4337 }
4338 }
4339
4340 // Should not reach here
4341 Err(last_error.unwrap_or_else(|| {
4342 Error::InvalidParameters(format!(
4343 "RPC '{}' failed after {} retries",
4344 method, max_retries
4345 ))
4346 }))
4347 }
4348
4349 async fn process_rpc_response<RpcResult>(
4350 &self,
4351 res: reqwest::Response,
4352 ) -> Result<RpcResult, Error>
4353 where
4354 RpcResult: DeserializeOwned,
4355 {
4356 let body = res.bytes().await?;
4357 let response_str = String::from_utf8_lossy(&body);
4358
4359 if log_enabled!(Level::Trace) {
4360 trace!("RPC Response: {}", response_str);
4361 }
4362
4363 // Record response on current span for Perfetto when profiling is enabled
4364 // Truncate large responses to avoid bloating trace files
4365 #[cfg(feature = "profiling")]
4366 {
4367 const MAX_RESPONSE_LEN: usize = 4096;
4368 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
4369 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
4370 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
4371 format!(
4372 "{}...[truncated {} bytes]",
4373 &response_str[..safe_end],
4374 response_str.len() - safe_end
4375 )
4376 } else {
4377 response_str.to_string()
4378 };
4379 tracing::Span::current().record("response", &truncated);
4380 }
4381
4382 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
4383 Ok(response) => response,
4384 Err(err) => {
4385 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4386 return Err(err.into());
4387 }
4388 };
4389
4390 // FIXME: Studio Server always returns 999 as the id.
4391 // if request.id.to_string() != response.id {
4392 // return Err(Error::InvalidRpcId(response.id));
4393 // }
4394
4395 if let Some(error) = response.error {
4396 Err(Error::RpcError(error.code, error.message))
4397 } else if let Some(result) = response.result {
4398 Ok(result)
4399 } else {
4400 Err(Error::InvalidResponse)
4401 }
4402 }
4403}
4404
4405/// Process items in parallel with semaphore concurrency control and progress
4406/// tracking.
4407///
4408/// This helper eliminates boilerplate for parallel item processing with:
4409/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
4410/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
4411/// - Atomic progress counter with automatic item-level updates
4412/// - Progress updates sent after each item completes (not byte-level streaming)
4413/// - Proper error propagation from spawned tasks
4414///
4415/// Note: This is optimized for discrete items with post-completion progress
4416/// updates. For byte-level streaming progress or custom retry logic, use
4417/// specialized implementations.
4418///
4419/// # Arguments
4420///
4421/// * `items` - Collection of items to process in parallel
4422/// * `progress` - Optional progress channel for tracking completion
4423/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
4424/// * `work_fn` - Async function to execute for each item
4425///
4426/// # Examples
4427///
4428/// ```rust,ignore
4429/// // Use default concurrency
4430/// parallel_foreach_items(samples, progress, None, |sample| async move {
4431/// sample.download(&client, file_type).await?;
4432/// Ok(())
4433/// }).await?;
4434/// ```
4435async fn parallel_foreach_items<T, F, Fut>(
4436 items: Vec<T>,
4437 progress: Option<Sender<Progress>>,
4438 concurrency: Option<usize>,
4439 work_fn: F,
4440) -> Result<(), Error>
4441where
4442 T: Send + 'static,
4443 F: Fn(T) -> Fut + Send + Sync + 'static,
4444 Fut: Future<Output = Result<(), Error>> + Send + 'static,
4445{
4446 let total = items.len();
4447 let current = Arc::new(AtomicUsize::new(0));
4448 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
4449 let work_fn = Arc::new(work_fn);
4450
4451 let tasks = items
4452 .into_iter()
4453 .map(|item| {
4454 let sem = sem.clone();
4455 let current = current.clone();
4456 let progress = progress.clone();
4457 let work_fn = work_fn.clone();
4458
4459 tokio::spawn(async move {
4460 let _permit = sem.acquire().await.map_err(|_| {
4461 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4462 })?;
4463
4464 // Execute the actual work
4465 work_fn(item).await?;
4466
4467 // Update progress
4468 if let Some(progress) = &progress {
4469 let current = current.fetch_add(1, Ordering::SeqCst);
4470 let _ = progress
4471 .send(Progress {
4472 current: current + 1,
4473 total,
4474 status: None,
4475 })
4476 .await;
4477 }
4478
4479 Ok::<(), Error>(())
4480 })
4481 })
4482 .collect::<Vec<_>>();
4483
4484 join_all(tasks)
4485 .await
4486 .into_iter()
4487 .collect::<Result<Vec<_>, _>>()?
4488 .into_iter()
4489 .collect::<Result<Vec<_>, _>>()?;
4490
4491 if let Some(progress) = progress {
4492 drop(progress);
4493 }
4494
4495 Ok(())
4496}
4497
4498/// Upload a file to S3 using multipart upload with presigned URLs.
4499///
4500/// Splits a file into chunks (100MB each) and uploads them in parallel using
4501/// S3 multipart upload protocol. Returns completion parameters with ETags for
4502/// finalizing the upload.
4503///
4504/// This function handles:
4505/// - Splitting files into parts based on PART_SIZE (100MB)
4506/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
4507/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
4508/// - Retry logic (handled by reqwest client)
4509/// - Progress tracking across all parts
4510///
4511/// # Arguments
4512///
4513/// * `http` - HTTP client for making requests
4514/// * `part` - Snapshot part info with presigned URLs for each chunk
4515/// * `path` - Local file path to upload
4516/// * `total` - Total bytes across all files for progress calculation
4517/// * `current` - Atomic counter tracking bytes uploaded across all operations
4518/// * `progress` - Optional channel for sending progress updates
4519///
4520/// # Returns
4521///
4522/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
4523async fn upload_multipart(
4524 http: reqwest::Client,
4525 part: SnapshotPart,
4526 path: PathBuf,
4527 total: usize,
4528 confirmed_bytes: Arc<AtomicUsize>,
4529 progress: Option<Sender<Progress>>,
4530) -> Result<SnapshotCompleteMultipartParams, Error> {
4531 let filesize = path.metadata()?.len() as usize;
4532 let n_parts = filesize.div_ceil(PART_SIZE);
4533 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
4534
4535 let key = part.key.ok_or(Error::InvalidResponse)?;
4536 let upload_id = part.upload_id;
4537
4538 let urls = part.urls.clone();
4539
4540 // Pre-allocate ETag slots for all parts
4541 let etags = Arc::new(tokio::sync::Mutex::new(vec![
4542 EtagPart {
4543 etag: "".to_owned(),
4544 part_number: 0,
4545 };
4546 n_parts
4547 ]));
4548
4549 // Per-part byte counters for streaming progress (reset on retry)
4550 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
4551 (0..n_parts)
4552 .map(|_| AtomicUsize::new(0))
4553 .collect::<Vec<_>>(),
4554 );
4555
4556 // Upload all parts in parallel with concurrency limiting
4557 let tasks = (0..n_parts)
4558 .map(|part_idx| {
4559 let http = http.clone();
4560 let url = urls[part_idx].clone();
4561 let etags = etags.clone();
4562 let path = path.to_owned();
4563 let sem = sem.clone();
4564 let progress = progress.clone();
4565 let confirmed_bytes = confirmed_bytes.clone();
4566 let part_bytes = part_bytes.clone();
4567
4568 // Calculate this part's size
4569 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
4570 filesize % PART_SIZE
4571 } else {
4572 PART_SIZE
4573 };
4574
4575 tokio::spawn(async move {
4576 // Acquire semaphore permit to limit concurrent uploads
4577 let _permit = sem.acquire().await.map_err(|_| {
4578 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4579 })?;
4580
4581 // Upload part with streaming progress and retry logic
4582 let etag = upload_part_with_progress(
4583 http,
4584 url,
4585 path,
4586 part_idx,
4587 n_parts,
4588 part_size,
4589 total,
4590 confirmed_bytes.clone(),
4591 part_bytes.clone(),
4592 progress.clone(),
4593 )
4594 .await?;
4595
4596 // Store ETag for this part (needed to complete multipart upload)
4597 let mut etags_guard = etags.lock().await;
4598 etags_guard[part_idx] = EtagPart {
4599 etag,
4600 part_number: part_idx + 1,
4601 };
4602
4603 // Part completed successfully - add to confirmed bytes
4604 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
4605 // Reset part counter since it's now confirmed
4606 part_bytes[part_idx].store(0, Ordering::SeqCst);
4607
4608 // Send final progress update for this part
4609 if let Some(progress) = &progress {
4610 let current = confirmed_bytes.load(Ordering::SeqCst)
4611 + part_bytes
4612 .iter()
4613 .map(|p| p.load(Ordering::SeqCst))
4614 .sum::<usize>();
4615 let _ = progress
4616 .send(Progress {
4617 current,
4618 total,
4619 status: None,
4620 })
4621 .await;
4622 }
4623
4624 Ok::<(), Error>(())
4625 })
4626 })
4627 .collect::<Vec<_>>();
4628
4629 // Wait for all parts to complete (double collect to handle both JoinError and
4630 // inner Error)
4631 join_all(tasks)
4632 .await
4633 .into_iter()
4634 .collect::<Result<Vec<_>, _>>()?
4635 .into_iter()
4636 .collect::<Result<Vec<_>, _>>()?;
4637
4638 Ok(SnapshotCompleteMultipartParams {
4639 key,
4640 upload_id,
4641 etag_list: etags.lock().await.clone(),
4642 })
4643}
4644
4645/// Upload a single part with streaming progress tracking and retry logic.
4646///
4647/// Progress is reported continuously as bytes are sent. On retry, the part's
4648/// progress counter is reset to avoid over-reporting.
4649#[allow(clippy::too_many_arguments)]
4650async fn upload_part_with_progress(
4651 http: reqwest::Client,
4652 url: String,
4653 path: PathBuf,
4654 part_idx: usize,
4655 n_parts: usize,
4656 part_size: usize,
4657 total: usize,
4658 confirmed_bytes: Arc<AtomicUsize>,
4659 part_bytes: Arc<Vec<AtomicUsize>>,
4660 progress: Option<Sender<Progress>>,
4661) -> Result<String, Error> {
4662 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4663 .ok()
4664 .and_then(|s| s.parse().ok())
4665 .unwrap_or(5usize);
4666
4667 // Per-part total upload timeout. Covers the send phase (request body) where
4668 // read_timeout does not apply. Each part is at most PART_SIZE (100MB), so
4669 // this bounds how long a stalled upload can block before retrying.
4670 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4671 .ok()
4672 .and_then(|s| s.parse().ok())
4673 .unwrap_or(600u64); // 600s = 100MB at ~170 KB/s minimum
4674
4675 let mut last_error: Option<Error> = None;
4676
4677 for attempt in 0..=max_retries {
4678 if attempt > 0 {
4679 // Reset this part's progress counter before retry
4680 part_bytes[part_idx].store(0, Ordering::SeqCst);
4681
4682 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4683 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4684 warn!(
4685 "Retry {}/{} for part {} after {:?}",
4686 attempt, max_retries, part_idx, delay
4687 );
4688 tokio::time::sleep(delay).await;
4689 }
4690
4691 match upload_part_streaming(
4692 http.clone(),
4693 url.clone(),
4694 path.clone(),
4695 part_idx,
4696 n_parts,
4697 part_size,
4698 total,
4699 upload_timeout_secs,
4700 confirmed_bytes.clone(),
4701 part_bytes.clone(),
4702 progress.clone(),
4703 )
4704 .await
4705 {
4706 Ok(etag) => return Ok(etag),
4707 Err(e) => {
4708 // Check if error is retryable
4709 let is_retryable = matches!(
4710 &e,
4711 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
4712 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
4713 );
4714
4715 if is_retryable && attempt < max_retries {
4716 last_error = Some(e);
4717 continue;
4718 }
4719
4720 return Err(e);
4721 }
4722 }
4723 }
4724
4725 Err(last_error
4726 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
4727}
4728
4729/// Perform the actual upload with streaming progress.
4730#[allow(clippy::too_many_arguments)]
4731async fn upload_part_streaming(
4732 http: reqwest::Client,
4733 url: String,
4734 path: PathBuf,
4735 part_idx: usize,
4736 n_parts: usize,
4737 _part_size: usize,
4738 total: usize,
4739 upload_timeout_secs: u64,
4740 confirmed_bytes: Arc<AtomicUsize>,
4741 part_bytes: Arc<Vec<AtomicUsize>>,
4742 progress: Option<Sender<Progress>>,
4743) -> Result<String, Error> {
4744 let filesize = path.metadata()?.len() as usize;
4745 let mut file = File::open(&path).await?;
4746 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
4747 .await?;
4748 let file = file.take(PART_SIZE as u64);
4749
4750 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
4751 filesize % PART_SIZE
4752 } else {
4753 PART_SIZE
4754 };
4755
4756 // Create stream with progress tracking
4757 let stream = FramedRead::new(file, BytesCodec::new());
4758
4759 // Wrap stream to track bytes sent and report progress
4760 let progress_stream = stream.map(move |result| {
4761 if let Ok(ref bytes) = result {
4762 let bytes_len = bytes.len();
4763 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
4764
4765 // Send progress update (fire-and-forget via try_send to avoid blocking)
4766 if let Some(ref progress) = progress {
4767 let current = confirmed_bytes.load(Ordering::SeqCst)
4768 + part_bytes
4769 .iter()
4770 .map(|p| p.load(Ordering::SeqCst))
4771 .sum::<usize>();
4772 // Best-effort progress reporting: use try_send to avoid blocking.
4773 // If the channel is full or closed, we intentionally skip this update
4774 // to avoid stalling the upload; subsequent updates will still be delivered.
4775 let _ = progress.try_send(Progress {
4776 current,
4777 total,
4778 status: None,
4779 });
4780 }
4781 }
4782 result.map(|b| b.freeze())
4783 });
4784
4785 let body = Body::wrap_stream(progress_stream);
4786
4787 let resp = http
4788 .put(url)
4789 .header(CONTENT_LENGTH, body_length)
4790 .timeout(Duration::from_secs(upload_timeout_secs))
4791 .body(body)
4792 .send()
4793 .await?
4794 .error_for_status()?;
4795
4796 let etag = resp
4797 .headers()
4798 .get("etag")
4799 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
4800 .to_str()
4801 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
4802 .to_owned();
4803
4804 // Studio Server requires etag without the quotes.
4805 let etag = etag
4806 .strip_prefix("\"")
4807 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
4808 let etag = etag
4809 .strip_suffix("\"")
4810 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
4811
4812 Ok(etag.to_owned())
4813}
4814
4815/// Upload a complete file to a presigned S3 URL using HTTP PUT.
4816///
4817/// This is used for populate_samples to upload files to S3 after
4818/// receiving presigned URLs from the server.
4819///
4820/// Includes explicit retry logic with exponential backoff for transient
4821/// failures.
4822async fn upload_file_to_presigned_url(
4823 http: reqwest::Client,
4824 url: &str,
4825 path: PathBuf,
4826) -> Result<(), Error> {
4827 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4828 .ok()
4829 .and_then(|s| s.parse().ok())
4830 .unwrap_or(5usize);
4831
4832 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4833 .ok()
4834 .and_then(|s| s.parse().ok())
4835 .unwrap_or(600u64);
4836
4837 // Read the entire file into memory once
4838 let file_data = fs::read(&path).await?;
4839 let file_size = file_data.len();
4840 let filename = path.file_name().unwrap_or_default().to_string_lossy();
4841
4842 let mut last_error: Option<Error> = None;
4843
4844 for attempt in 0..=max_retries {
4845 if attempt > 0 {
4846 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4847 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4848 warn!(
4849 "Retry {}/{} for upload '{}' after {:?}",
4850 attempt, max_retries, filename, delay
4851 );
4852 tokio::time::sleep(delay).await;
4853 }
4854
4855 // Attempt upload
4856 let result = http
4857 .put(url)
4858 .header(CONTENT_LENGTH, file_size)
4859 .timeout(Duration::from_secs(upload_timeout_secs))
4860 .body(file_data.clone())
4861 .send()
4862 .await;
4863
4864 match result {
4865 Ok(resp) => {
4866 if resp.status().is_success() {
4867 if attempt > 0 {
4868 debug!(
4869 "Upload '{}' succeeded on retry {} ({} bytes)",
4870 filename, attempt, file_size
4871 );
4872 } else {
4873 debug!(
4874 "Successfully uploaded file: {} ({} bytes)",
4875 filename, file_size
4876 );
4877 }
4878 return Ok(());
4879 }
4880
4881 let status = resp.status();
4882 let status_code = status.as_u16();
4883
4884 // Check if error is retryable
4885 let is_retryable =
4886 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
4887
4888 if is_retryable && attempt < max_retries {
4889 let error_text = resp.text().await.unwrap_or_default();
4890 warn!(
4891 "Upload '{}' failed with HTTP {} (retryable): {}",
4892 filename, status_code, error_text
4893 );
4894 last_error = Some(Error::InvalidParameters(format!(
4895 "Upload failed: HTTP {} - {}",
4896 status, error_text
4897 )));
4898 continue;
4899 }
4900
4901 // Non-retryable error or max retries exceeded
4902 let error_text = resp.text().await.unwrap_or_default();
4903 if attempt > 0 {
4904 error!(
4905 "Upload '{}' failed after {} retries: HTTP {} - {}",
4906 filename, attempt, status, error_text
4907 );
4908 }
4909 return Err(Error::InvalidParameters(format!(
4910 "Upload failed: HTTP {} - {}",
4911 status, error_text
4912 )));
4913 }
4914 Err(e) => {
4915 // Transport error (timeout, connection failure, etc.)
4916 let is_timeout = e.is_timeout();
4917 let is_connect = e.is_connect();
4918
4919 if (is_timeout || is_connect) && attempt < max_retries {
4920 warn!(
4921 "Upload '{}' transport error (retrying): {}",
4922 filename,
4923 if is_timeout {
4924 "timeout"
4925 } else {
4926 "connection failed"
4927 }
4928 );
4929 last_error = Some(Error::HttpError(e));
4930 continue;
4931 }
4932
4933 // Non-retryable or max retries exceeded
4934 if attempt > 0 {
4935 error!(
4936 "Upload '{}' failed after {} retries: {}",
4937 filename, attempt, e
4938 );
4939 }
4940 return Err(Error::HttpError(e));
4941 }
4942 }
4943 }
4944
4945 // Should not reach here, but return last error if we do
4946 Err(last_error.unwrap_or_else(|| {
4947 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
4948 }))
4949}
4950
4951/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
4952///
4953/// This is used for populate_samples to upload file content from memory
4954/// (e.g., from ZIP archives) without writing to disk first.
4955///
4956/// Includes explicit retry logic with exponential backoff for transient
4957/// failures.
4958async fn upload_bytes_to_presigned_url(
4959 http: reqwest::Client,
4960 url: &str,
4961 file_data: Vec<u8>,
4962 filename: &str,
4963) -> Result<(), Error> {
4964 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4965 .ok()
4966 .and_then(|s| s.parse().ok())
4967 .unwrap_or(5usize);
4968
4969 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4970 .ok()
4971 .and_then(|s| s.parse().ok())
4972 .unwrap_or(600u64);
4973
4974 let file_size = file_data.len();
4975 let mut last_error: Option<Error> = None;
4976
4977 for attempt in 0..=max_retries {
4978 if attempt > 0 {
4979 // Exponential backoff: 1s, 2s, 4s, 8s, ...
4980 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
4981 warn!(
4982 "Retry {}/{} for upload '{}' after {:?}",
4983 attempt, max_retries, filename, delay
4984 );
4985 tokio::time::sleep(delay).await;
4986 }
4987
4988 // Attempt upload
4989 let result = http
4990 .put(url)
4991 .header(CONTENT_LENGTH, file_size)
4992 .timeout(Duration::from_secs(upload_timeout_secs))
4993 .body(file_data.clone())
4994 .send()
4995 .await;
4996
4997 match result {
4998 Ok(resp) => {
4999 if resp.status().is_success() {
5000 if attempt > 0 {
5001 debug!(
5002 "Upload '{}' succeeded on retry {} ({} bytes)",
5003 filename, attempt, file_size
5004 );
5005 } else {
5006 debug!(
5007 "Successfully uploaded file: {} ({} bytes)",
5008 filename, file_size
5009 );
5010 }
5011 return Ok(());
5012 }
5013
5014 let status = resp.status();
5015 let status_code = status.as_u16();
5016
5017 // Check if error is retryable
5018 let is_retryable =
5019 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5020
5021 if is_retryable && attempt < max_retries {
5022 let error_text = resp.text().await.unwrap_or_default();
5023 warn!(
5024 "Upload '{}' failed with HTTP {} (retryable): {}",
5025 filename, status_code, error_text
5026 );
5027 last_error = Some(Error::InvalidParameters(format!(
5028 "Upload failed: HTTP {} - {}",
5029 status, error_text
5030 )));
5031 continue;
5032 }
5033
5034 // Non-retryable error or max retries exceeded
5035 let error_text = resp.text().await.unwrap_or_default();
5036 if attempt > 0 {
5037 error!(
5038 "Upload '{}' failed after {} retries: HTTP {} - {}",
5039 filename, attempt, status, error_text
5040 );
5041 }
5042 return Err(Error::InvalidParameters(format!(
5043 "Upload failed: HTTP {} - {}",
5044 status, error_text
5045 )));
5046 }
5047 Err(e) => {
5048 // Transport error (timeout, connection failure, etc.)
5049 let is_timeout = e.is_timeout();
5050 let is_connect = e.is_connect();
5051
5052 if (is_timeout || is_connect) && attempt < max_retries {
5053 warn!(
5054 "Upload '{}' transport error (retrying): {}",
5055 filename,
5056 if is_timeout {
5057 "timeout"
5058 } else {
5059 "connection failed"
5060 }
5061 );
5062 last_error = Some(Error::HttpError(e));
5063 continue;
5064 }
5065
5066 // Non-retryable or max retries exceeded
5067 if attempt > 0 {
5068 error!(
5069 "Upload '{}' failed after {} retries: {}",
5070 filename, attempt, e
5071 );
5072 }
5073 return Err(Error::HttpError(e));
5074 }
5075 }
5076 }
5077
5078 // Should not reach here, but return last error if we do
5079 Err(last_error.unwrap_or_else(|| {
5080 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5081 }))
5082}
5083
5084#[cfg(test)]
5085mod tests {
5086 use super::*;
5087
5088 #[test]
5089 fn test_filter_and_sort_by_name_exact_match_first() {
5090 // Test that exact matches come first
5091 let items = vec![
5092 "Deer Roundtrip 123".to_string(),
5093 "Deer".to_string(),
5094 "Reindeer".to_string(),
5095 "DEER".to_string(),
5096 ];
5097 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5098 assert_eq!(result[0], "Deer"); // Exact match first
5099 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
5100 }
5101
5102 #[test]
5103 fn test_filter_and_sort_by_name_shorter_names_preferred() {
5104 // Test that shorter names (more specific) come before longer ones
5105 let items = vec![
5106 "Test Dataset ABC".to_string(),
5107 "Test".to_string(),
5108 "Test Dataset".to_string(),
5109 ];
5110 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5111 assert_eq!(result[0], "Test"); // Exact match first
5112 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
5113 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
5114 }
5115
5116 #[test]
5117 fn test_filter_and_sort_by_name_case_insensitive_filter() {
5118 // Test that filtering is case-insensitive
5119 let items = vec![
5120 "UPPERCASE".to_string(),
5121 "lowercase".to_string(),
5122 "MixedCase".to_string(),
5123 ];
5124 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
5125 assert_eq!(result.len(), 3); // All items should match
5126 }
5127
5128 #[test]
5129 fn test_filter_and_sort_by_name_no_matches() {
5130 // Test that empty result is returned when no matches
5131 let items = vec!["Apple".to_string(), "Banana".to_string()];
5132 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
5133 assert!(result.is_empty());
5134 }
5135
5136 #[test]
5137 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
5138 // Test alphabetical ordering for same-length names
5139 let items = vec![
5140 "TestC".to_string(),
5141 "TestA".to_string(),
5142 "TestB".to_string(),
5143 ];
5144 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5145 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
5146 }
5147
5148 #[test]
5149 fn test_build_filename_no_flatten() {
5150 // When flatten=false, should return base_name unchanged
5151 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
5152 assert_eq!(result, "image.jpg");
5153
5154 let result = Client::build_filename("test.png", false, None, None);
5155 assert_eq!(result, "test.png");
5156 }
5157
5158 #[test]
5159 fn test_build_filename_flatten_no_sequence() {
5160 // When flatten=true but no sequence, should return base_name unchanged
5161 let result = Client::build_filename("standalone.jpg", true, None, None);
5162 assert_eq!(result, "standalone.jpg");
5163 }
5164
5165 #[test]
5166 fn test_build_filename_flatten_with_sequence_not_prefixed() {
5167 // When flatten=true, in sequence, filename not prefixed → add prefix
5168 let result = Client::build_filename(
5169 "image.camera.jpeg",
5170 true,
5171 Some(&"deer_sequence".to_string()),
5172 Some(42),
5173 );
5174 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
5175 }
5176
5177 #[test]
5178 fn test_build_filename_flatten_with_sequence_no_frame() {
5179 // When flatten=true, in sequence, no frame number → prefix with sequence only
5180 let result =
5181 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
5182 assert_eq!(result, "sequence_A_image.jpg");
5183 }
5184
5185 #[test]
5186 fn test_build_filename_flatten_already_prefixed() {
5187 // When flatten=true, filename already starts with sequence_ → return unchanged
5188 let result = Client::build_filename(
5189 "deer_sequence_042.camera.jpeg",
5190 true,
5191 Some(&"deer_sequence".to_string()),
5192 Some(42),
5193 );
5194 assert_eq!(result, "deer_sequence_042.camera.jpeg");
5195 }
5196
5197 #[test]
5198 fn test_build_filename_flatten_already_prefixed_different_frame() {
5199 // Edge case: filename has sequence prefix but we're adding different frame
5200 // Should still respect existing prefix
5201 let result = Client::build_filename(
5202 "sequence_A_001.jpg",
5203 true,
5204 Some(&"sequence_A".to_string()),
5205 Some(2),
5206 );
5207 assert_eq!(result, "sequence_A_001.jpg");
5208 }
5209
5210 #[test]
5211 fn test_build_filename_flatten_partial_match() {
5212 // Edge case: filename contains sequence name but not as prefix
5213 let result = Client::build_filename(
5214 "test_sequence_A_image.jpg",
5215 true,
5216 Some(&"sequence_A".to_string()),
5217 Some(5),
5218 );
5219 // Should add prefix because it doesn't START with "sequence_A_"
5220 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
5221 }
5222
5223 #[test]
5224 fn test_build_filename_flatten_preserves_extension() {
5225 // Verify that file extensions are preserved correctly
5226 let extensions = vec![
5227 "jpeg",
5228 "jpg",
5229 "png",
5230 "camera.jpeg",
5231 "lidar.pcd",
5232 "depth.png",
5233 ];
5234
5235 for ext in extensions {
5236 let filename = format!("image.{}", ext);
5237 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
5238 assert!(
5239 result.ends_with(&format!(".{}", ext)),
5240 "Extension .{} not preserved in {}",
5241 ext,
5242 result
5243 );
5244 }
5245 }
5246
5247 #[test]
5248 fn test_build_filename_flatten_sanitization_compatibility() {
5249 // Test with sanitized path components (no special chars)
5250 let result = Client::build_filename(
5251 "sample_001.jpg",
5252 true,
5253 Some(&"seq_name_with_underscores".to_string()),
5254 Some(10),
5255 );
5256 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
5257 }
5258
5259 // =========================================================================
5260 // Additional filter_and_sort_by_name tests for exact match determinism
5261 // =========================================================================
5262
5263 #[test]
5264 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
5265 // Test that searching for "Deer" always returns "Deer" first, not
5266 // "Deer Roundtrip 20251129" or similar
5267 let items = vec![
5268 "Deer Roundtrip 20251129".to_string(),
5269 "White-Tailed Deer".to_string(),
5270 "Deer".to_string(),
5271 "Deer Snapshot Test".to_string(),
5272 "Reindeer Dataset".to_string(),
5273 ];
5274
5275 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5276
5277 // CRITICAL: First result must be exact match "Deer"
5278 assert_eq!(
5279 result.first().map(|s| s.as_str()),
5280 Some("Deer"),
5281 "Expected exact match 'Deer' first, got: {:?}",
5282 result.first()
5283 );
5284
5285 // Verify all items containing "Deer" are present (case-insensitive)
5286 assert_eq!(result.len(), 5);
5287 }
5288
5289 #[test]
5290 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
5291 // Verify case-sensitive exact match takes priority over case-insensitive
5292 let items = vec![
5293 "DEER".to_string(),
5294 "deer".to_string(),
5295 "Deer".to_string(),
5296 "Deer Test".to_string(),
5297 ];
5298
5299 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5300
5301 // Priority 1: Case-sensitive exact match "Deer" first
5302 assert_eq!(result[0], "Deer");
5303 // Priority 2: Case-insensitive exact matches next
5304 assert!(result[1] == "DEER" || result[1] == "deer");
5305 assert!(result[2] == "DEER" || result[2] == "deer");
5306 }
5307
5308 #[test]
5309 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
5310 // Realistic scenario: User searches for snapshot "Deer" and multiple
5311 // snapshots exist with similar names
5312 let items = vec![
5313 "Unit Testing - Deer Dataset Backup".to_string(),
5314 "Deer".to_string(),
5315 "Deer Snapshot 2025-01-15".to_string(),
5316 "Original Deer".to_string(),
5317 ];
5318
5319 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5320
5321 // MUST return exact match first for deterministic test behavior
5322 assert_eq!(
5323 result[0], "Deer",
5324 "Searching for 'Deer' should return exact 'Deer' first"
5325 );
5326 }
5327
5328 #[test]
5329 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
5330 // Realistic scenario: User searches for dataset "Deer" but multiple
5331 // datasets have "Deer" in their name
5332 let items = vec![
5333 "Deer Roundtrip".to_string(),
5334 "Deer".to_string(),
5335 "deer".to_string(),
5336 "White-Tailed Deer".to_string(),
5337 "Deer-V2".to_string(),
5338 ];
5339
5340 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5341
5342 // Exact case-sensitive match must be first
5343 assert_eq!(result[0], "Deer");
5344 // Case-insensitive exact match should be second
5345 assert_eq!(result[1], "deer");
5346 // Shorter names should come before longer names
5347 assert!(
5348 result.iter().position(|s| s == "Deer-V2").unwrap()
5349 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
5350 );
5351 }
5352
5353 #[test]
5354 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
5355 // CRITICAL: The first result should ALWAYS be the best match
5356 // This is essential for deterministic test behavior
5357 let scenarios = vec![
5358 // (items, filter, expected_first)
5359 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
5360 (vec!["test", "TEST", "Test Data"], "test", "test"),
5361 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
5362 ];
5363
5364 for (items, filter, expected_first) in scenarios {
5365 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
5366 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
5367
5368 assert_eq!(
5369 result.first().map(|s| s.as_str()),
5370 Some(expected_first),
5371 "For filter '{}', expected first result '{}', got: {:?}",
5372 filter,
5373 expected_first,
5374 result.first()
5375 );
5376 }
5377 }
5378
5379 #[test]
5380 fn test_with_server_clears_storage() {
5381 use crate::storage::MemoryTokenStorage;
5382
5383 // Create client with memory storage and a token
5384 let storage = Arc::new(MemoryTokenStorage::new());
5385 storage.store("test-token").unwrap();
5386
5387 let client = Client::new().unwrap().with_storage(storage.clone());
5388
5389 // Verify token is loaded
5390 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
5391
5392 // Change server - should clear storage
5393 let _new_client = client.with_server("test").unwrap();
5394
5395 // Verify storage was cleared
5396 assert_eq!(storage.load().unwrap(), None);
5397 }
5398}