edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult,
8 NewValidationSession, Organization, Project, ProjectID, SampleID, SamplesCountResult,
9 SamplesListParams, SamplesListResult, Snapshot, SnapshotCreateFromDataset,
10 SnapshotFromDatasetResult, SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage,
11 StartValidationRequest, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
12 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
13 ValidationSessionID,
14 },
15 dataset::{
16 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
17 },
18 retry::{create_retry_policy, log_retry_configuration},
19 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
20};
21use base64::Engine as _;
22use chrono::{DateTime, Utc};
23use directories::ProjectDirs;
24use futures::{StreamExt as _, future::join_all};
25use log::{Level, debug, error, log_enabled, trace, warn};
26use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::{
29 collections::HashMap,
30 ffi::OsStr,
31 fs::create_dir_all,
32 io::{SeekFrom, Write as _},
33 path::{Path, PathBuf},
34 sync::{
35 Arc,
36 atomic::{AtomicUsize, Ordering},
37 },
38 time::Duration,
39 vec,
40};
41use tokio::{
42 fs::{self, File},
43 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
44 sync::{RwLock, Semaphore, mpsc::Sender},
45};
46use tokio_util::codec::{BytesCodec, FramedRead};
47use walkdir::WalkDir;
48
49#[cfg(feature = "polars")]
50use polars::prelude::*;
51
52/// Maps a JSON-RPC error code to a typed `Error` variant when the code is
53/// well-known; otherwise returns `Error::RpcError(code, message)` unchanged.
54///
55/// Scoped to the new DE-2565 methods. Existing methods continue to return
56/// `Error::RpcError` directly.
57///
58/// Server error codes (from `api.go` via `jrpc.Fail`):
59/// - `1` – generic server error
60/// - `3` – validation / bad request
61/// - `10` – internal server error
62/// - `101` – resource not found (e.g. "Cannot find task...", "not found in DB")
63/// - `401` – unauthenticated
64/// - `403` – forbidden
65/// - `413` – payload too large
66pub(crate) fn map_rpc_error(
67 method: &str,
68 code: i32,
69 message: String,
70 task_id: Option<crate::api::TaskID>,
71) -> Error {
72 // Server emits "Cannot find task...", "not found in DB", and other phrasings
73 // for code 101. Code 101 with a task_id is task-not-found by contract
74 // (see api.go), so we return the typed variant unconditionally when the
75 // caller supplied a task_id — message phrasing is treated as informational
76 // and is preserved by the RPC layer for diagnostic logging upstream.
77 if code == 101
78 && let Some(id) = task_id
79 {
80 return Error::TaskNotFound(id);
81 }
82 match code {
83 401 | 403 => Error::PermissionDenied(method.to_string()),
84 413 => Error::PayloadTooLarge {
85 method: method.to_string(),
86 size_hint: None,
87 },
88 _ => Error::RpcError(code, message),
89 }
90}
91
92/// Returns true if `val` is structurally a JSON-RPC 2.0 *error* envelope.
93///
94/// A real envelope must:
95/// 1. Be a JSON object,
96/// 2. Carry a `"jsonrpc"` member (the protocol-version sentinel — JSON-RPC
97/// 2.0 §5 mandates this on every response object),
98/// 3. Carry an `"error"` object that includes a numeric `"code"` field.
99///
100/// This is intentionally stricter than a "looks for a top-level `error`
101/// key" check so that legitimate JSON file payloads (validation traces,
102/// metrics dumps, diagnostics) which happen to include a free-form `error`
103/// field are *not* misclassified as RPC failures.
104///
105/// Extracted so it can be unit-tested without a live server.
106pub(crate) fn is_jsonrpc_error_envelope(val: &serde_json::Value) -> bool {
107 let Some(obj) = val.as_object() else {
108 return false;
109 };
110 // Protocol-version sentinel — only JSON-RPC envelopes carry this.
111 if !obj.contains_key("jsonrpc") {
112 return false;
113 }
114 let Some(err) = obj.get("error").and_then(|e| e.as_object()) else {
115 return false;
116 };
117 err.get("code")
118 .map(|c| c.is_i64() || c.is_u64())
119 .unwrap_or(false)
120}
121
122/// Validates that `group` and `name` are both non-empty strings for chart
123/// operations (`add_chart`, `get_chart`). Extracted so it can be unit-tested
124/// without a live server.
125pub(crate) fn validate_chart_args(group: &str, name: &str) -> Result<(), Error> {
126 if group.is_empty() || name.is_empty() {
127 return Err(Error::InvalidParameters(
128 "chart: group and name must be non-empty".into(),
129 ));
130 }
131 Ok(())
132}
133
134static PART_SIZE: usize = 100 * 1024 * 1024;
135
136/// Source for file content during upload - either a local path or raw bytes.
137#[derive(Clone)]
138enum FileSource {
139 /// File content from a local filesystem path.
140 Path(PathBuf),
141 /// File content as raw bytes (e.g., from a ZIP archive).
142 Bytes(Vec<u8>),
143}
144
145fn max_tasks() -> usize {
146 std::env::var("MAX_TASKS")
147 .ok()
148 .and_then(|v| v.parse().ok())
149 .unwrap_or_else(|| {
150 // Default to half the number of CPUs, minimum 2, maximum 8
151 let cpus = std::thread::available_parallelism()
152 .map(|n| n.get())
153 .unwrap_or(4);
154 (cpus / 2).clamp(2, 8)
155 })
156}
157
158/// Maximum concurrent upload tasks for multipart S3 uploads.
159///
160/// Higher concurrency improves upload throughput by saturating available
161/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
162fn max_upload_tasks() -> usize {
163 std::env::var("MAX_UPLOAD_TASKS")
164 .ok()
165 .and_then(|v| v.parse().ok())
166 .unwrap_or(8) // Default to 8 concurrent part uploads
167}
168
169/// Filters items by name and sorts by match quality.
170///
171/// Match quality priority (best to worst):
172/// 1. Exact match (case-sensitive)
173/// 2. Exact match (case-insensitive)
174/// 3. Substring match (shorter names first, then alphabetically)
175///
176/// This ensures that searching for "Deer" returns "Deer" before
177/// "Deer Roundtrip 20251129" or "Reindeer".
178fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
179where
180 F: Fn(&T) -> &str,
181{
182 let filter_lower = filter.to_lowercase();
183 let mut filtered: Vec<T> = items
184 .into_iter()
185 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
186 .collect();
187
188 filtered.sort_by(|a, b| {
189 let name_a = get_name(a);
190 let name_b = get_name(b);
191
192 // Priority 1: Exact match (case-sensitive)
193 let exact_a = name_a == filter;
194 let exact_b = name_b == filter;
195 if exact_a != exact_b {
196 return exact_b.cmp(&exact_a); // true (exact) comes first
197 }
198
199 // Priority 2: Exact match (case-insensitive)
200 let exact_ci_a = name_a.to_lowercase() == filter_lower;
201 let exact_ci_b = name_b.to_lowercase() == filter_lower;
202 if exact_ci_a != exact_ci_b {
203 return exact_ci_b.cmp(&exact_ci_a);
204 }
205
206 // Priority 3: Shorter names first (more specific matches)
207 let len_cmp = name_a.len().cmp(&name_b.len());
208 if len_cmp != std::cmp::Ordering::Equal {
209 return len_cmp;
210 }
211
212 // Priority 4: Alphabetical order for stability
213 name_a.cmp(name_b)
214 });
215
216 filtered
217}
218
219/// Whether `host` refers to a loopback (machine-local) endpoint.
220///
221/// Used by [`Client::with_url`] to decide whether a plain-`http://` URL is
222/// safe to accept. Loopback traffic never leaves the machine, so the
223/// usual concern about leaking the Studio bearer token in plaintext does
224/// not apply — that's how wiremock and local dev servers connect.
225fn is_loopback_host(host: Option<&url::Host<&str>>) -> bool {
226 match host {
227 Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
228 Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
229 // RFC 6761 reserves "localhost" (and `*.localhost`) as a loopback
230 // name. Compare case-insensitively because URL hosts are matched
231 // that way and developers do type capitalized variants.
232 Some(url::Host::Domain(d)) => {
233 d.eq_ignore_ascii_case("localhost") || d.to_ascii_lowercase().ends_with(".localhost")
234 }
235 None => false,
236 }
237}
238
239fn sanitize_path_component(name: &str) -> String {
240 let trimmed = name.trim();
241 if trimmed.is_empty() {
242 return "unnamed".to_string();
243 }
244
245 let component = Path::new(trimmed)
246 .file_name()
247 .unwrap_or_else(|| OsStr::new(trimmed));
248
249 let sanitized: String = component
250 .to_string_lossy()
251 .chars()
252 .map(|c| match c {
253 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
254 _ => c,
255 })
256 .collect();
257
258 if sanitized.is_empty() {
259 "unnamed".to_string()
260 } else {
261 sanitized
262 }
263}
264
265/// Progress information for long-running operations.
266///
267/// This struct tracks the current progress of operations like file uploads,
268/// downloads, or dataset processing. It provides the current count, total
269/// count, and an optional status string to enable progress reporting in
270/// applications.
271///
272/// # Multi-Stage Progress
273///
274/// The `status` field enables multi-stage progress tracking. When an operation
275/// has multiple phases, the status field changes to indicate the current phase.
276/// Applications should detect status changes to reset their progress display.
277///
278/// # Operation Progress Details
279///
280/// | Operation | Status | Unit | Notes |
281/// |-----------|--------|------|-------|
282/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
283/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
284/// | [`samples`] | `None` | samples | Paginated API fetch |
285/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
286/// | [`annotations`] | `None` | samples | Samples processed for annotations |
287/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
288/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
289/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
290///
291/// [`download_dataset`]: Client::download_dataset
292/// [`populate_samples`]: Client::populate_samples
293/// [`samples`]: Client::samples
294/// [`sample_names`]: Client::sample_names
295/// [`annotations`]: Client::annotations
296/// [`download_artifact`]: Client::download_artifact
297/// [`download_checkpoint`]: Client::download_checkpoint
298/// [`download_snapshot`]: Client::download_snapshot
299///
300/// # Examples
301///
302/// Basic progress display:
303///
304/// ```rust
305/// use edgefirst_client::Progress;
306///
307/// let progress = Progress {
308/// current: 25,
309/// total: 100,
310/// status: Some("Downloading".to_string()),
311/// };
312/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
313/// println!(
314/// "{}: {:.1}% ({}/{})",
315/// progress.status.as_deref().unwrap_or("Progress"),
316/// percentage,
317/// progress.current,
318/// progress.total
319/// );
320/// ```
321///
322/// Multi-stage progress handling (e.g., for `download_dataset`):
323///
324/// ```rust,ignore
325/// let mut last_status: Option<String> = None;
326///
327/// while let Some(progress) = rx.recv().await {
328/// // Detect stage change and reset progress bar
329/// if progress.status != last_status {
330/// if let Some(ref status) = progress.status {
331/// println!("\n{}", status);
332/// }
333/// last_status = progress.status.clone();
334/// }
335///
336/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
337/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
338/// }
339/// ```
340#[derive(Debug, Clone)]
341pub struct Progress {
342 /// Current number of completed items or bytes.
343 pub current: usize,
344 /// Total number of items or bytes to process.
345 pub total: usize,
346 /// Optional status describing the current operation phase.
347 ///
348 /// When this value changes from `None` to `Some(...)` or between different
349 /// values, it indicates a new phase has started. Applications should reset
350 /// their progress display when the status changes.
351 ///
352 /// Currently only [`Client::download_dataset`] uses status changes:
353 /// - Phase 1: `None` while fetching sample metadata
354 /// - Phase 2: `"Downloading"` while downloading files
355 ///
356 /// All other operations use `None` throughout.
357 pub status: Option<String>,
358}
359
360#[derive(Serialize)]
361struct RpcRequest<Params> {
362 id: u64,
363 jsonrpc: String,
364 method: String,
365 params: Option<Params>,
366}
367
368impl<T> Default for RpcRequest<T> {
369 fn default() -> Self {
370 RpcRequest {
371 id: 0,
372 jsonrpc: "2.0".to_string(),
373 method: "".to_string(),
374 params: None,
375 }
376 }
377}
378
379#[derive(Deserialize)]
380struct RpcError {
381 code: i32,
382 message: String,
383}
384
385#[derive(Deserialize)]
386struct RpcResponse<RpcResult> {
387 #[allow(dead_code)]
388 id: String,
389 #[allow(dead_code)]
390 jsonrpc: String,
391 error: Option<RpcError>,
392 result: Option<RpcResult>,
393}
394
395#[derive(Deserialize)]
396#[allow(dead_code)]
397struct EmptyResult {}
398
399#[derive(Debug, Serialize)]
400#[allow(dead_code)]
401struct SnapshotCreateParams {
402 snapshot_name: String,
403 keys: Vec<String>,
404}
405
406#[derive(Debug, Deserialize)]
407#[allow(dead_code)]
408struct SnapshotCreateResult {
409 snapshot_id: SnapshotID,
410 urls: Vec<String>,
411}
412
413#[derive(Debug, Serialize)]
414struct SnapshotCreateMultipartParams {
415 snapshot_name: String,
416 keys: Vec<String>,
417 file_sizes: Vec<usize>,
418 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
419 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
420 snapshot_type: Option<String>,
421}
422
423#[derive(Debug, Deserialize)]
424#[serde(untagged)]
425enum SnapshotCreateMultipartResultField {
426 Id(u64),
427 Part(SnapshotPart),
428}
429
430#[derive(Debug, Serialize)]
431struct SnapshotCompleteMultipartParams {
432 key: String,
433 upload_id: String,
434 etag_list: Vec<EtagPart>,
435}
436
437#[derive(Debug, Clone, Serialize)]
438struct EtagPart {
439 #[serde(rename = "ETag")]
440 etag: String,
441 #[serde(rename = "PartNumber")]
442 part_number: usize,
443}
444
445#[derive(Debug, Clone, Deserialize)]
446struct SnapshotPart {
447 key: Option<String>,
448 upload_id: String,
449 urls: Vec<String>,
450}
451
452#[derive(Debug, Serialize)]
453struct SnapshotStatusParams {
454 snapshot_id: SnapshotID,
455 status: String,
456}
457
458#[derive(Deserialize, Debug)]
459struct SnapshotStatusResult {
460 #[allow(dead_code)]
461 pub id: SnapshotID,
462 #[allow(dead_code)]
463 pub uid: String,
464 #[allow(dead_code)]
465 pub description: String,
466 #[allow(dead_code)]
467 pub date: String,
468 #[allow(dead_code)]
469 pub status: String,
470}
471
472#[derive(Serialize)]
473#[allow(dead_code)]
474struct ImageListParams {
475 images_filter: ImagesFilter,
476 image_files_filter: HashMap<String, String>,
477 only_ids: bool,
478}
479
480#[derive(Serialize)]
481#[allow(dead_code)]
482struct ImagesFilter {
483 dataset_id: DatasetID,
484}
485
486/// Main client for interacting with EdgeFirst Studio Server.
487///
488/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
489/// and manages authentication, RPC calls, and data operations. It provides
490/// methods for managing projects, datasets, experiments, training sessions,
491/// and various utility functions for data processing.
492///
493/// The client supports multiple authentication methods and can work with both
494/// SaaS and self-hosted EdgeFirst Studio instances.
495///
496/// # Features
497///
498/// - **Authentication**: Token-based authentication with automatic persistence
499/// - **Dataset Management**: Upload, download, and manipulate datasets
500/// - **Project Operations**: Create and manage projects and experiments
501/// - **Training & Validation**: Submit and monitor ML training jobs
502/// - **Data Integration**: Convert between EdgeFirst datasets and popular
503/// formats
504/// - **Progress Tracking**: Real-time progress updates for long-running
505/// operations
506///
507/// # Examples
508///
509/// ```no_run
510/// use edgefirst_client::{Client, DatasetID};
511/// use std::str::FromStr;
512///
513/// # async fn example() -> Result<(), edgefirst_client::Error> {
514/// // Create a new client and authenticate
515/// let mut client = Client::new()?;
516/// let client = client
517/// .with_login("your-email@example.com", "password")
518/// .await?;
519///
520/// // Or use an existing token
521/// let base_client = Client::new()?;
522/// let client = base_client.with_token("your-token-here")?;
523///
524/// // Get organization and projects
525/// let org = client.organization().await?;
526/// let projects = client.projects(None).await?;
527///
528/// // Work with datasets
529/// let dataset_id = DatasetID::from_str("ds-abc123")?;
530/// let dataset = client.dataset(dataset_id).await?;
531/// # Ok(())
532/// # }
533/// ```
534/// Client is Clone but cannot derive Debug due to dyn TokenStorage
535#[derive(Clone)]
536pub struct Client {
537 http: reqwest::Client,
538 /// HTTP client for long-running bulk transfers (uploads/downloads, no total-request
539 /// timeout). An idle read timeout is still configured on the underlying client, and
540 /// some operations (such as uploads) may apply additional per-request timeouts.
541 bulk_http: reqwest::Client,
542 url: String,
543 token: Arc<RwLock<String>>,
544 /// Token storage backend. When set, tokens are automatically persisted.
545 storage: Option<Arc<dyn TokenStorage>>,
546 /// Legacy token path field for backwards compatibility with
547 /// with_token_path(). Deprecated: Use with_storage() instead.
548 token_path: Option<PathBuf>,
549}
550
551impl std::fmt::Debug for Client {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("Client")
554 .field("url", &self.url)
555 .field("has_storage", &self.storage.is_some())
556 .field("token_path", &self.token_path)
557 .finish()
558 }
559}
560
561/// Private context struct for pagination operations
562struct FetchContext<'a> {
563 dataset_id: DatasetID,
564 annotation_set_id: Option<AnnotationSetID>,
565 groups: &'a [String],
566 types: Vec<String>,
567 labels: &'a HashMap<String, u64>,
568}
569
570#[derive(Debug, Serialize)]
571struct JobsListRequest {}
572
573#[derive(Debug, Serialize)]
574struct JobRunRequest {
575 name: String,
576 job_name: String,
577 env: std::collections::HashMap<String, String>,
578 data: std::collections::HashMap<String, crate::api::Parameter>,
579}
580
581#[derive(Debug, Serialize)]
582struct JobStopRequest {
583 task_id: u64,
584}
585
586#[derive(Debug, Serialize)]
587pub(crate) struct TaskDataListRequest {
588 pub(crate) task_id: u64,
589}
590
591#[derive(Debug, Serialize)]
592pub(crate) struct TaskDataDownloadRequest {
593 pub(crate) task_id: u64,
594 pub(crate) folder: String,
595 pub(crate) file: String,
596}
597
598#[derive(Debug, Serialize)]
599pub(crate) struct TaskChartAddRequest {
600 pub(crate) task_id: u64,
601 pub(crate) group_name: String,
602 pub(crate) chart_name: String,
603 pub(crate) params: Option<crate::api::Parameter>,
604 pub(crate) data: crate::api::Parameter,
605}
606
607#[derive(Debug, Serialize)]
608pub(crate) struct TaskChartListRequest {
609 pub(crate) task_id: u64,
610 pub(crate) group_name: String,
611}
612
613#[derive(Debug, Serialize)]
614pub(crate) struct TaskChartGetRequest {
615 pub(crate) task_id: u64,
616 pub(crate) group_name: String,
617 pub(crate) chart_name: String,
618}
619
620#[derive(Debug, Serialize)]
621pub(crate) struct ValDataDownloadRequest {
622 pub(crate) session_id: u64,
623 pub(crate) filename: String,
624}
625
626#[derive(Debug, Serialize)]
627pub(crate) struct ValDataListRequest {
628 pub(crate) session_id: u64,
629}
630
631/// Streams the body of a successful `reqwest` response to a file on disk,
632/// emitting optional progress events.
633///
634/// Both `download_artifact` and `rpc_download` share this logic. The caller is
635/// responsible for creating any required parent directories before calling this
636/// function.
637///
638/// # Arguments
639/// * `resp` - A successful (HTTP 2xx) `reqwest::Response` whose body will
640/// be streamed to `path`.
641/// * `path` - Destination file path (created or truncated).
642/// * `progress` - Optional channel; events carry bytes received and
643/// `Content-Length` total (0 if the server omits it).
644///
645/// # Errors
646/// Returns `Error::IoError` on file I/O failures or propagates stream errors.
647async fn stream_response_to_file(
648 resp: reqwest::Response,
649 path: &std::path::Path,
650 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
651) -> Result<(), Error> {
652 use tokio::io::AsyncWriteExt as _;
653 let total = resp.content_length().unwrap_or(0) as usize;
654 let mut stream = resp.bytes_stream();
655 let mut file = tokio::fs::File::create(path).await?;
656 let mut current = 0usize;
657
658 if let Some(ref tx) = progress {
659 let _ = tx
660 .send(Progress {
661 current: 0,
662 total,
663 status: None,
664 })
665 .await;
666 }
667
668 while let Some(chunk) = stream.next().await {
669 let chunk = chunk?;
670 file.write_all(&chunk).await?;
671 current += chunk.len();
672 if let Some(ref tx) = progress {
673 let _ = tx
674 .send(Progress {
675 current,
676 total,
677 status: None,
678 })
679 .await;
680 }
681 }
682
683 // Flush tokio's internal write buffer to the OS before returning.
684 // tokio::fs::File buffers writes internally; without this, the buffer
685 // may not reach the filesystem before the caller reads the file.
686 file.flush().await?;
687 Ok(())
688}
689
690impl Client {
691 /// Create a new unauthenticated client with the default saas server.
692 ///
693 /// By default, the client uses [`FileTokenStorage`] for token persistence.
694 /// Use [`with_storage`][Self::with_storage],
695 /// [`with_memory_storage`][Self::with_memory_storage],
696 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
697 /// behavior.
698 ///
699 /// To connect to a different server, use [`with_server`][Self::with_server]
700 /// or [`with_token`][Self::with_token] (tokens include the server
701 /// instance).
702 ///
703 /// This client is created without a token and will need to authenticate
704 /// before using methods that require authentication.
705 ///
706 /// # Examples
707 ///
708 /// ```rust,no_run
709 /// use edgefirst_client::Client;
710 ///
711 /// # fn main() -> Result<(), edgefirst_client::Error> {
712 /// // Create client with default file storage
713 /// let client = Client::new()?;
714 ///
715 /// // Create client without token persistence
716 /// let client = Client::new()?.with_memory_storage();
717 /// # Ok(())
718 /// # }
719 /// ```
720 pub fn new() -> Result<Self, Error> {
721 log_retry_configuration();
722
723 // Get timeout from environment or use default
724 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
725 .ok()
726 .and_then(|s| s.parse().ok())
727 .unwrap_or(30); // Default 30s total deadline for API calls
728
729 // Per-chunk idle timeout for bulk transfers: fires only when no bytes
730 // arrive for this duration. Resets after every received chunk, so a
731 // healthy multi-GB transfer will never be interrupted.
732 let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
733 .ok()
734 .and_then(|s| s.parse().ok())
735 .unwrap_or(120); // Default 120s idle timeout for bulk transfers
736
737 // Create single HTTP client with URL-based retry policy
738 //
739 // The retry policy classifies requests into two categories:
740 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
741 // errors
742 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
743 //
744 // This allows the same client to handle both API calls and file operations
745 // with appropriate retry behavior for each. See retry.rs for details.
746 let http = reqwest::Client::builder()
747 .connect_timeout(Duration::from_secs(10))
748 .timeout(Duration::from_secs(timeout_secs))
749 .pool_idle_timeout(Duration::from_secs(90))
750 .pool_max_idle_per_host(10)
751 .retry(create_retry_policy())
752 .build()?;
753
754 // Separate HTTP client for bulk transfers (uploads and downloads).
755 // No total-request timeout (EDGEFIRST_TIMEOUT does not apply here).
756 // Uses read_timeout instead: resets after every received chunk, so a
757 // healthy large transfer is never interrupted, but a truly stalled
758 // connection (no bytes for EDGEFIRST_READ_TIMEOUT seconds) is aborted.
759 let bulk_http = reqwest::Client::builder()
760 .connect_timeout(Duration::from_secs(30))
761 .read_timeout(Duration::from_secs(read_timeout_secs))
762 .pool_idle_timeout(Duration::from_secs(90))
763 // Bulk file transfers fan out to many concurrent presigned-URL
764 // uploads — up to `EDGEFIRST_UPLOAD_BATCHES` pipelined batches ×
765 // `max_tasks()` uploads each. Keep enough idle connections warm to
766 // reuse across that fan-out instead of churning new TLS handshakes.
767 .pool_max_idle_per_host(64)
768 .retry(create_retry_policy())
769 .build()?;
770
771 // Default to file storage, loading any existing token
772 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
773 Ok(file_storage) => Arc::new(file_storage),
774 Err(e) => {
775 warn!(
776 "Could not initialize file token storage: {}. Using memory storage.",
777 e
778 );
779 Arc::new(MemoryTokenStorage::new())
780 }
781 };
782
783 // Try to load existing token from storage
784 let token = match storage.load() {
785 Ok(Some(t)) => t,
786 Ok(None) => String::new(),
787 Err(e) => {
788 warn!(
789 "Failed to load token from storage: {}. Starting with empty token.",
790 e
791 );
792 String::new()
793 }
794 };
795
796 // Extract server from token if available
797 let url = if !token.is_empty() {
798 match Self::extract_server_from_token(&token) {
799 Ok(server) => format!("https://{}.edgefirst.studio", server),
800 Err(e) => {
801 warn!(
802 "Failed to extract server from token: {}. Using default server.",
803 e
804 );
805 "https://edgefirst.studio".to_string()
806 }
807 }
808 } else {
809 "https://edgefirst.studio".to_string()
810 };
811
812 Ok(Client {
813 http,
814 bulk_http,
815 url,
816 token: Arc::new(tokio::sync::RwLock::new(token)),
817 storage: Some(storage),
818 token_path: None,
819 })
820 }
821
822 /// Returns a new client connected to the specified server instance.
823 ///
824 /// The server parameter is an instance name that maps to a URL:
825 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
826 /// server)
827 /// - `"test"` → `https://test.edgefirst.studio`
828 /// - `"stage"` → `https://stage.edgefirst.studio`
829 /// - `"dev"` → `https://dev.edgefirst.studio`
830 /// - `"{name}"` → `https://{name}.edgefirst.studio`
831 ///
832 /// # Server Selection Priority
833 ///
834 /// When using the CLI or Python API, server selection follows this
835 /// priority:
836 ///
837 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
838 /// they were issued for. If you have a valid token, its server is used.
839 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
840 /// token is available. If a token exists with a different server, a
841 /// warning is emitted and the token's server takes priority.
842 /// 3. **Default `"saas"`** - If no token and no server specified, the
843 /// production server (`https://edgefirst.studio`) is used.
844 ///
845 /// # Important Notes
846 ///
847 /// - If a token is already set in the client, calling this method will
848 /// **drop the token** as tokens are specific to the server instance.
849 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
850 /// token's server before calling this method.
851 /// - For login operations, call `with_server()` first, then authenticate.
852 ///
853 /// # Examples
854 ///
855 /// ```rust,no_run
856 /// use edgefirst_client::Client;
857 ///
858 /// # fn main() -> Result<(), edgefirst_client::Error> {
859 /// let client = Client::new()?.with_server("test")?;
860 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
861 /// # Ok(())
862 /// # }
863 /// ```
864 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
865 // Resolve the target URL. Full URLs (self-hosted Studio,
866 // wiremock) are validated through `with_url` so the HTTPS rules
867 // there apply uniformly. Short names map to the SaaS pattern.
868 // We extract only the URL string and rebuild the Client below,
869 // because `with_url` preserves the in-memory token (the contract
870 // for self-hosted deployments) whereas `with_server` deliberately
871 // clears it (a different server means a stale token).
872 let url = if server.starts_with("http://") || server.starts_with("https://") {
873 self.with_url(server)?.url().to_string()
874 } else {
875 match server {
876 "" | "saas" => "https://edgefirst.studio".to_string(),
877 name => format!("https://{}.edgefirst.studio", name),
878 }
879 };
880
881 // Clear token from storage when changing servers to prevent
882 // authentication issues with stale tokens from different
883 // instances. This runs whether the caller passed a short name
884 // or a full URL — both reach a new server.
885 if let Some(ref storage) = self.storage
886 && let Err(e) = storage.clear()
887 {
888 warn!(
889 "Failed to clear token from storage when changing servers: {}",
890 e
891 );
892 }
893
894 Ok(Client {
895 url,
896 token: Arc::new(tokio::sync::RwLock::new(String::new())),
897 ..self.clone()
898 })
899 }
900
901 /// Returns a new client pointed at an explicit URL.
902 ///
903 /// Used for self-hosted Studio deployments (e.g.
904 /// `https://studio.example.com`) and for offline integration tests
905 /// against a mock HTTP server (e.g. `http://127.0.0.1:8080`). The
906 /// token is preserved so callers can chain
907 /// `Client::new()?.with_url(...)?.with_token(...)`.
908 ///
909 /// # Errors
910 ///
911 /// Returns [`Error::UrlParseError`] for syntactically invalid URLs and
912 /// [`Error::InsecureUrl`] for plain `http://` URLs that resolve to a
913 /// non-loopback host: the Studio bearer token rides in the
914 /// `Authorization` header, and plain HTTP would leak it in the clear.
915 /// Loopback URLs (`127.0.0.1`, `::1`, `localhost`, `*.localhost`) are
916 /// permitted because traffic never leaves the machine — wiremock and
917 /// local dev servers go through that path.
918 pub fn with_url(&self, url: &str) -> Result<Self, Error> {
919 // Reject malformed inputs early so test failures point at the test
920 // rather than a downstream reqwest send.
921 let parsed = url::Url::parse(url)?;
922 let scheme = parsed.scheme();
923 if scheme == "http" {
924 if !is_loopback_host(parsed.host().as_ref()) {
925 return Err(Error::InsecureUrl(url.to_string()));
926 }
927 } else if scheme != "https" {
928 return Err(Error::InsecureUrl(url.to_string()));
929 }
930 Ok(Client {
931 url: url.trim_end_matches('/').to_string(),
932 ..self.clone()
933 })
934 }
935
936 /// Returns a new client with the specified token storage backend.
937 ///
938 /// Use this to configure custom token storage, such as platform-specific
939 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
940 ///
941 /// # Examples
942 ///
943 /// ```rust,no_run
944 /// use edgefirst_client::{Client, FileTokenStorage};
945 /// use std::{path::PathBuf, sync::Arc};
946 ///
947 /// # fn main() -> Result<(), edgefirst_client::Error> {
948 /// // Use a custom file path for token storage
949 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
950 /// let client = Client::new()?.with_storage(Arc::new(storage));
951 /// # Ok(())
952 /// # }
953 /// ```
954 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
955 // Try to load existing token from the new storage
956 let token = match storage.load() {
957 Ok(Some(t)) => t,
958 Ok(None) => String::new(),
959 Err(e) => {
960 warn!(
961 "Failed to load token from storage: {}. Starting with empty token.",
962 e
963 );
964 String::new()
965 }
966 };
967
968 Client {
969 token: Arc::new(tokio::sync::RwLock::new(token)),
970 storage: Some(storage),
971 token_path: None,
972 ..self
973 }
974 }
975
976 /// Returns a new client with in-memory token storage (no persistence).
977 ///
978 /// Tokens are stored in memory only and lost when the application exits.
979 /// This is useful for testing or when you want to manage token persistence
980 /// externally.
981 ///
982 /// # Examples
983 ///
984 /// ```rust,no_run
985 /// use edgefirst_client::Client;
986 ///
987 /// # fn main() -> Result<(), edgefirst_client::Error> {
988 /// let client = Client::new()?.with_memory_storage();
989 /// # Ok(())
990 /// # }
991 /// ```
992 pub fn with_memory_storage(self) -> Self {
993 Client {
994 token: Arc::new(tokio::sync::RwLock::new(String::new())),
995 storage: Some(Arc::new(MemoryTokenStorage::new())),
996 token_path: None,
997 ..self
998 }
999 }
1000
1001 /// Returns a new client with no token storage.
1002 ///
1003 /// Tokens are not persisted. Use this when you want to manage tokens
1004 /// entirely manually.
1005 ///
1006 /// # Examples
1007 ///
1008 /// ```rust,no_run
1009 /// use edgefirst_client::Client;
1010 ///
1011 /// # fn main() -> Result<(), edgefirst_client::Error> {
1012 /// let client = Client::new()?.with_no_storage();
1013 /// # Ok(())
1014 /// # }
1015 /// ```
1016 pub fn with_no_storage(self) -> Self {
1017 Client {
1018 storage: None,
1019 token_path: None,
1020 ..self
1021 }
1022 }
1023
1024 /// Returns a new client authenticated with the provided username and
1025 /// password.
1026 ///
1027 /// The token is automatically persisted to storage (if configured).
1028 ///
1029 /// # Examples
1030 ///
1031 /// ```rust,no_run
1032 /// use edgefirst_client::Client;
1033 ///
1034 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1035 /// let client = Client::new()?
1036 /// .with_server("test")?
1037 /// .with_login("user@example.com", "password")
1038 /// .await?;
1039 /// # Ok(())
1040 /// # }
1041 /// ```
1042 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
1043 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
1044 let params = HashMap::from([("username", username), ("password", password)]);
1045 let login: LoginResult = self
1046 .rpc_without_auth("auth.login".to_owned(), Some(params))
1047 .await?;
1048
1049 // Validate that the server returned a non-empty token
1050 if login.token.is_empty() {
1051 return Err(Error::EmptyToken);
1052 }
1053
1054 // Persist token to storage if configured
1055 if let Some(ref storage) = self.storage
1056 && let Err(e) = storage.store(&login.token)
1057 {
1058 warn!("Failed to persist token to storage: {}", e);
1059 }
1060
1061 Ok(Client {
1062 token: Arc::new(tokio::sync::RwLock::new(login.token)),
1063 ..self.clone()
1064 })
1065 }
1066
1067 /// Returns a new client which will load and save the token to the specified
1068 /// path.
1069 ///
1070 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
1071 /// [`FileTokenStorage`] instead for more flexible token management.
1072 ///
1073 /// This method is maintained for backwards compatibility with existing
1074 /// code. It disables the default storage and uses file-based storage at
1075 /// the specified path.
1076 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
1077 let token_path = match token_path {
1078 Some(path) => path.to_path_buf(),
1079 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1080 .ok_or_else(|| {
1081 Error::IoError(std::io::Error::new(
1082 std::io::ErrorKind::NotFound,
1083 "Could not determine user config directory",
1084 ))
1085 })?
1086 .config_dir()
1087 .join("token"),
1088 };
1089
1090 debug!("Using token path (legacy): {:?}", token_path);
1091
1092 let token = match token_path.exists() {
1093 true => std::fs::read_to_string(&token_path)?,
1094 false => "".to_string(),
1095 };
1096
1097 if !token.is_empty() {
1098 match self.with_token(&token) {
1099 Ok(client) => Ok(Client {
1100 token_path: Some(token_path),
1101 storage: None, // Disable new storage when using legacy token_path
1102 ..client
1103 }),
1104 Err(e) => {
1105 // Token is corrupted or invalid - remove it and continue with no token
1106 warn!(
1107 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
1108 token_path, e
1109 );
1110 if let Err(remove_err) = std::fs::remove_file(&token_path) {
1111 warn!("Failed to remove corrupted token file: {:?}", remove_err);
1112 }
1113 // Clear any token from default storage to ensure we don't use it
1114 Ok(Client {
1115 token_path: Some(token_path),
1116 storage: None,
1117 token: Arc::new(RwLock::new("".to_string())),
1118 ..self.clone()
1119 })
1120 }
1121 }
1122 } else {
1123 // No token in the legacy file - clear any token from default storage
1124 Ok(Client {
1125 token_path: Some(token_path),
1126 storage: None,
1127 token: Arc::new(RwLock::new("".to_string())),
1128 ..self.clone()
1129 })
1130 }
1131 }
1132
1133 /// Returns a new client authenticated with the provided token.
1134 ///
1135 /// The token is automatically persisted to storage (if configured).
1136 /// The server URL is extracted from the token payload.
1137 ///
1138 /// # Examples
1139 ///
1140 /// ```rust,no_run
1141 /// use edgefirst_client::Client;
1142 ///
1143 /// # fn main() -> Result<(), edgefirst_client::Error> {
1144 /// let client = Client::new()?.with_token("your-jwt-token")?;
1145 /// # Ok(())
1146 /// # }
1147 /// ```
1148 /// Extract server name from JWT token payload.
1149 ///
1150 /// Helper method to parse the JWT token and extract the "server" field
1151 /// from the payload. Returns the server name (e.g., "test", "stage", "")
1152 /// or an error if the token is invalid.
1153 fn extract_server_from_token(token: &str) -> Result<String, Error> {
1154 let token_parts: Vec<&str> = token.split('.').collect();
1155 if token_parts.len() != 3 {
1156 return Err(Error::InvalidToken);
1157 }
1158
1159 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1160 .decode(token_parts[1])
1161 .map_err(|_| Error::InvalidToken)?;
1162 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1163 let server = match payload.get("server") {
1164 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
1165 None => return Err(Error::InvalidToken),
1166 };
1167
1168 Ok(server)
1169 }
1170
1171 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
1172 if token.is_empty() {
1173 return Ok(self.clone());
1174 }
1175
1176 let server = Self::extract_server_from_token(token)?;
1177
1178 // Persist token to storage if configured
1179 if let Some(ref storage) = self.storage
1180 && let Err(e) = storage.store(token)
1181 {
1182 warn!("Failed to persist token to storage: {}", e);
1183 }
1184
1185 Ok(Client {
1186 url: format!("https://{}.edgefirst.studio", server),
1187 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
1188 ..self.clone()
1189 })
1190 }
1191
1192 /// Persist the current token to storage.
1193 ///
1194 /// This is automatically called when using [`with_login`][Self::with_login]
1195 /// or [`with_token`][Self::with_token], so you typically don't need to call
1196 /// this directly.
1197 ///
1198 /// If using the legacy `token_path` configuration, saves to the file path.
1199 /// If using the new storage abstraction, saves to the configured storage.
1200 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1201 pub async fn save_token(&self) -> Result<(), Error> {
1202 let token = self.token.read().await;
1203
1204 // Try new storage first
1205 if let Some(ref storage) = self.storage {
1206 storage.store(&token)?;
1207 debug!("Token saved to storage");
1208 return Ok(());
1209 }
1210
1211 // Fall back to legacy token_path behavior
1212 let path = self.token_path.clone().unwrap_or_else(|| {
1213 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1214 .map(|dirs| dirs.config_dir().join("token"))
1215 .unwrap_or_else(|| PathBuf::from(".token"))
1216 });
1217
1218 create_dir_all(path.parent().ok_or_else(|| {
1219 Error::IoError(std::io::Error::new(
1220 std::io::ErrorKind::InvalidInput,
1221 "Token path has no parent directory",
1222 ))
1223 })?)?;
1224 let mut file = std::fs::File::create(&path)?;
1225 file.write_all(token.as_bytes())?;
1226
1227 debug!("Saved token to {:?}", path);
1228
1229 Ok(())
1230 }
1231
1232 /// Return the version of the EdgeFirst Studio server for the current
1233 /// client connection.
1234 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1235 pub async fn version(&self) -> Result<String, Error> {
1236 let version: HashMap<String, String> = self
1237 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
1238 .await?;
1239 let version = version.get("version").ok_or(Error::InvalidResponse)?;
1240 Ok(version.to_owned())
1241 }
1242
1243 /// Clear the token used to authenticate the client with the server.
1244 ///
1245 /// Clears the token from memory and from storage (if configured).
1246 /// If using the legacy `token_path` configuration, removes the token file.
1247 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1248 pub async fn logout(&self) -> Result<(), Error> {
1249 {
1250 let mut token = self.token.write().await;
1251 *token = "".to_string();
1252 }
1253
1254 // Clear from new storage if configured
1255 if let Some(ref storage) = self.storage
1256 && let Err(e) = storage.clear()
1257 {
1258 warn!("Failed to clear token from storage: {}", e);
1259 }
1260
1261 // Also clear legacy token_path if configured
1262 if let Some(path) = &self.token_path
1263 && path.exists()
1264 {
1265 fs::remove_file(path).await?;
1266 }
1267
1268 Ok(())
1269 }
1270
1271 /// Return the token used to authenticate the client with the server. When
1272 /// logging into the server using a username and password, the token is
1273 /// returned by the server and stored in the client for future interactions.
1274 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1275 pub async fn token(&self) -> String {
1276 self.token.read().await.clone()
1277 }
1278
1279 /// Verify the token used to authenticate the client with the server. This
1280 /// method is used to ensure that the token is still valid and has not
1281 /// expired. If the token is invalid, the server will return an error and
1282 /// the client will need to login again.
1283 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1284 pub async fn verify_token(&self) -> Result<(), Error> {
1285 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
1286 .await?;
1287 Ok::<(), Error>(())
1288 }
1289
1290 /// Renew the token used to authenticate the client with the server.
1291 ///
1292 /// Refreshes the token before it expires. If the token has already expired,
1293 /// the server will return an error and you will need to login again.
1294 ///
1295 /// The new token is automatically persisted to storage (if configured).
1296 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1297 pub async fn renew_token(&self) -> Result<(), Error> {
1298 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1299 let result: LoginResult = self
1300 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1301 .await?;
1302
1303 {
1304 let mut token = self.token.write().await;
1305 *token = result.token.clone();
1306 }
1307
1308 // Persist to new storage if configured
1309 if let Some(ref storage) = self.storage
1310 && let Err(e) = storage.store(&result.token)
1311 {
1312 warn!("Failed to persist renewed token to storage: {}", e);
1313 }
1314
1315 // Also persist to legacy token_path if configured
1316 if self.token_path.is_some() {
1317 self.save_token().await?;
1318 }
1319
1320 Ok(())
1321 }
1322
1323 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1324 let token = self.token.read().await;
1325 if token.is_empty() {
1326 return Err(Error::EmptyToken);
1327 }
1328
1329 let token_parts: Vec<&str> = token.split('.').collect();
1330 if token_parts.len() != 3 {
1331 return Err(Error::InvalidToken);
1332 }
1333
1334 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1335 .decode(token_parts[1])
1336 .map_err(|_| Error::InvalidToken)?;
1337 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1338 match payload.get(field) {
1339 Some(value) => Ok(value.to_owned()),
1340 None => Err(Error::InvalidToken),
1341 }
1342 }
1343
1344 /// Returns the URL of the EdgeFirst Studio server for the current client.
1345 pub fn url(&self) -> &str {
1346 &self.url
1347 }
1348
1349 /// Returns the server name for the current client.
1350 ///
1351 /// This extracts the server name from the client's URL:
1352 /// - `https://edgefirst.studio` → `"saas"`
1353 /// - `https://test.edgefirst.studio` → `"test"`
1354 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1355 ///
1356 /// # Examples
1357 ///
1358 /// ```rust,no_run
1359 /// use edgefirst_client::Client;
1360 ///
1361 /// # fn main() -> Result<(), edgefirst_client::Error> {
1362 /// let client = Client::new()?.with_server("test")?;
1363 /// assert_eq!(client.server(), "test");
1364 ///
1365 /// let client = Client::new()?; // default
1366 /// assert_eq!(client.server(), "saas");
1367 /// # Ok(())
1368 /// # }
1369 /// ```
1370 pub fn server(&self) -> &str {
1371 if self.url == "https://edgefirst.studio" {
1372 "saas"
1373 } else if let Some(name) = self.url.strip_prefix("https://") {
1374 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1375 } else {
1376 "saas"
1377 }
1378 }
1379
1380 /// Returns the username associated with the current token.
1381 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1382 pub async fn username(&self) -> Result<String, Error> {
1383 match self.token_field("username").await? {
1384 serde_json::Value::String(username) => Ok(username),
1385 _ => Err(Error::InvalidToken),
1386 }
1387 }
1388
1389 /// Returns the expiration time for the current token.
1390 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1391 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1392 let ts = match self.token_field("exp").await? {
1393 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1394 _ => return Err(Error::InvalidToken),
1395 };
1396
1397 match DateTime::<Utc>::from_timestamp(ts, 0) {
1398 Some(dt) => Ok(dt),
1399 None => Err(Error::InvalidToken),
1400 }
1401 }
1402
1403 /// Returns the organization information for the current user.
1404 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1405 pub async fn organization(&self) -> Result<Organization, Error> {
1406 self.rpc::<(), Organization>("org.get".to_owned(), None)
1407 .await
1408 }
1409
1410 /// Returns a list of projects available to the user. The projects are
1411 /// returned as a vector of Project objects. If a name filter is
1412 /// provided, only projects matching the filter are returned.
1413 ///
1414 /// Results are sorted by match quality: exact matches first, then
1415 /// case-insensitive exact matches, then shorter names (more specific),
1416 /// then alphabetically.
1417 ///
1418 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1419 /// Projects contain datasets, trainers, and trainer sessions. Projects
1420 /// are used to group related datasets and trainers together.
1421 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1422 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1423 let projects = self
1424 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1425 .await?;
1426 if let Some(name) = name {
1427 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1428 } else {
1429 Ok(projects)
1430 }
1431 }
1432
1433 /// Return the project with the specified project ID. If the project does
1434 /// not exist, an error is returned.
1435 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1436 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1437 let params = HashMap::from([("project_id", project_id)]);
1438 self.rpc("project.get".to_owned(), Some(params)).await
1439 }
1440
1441 /// Returns a list of datasets available to the user. The datasets are
1442 /// returned as a vector of Dataset objects. If a name filter is
1443 /// provided, only datasets matching the filter are returned.
1444 ///
1445 /// Results are sorted by match quality: exact matches first, then
1446 /// case-insensitive exact matches, then shorter names (more specific),
1447 /// then alphabetically. This ensures "Deer" returns before "Deer
1448 /// Roundtrip".
1449 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1450 pub async fn datasets(
1451 &self,
1452 project_id: ProjectID,
1453 name: Option<&str>,
1454 ) -> Result<Vec<Dataset>, Error> {
1455 let params = HashMap::from([("project_id", project_id)]);
1456 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1457 if let Some(name) = name {
1458 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1459 } else {
1460 Ok(datasets)
1461 }
1462 }
1463
1464 /// Return the dataset with the specified dataset ID. If the dataset does
1465 /// not exist, an error is returned.
1466 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1467 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1468 let params = HashMap::from([("dataset_id", dataset_id)]);
1469 self.rpc("dataset.get".to_owned(), Some(params)).await
1470 }
1471
1472 /// Lists the labels for the specified dataset.
1473 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1474 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1475 let params = HashMap::from([("dataset_id", dataset_id)]);
1476 self.rpc("label.list".to_owned(), Some(params)).await
1477 }
1478
1479 /// Add a new label to the dataset with the specified name.
1480 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1481 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1482 let new_label = NewLabel {
1483 dataset_id,
1484 labels: vec![NewLabelObject {
1485 name: name.to_owned(),
1486 }],
1487 };
1488 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1489 Ok(())
1490 }
1491
1492 /// Removes the label with the specified ID from the dataset. Label IDs are
1493 /// globally unique so the dataset_id is not required.
1494 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1495 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1496 let params = HashMap::from([("label_id", label_id)]);
1497 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1498 Ok(())
1499 }
1500
1501 /// Creates a new dataset in the specified project.
1502 ///
1503 /// # Arguments
1504 ///
1505 /// * `project_id` - The ID of the project to create the dataset in
1506 /// * `name` - The name of the new dataset
1507 /// * `description` - Optional description for the dataset
1508 ///
1509 /// # Returns
1510 ///
1511 /// Returns the dataset ID of the newly created dataset.
1512 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1513 pub async fn create_dataset(
1514 &self,
1515 project_id: &str,
1516 name: &str,
1517 description: Option<&str>,
1518 ) -> Result<DatasetID, Error> {
1519 let mut params = HashMap::new();
1520 params.insert("project_id", project_id);
1521 params.insert("name", name);
1522 if let Some(desc) = description {
1523 params.insert("description", desc);
1524 }
1525
1526 #[derive(Deserialize)]
1527 struct CreateDatasetResult {
1528 id: DatasetID,
1529 }
1530
1531 let result: CreateDatasetResult =
1532 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1533 Ok(result.id)
1534 }
1535
1536 /// Deletes a dataset by marking it as deleted.
1537 ///
1538 /// # Arguments
1539 ///
1540 /// * `dataset_id` - The ID of the dataset to delete
1541 ///
1542 /// # Returns
1543 ///
1544 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1545 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1546 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1547 let params = HashMap::from([("id", dataset_id)]);
1548 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1549 Ok(())
1550 }
1551
1552 /// Updates the label with the specified ID to have the new name or index.
1553 /// Label IDs cannot be changed. Label IDs are globally unique so the
1554 /// dataset_id is not required.
1555 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1556 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1557 #[derive(Serialize)]
1558 struct Params {
1559 dataset_id: DatasetID,
1560 label_id: u64,
1561 label_name: String,
1562 label_index: u64,
1563 }
1564
1565 let _: String = self
1566 .rpc(
1567 "label.update".to_owned(),
1568 Some(Params {
1569 dataset_id: label.dataset_id(),
1570 label_id: label.id(),
1571 label_name: label.name().to_owned(),
1572 label_index: label.index(),
1573 }),
1574 )
1575 .await?;
1576 Ok(())
1577 }
1578
1579 /// Lists the groups for the specified dataset.
1580 ///
1581 /// Groups are used to organize samples into logical subsets such as
1582 /// "train", "val", "test", etc. Each sample can belong to at most one
1583 /// group at a time.
1584 ///
1585 /// # Arguments
1586 ///
1587 /// * `dataset_id` - The ID of the dataset to list groups for
1588 ///
1589 /// # Returns
1590 ///
1591 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1592 /// empty vector if no groups have been created yet.
1593 ///
1594 /// # Errors
1595 ///
1596 /// Returns an error if the dataset does not exist or cannot be accessed.
1597 ///
1598 /// # Example
1599 ///
1600 /// ```rust,no_run
1601 /// # use edgefirst_client::{Client, DatasetID};
1602 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1603 /// let client = Client::new()?.with_token_path(None)?;
1604 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1605 ///
1606 /// let groups = client.groups(dataset_id).await?;
1607 /// for group in groups {
1608 /// println!("{}: {}", group.id, group.name);
1609 /// }
1610 /// # Ok(())
1611 /// # }
1612 /// ```
1613 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1614 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1615 let params = HashMap::from([("dataset_id", dataset_id)]);
1616 self.rpc("groups.list".to_owned(), Some(params)).await
1617 }
1618
1619 /// Gets an existing group by name or creates a new one.
1620 ///
1621 /// This is a convenience method that first checks if a group with the
1622 /// specified name exists, and creates it if not. This is useful when
1623 /// you need to ensure a group exists before assigning samples to it.
1624 ///
1625 /// # Arguments
1626 ///
1627 /// * `dataset_id` - The ID of the dataset
1628 /// * `name` - The name of the group (e.g., "train", "val", "test")
1629 ///
1630 /// # Returns
1631 ///
1632 /// Returns the group ID (either existing or newly created).
1633 ///
1634 /// # Errors
1635 ///
1636 /// Returns an error if:
1637 /// - The dataset does not exist or cannot be accessed
1638 /// - The group creation fails
1639 ///
1640 /// # Concurrency
1641 ///
1642 /// This method handles concurrent creation attempts gracefully. If another
1643 /// process creates the group between the existence check and creation,
1644 /// this method will return the existing group's ID.
1645 ///
1646 /// # Example
1647 ///
1648 /// ```rust,no_run
1649 /// # use edgefirst_client::{Client, DatasetID};
1650 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1651 /// let client = Client::new()?.with_token_path(None)?;
1652 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1653 ///
1654 /// // Get or create a "train" group
1655 /// let train_group_id = client
1656 /// .get_or_create_group(dataset_id.clone(), "train")
1657 /// .await?;
1658 /// println!("Train group ID: {}", train_group_id);
1659 ///
1660 /// // Calling again returns the same ID
1661 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1662 /// assert_eq!(train_group_id, same_id);
1663 /// # Ok(())
1664 /// # }
1665 /// ```
1666 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1667 pub async fn get_or_create_group(
1668 &self,
1669 dataset_id: DatasetID,
1670 name: &str,
1671 ) -> Result<u64, Error> {
1672 // First check if the group already exists
1673 let groups = self.groups(dataset_id).await?;
1674 if let Some(group) = groups.iter().find(|g| g.name == name) {
1675 return Ok(group.id);
1676 }
1677
1678 // Create the group
1679 #[derive(Serialize)]
1680 struct CreateGroupParams {
1681 dataset_id: DatasetID,
1682 group_names: Vec<String>,
1683 group_splits: Vec<i64>,
1684 }
1685
1686 let params = CreateGroupParams {
1687 dataset_id,
1688 group_names: vec![name.to_string()],
1689 group_splits: vec![0], // No automatic splitting
1690 };
1691
1692 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1693 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1694 Ok(group.id)
1695 } else {
1696 // Group might have been created by concurrent call, try fetching again
1697 let groups = self.groups(dataset_id).await?;
1698 groups
1699 .iter()
1700 .find(|g| g.name == name)
1701 .map(|g| g.id)
1702 .ok_or_else(|| {
1703 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1704 })
1705 }
1706 }
1707
1708 /// Sets the group for a sample.
1709 ///
1710 /// Assigns a sample to a specific group. Each sample can belong to at most
1711 /// one group at a time. Setting a new group replaces any existing group
1712 /// assignment.
1713 ///
1714 /// # Arguments
1715 ///
1716 /// * `sample_id` - The ID of the sample (image) to update
1717 /// * `group_id` - The ID of the group to assign. Use
1718 /// [`get_or_create_group`] to obtain a group ID from a name.
1719 ///
1720 /// # Returns
1721 ///
1722 /// Returns `Ok(())` on success.
1723 ///
1724 /// # Errors
1725 ///
1726 /// Returns an error if:
1727 /// - The sample does not exist
1728 /// - The group does not exist
1729 /// - Insufficient permissions to modify the sample
1730 ///
1731 /// # Example
1732 ///
1733 /// ```rust,no_run
1734 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1735 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1736 /// let client = Client::new()?.with_token_path(None)?;
1737 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1738 /// let sample_id: SampleID = 12345.into();
1739 ///
1740 /// // Get or create the "val" group
1741 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1742 ///
1743 /// // Assign the sample to the "val" group
1744 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1745 /// # Ok(())
1746 /// # }
1747 /// ```
1748 ///
1749 /// [`get_or_create_group`]: Self::get_or_create_group
1750 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1751 pub async fn set_sample_group_id(
1752 &self,
1753 sample_id: SampleID,
1754 group_id: u64,
1755 ) -> Result<(), Error> {
1756 #[derive(Serialize)]
1757 struct SetGroupParams {
1758 image_id: SampleID,
1759 group_id: u64,
1760 }
1761
1762 let params = SetGroupParams {
1763 image_id: sample_id,
1764 group_id,
1765 };
1766 let _: String = self
1767 .rpc("image.set_group_id".to_owned(), Some(params))
1768 .await?;
1769 Ok(())
1770 }
1771
1772 /// Downloads dataset samples to the local filesystem.
1773 ///
1774 /// # Arguments
1775 ///
1776 /// * `dataset_id` - The unique identifier of the dataset
1777 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1778 /// * `file_types` - File types to download. Supported types:
1779 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
1780 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
1781 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
1782 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
1783 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
1784 /// - `FileType::RadarCube` - Radar cube data (.png format)
1785 /// - `FileType::All` - All sensor types (expands to all of the above)
1786 /// * `output` - Local directory to save downloaded files
1787 /// * `flatten` - If true, download all files to output root without
1788 /// sequence subdirectories. When flattening, filenames are prefixed with
1789 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1790 /// unavailable) unless the filename already starts with
1791 /// `{sequence_name}_`, to avoid conflicts between sequences.
1792 /// * `progress` - Optional channel for progress updates
1793 ///
1794 /// # Progress
1795 ///
1796 /// This operation has two phases with distinct progress reporting:
1797 ///
1798 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
1799 /// from the server. Progress counts samples fetched.
1800 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
1801 /// files to disk. Progress counts samples completed (each sample may
1802 /// have multiple files for different sensor types).
1803 ///
1804 /// Applications should detect the status change from `None` to
1805 /// `"Downloading"` to reset their progress bar for the second phase.
1806 ///
1807 /// # Returns
1808 ///
1809 /// Returns `Ok(())` on success or an error if download fails.
1810 ///
1811 /// # Example
1812 ///
1813 /// ```rust,no_run
1814 /// # use edgefirst_client::{Client, DatasetID, FileType};
1815 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1816 /// let client = Client::new()?.with_token_path(None)?;
1817 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1818 ///
1819 /// // Download with sequence subdirectories (default)
1820 /// client
1821 /// .download_dataset(
1822 /// dataset_id,
1823 /// &[],
1824 /// &[FileType::Image],
1825 /// "./data".into(),
1826 /// false,
1827 /// None,
1828 /// )
1829 /// .await?;
1830 ///
1831 /// // Download flattened (all files in one directory)
1832 /// client
1833 /// .download_dataset(
1834 /// dataset_id,
1835 /// &[],
1836 /// &[FileType::Image],
1837 /// "./data".into(),
1838 /// true,
1839 /// None,
1840 /// )
1841 /// .await?;
1842 ///
1843 /// // Download all sensor types
1844 /// client
1845 /// .download_dataset(
1846 /// dataset_id,
1847 /// &[],
1848 /// &FileType::expand_types(&[FileType::All]),
1849 /// "./data".into(),
1850 /// false,
1851 /// None,
1852 /// )
1853 /// .await?;
1854 /// # Ok(())
1855 /// # }
1856 /// ```
1857 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
1858 pub async fn download_dataset(
1859 &self,
1860 dataset_id: DatasetID,
1861 groups: &[String],
1862 file_types: &[FileType],
1863 output: PathBuf,
1864 flatten: bool,
1865 progress: Option<Sender<Progress>>,
1866 ) -> Result<(), Error> {
1867 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
1868 let samples = self
1869 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1870 .await?;
1871 fs::create_dir_all(&output).await?;
1872
1873 // Phase 2: Download actual files using direct semaphore pattern
1874 let total = samples.len();
1875 let current = Arc::new(AtomicUsize::new(0));
1876 let sem = Arc::new(Semaphore::new(max_tasks()));
1877
1878 // Send initial progress for download phase
1879 if let Some(ref progress) = progress {
1880 let _ = progress
1881 .send(Progress {
1882 current: 0,
1883 total,
1884 status: Some("Downloading".to_string()),
1885 })
1886 .await;
1887 }
1888
1889 let tasks = samples
1890 .into_iter()
1891 .map(|sample| {
1892 let client = self.clone();
1893 let file_types = file_types.to_vec();
1894 let output = output.clone();
1895 let progress = progress.clone();
1896 let current = current.clone();
1897 let sem = sem.clone();
1898
1899 tokio::spawn(async move {
1900 let _permit = sem.acquire().await.map_err(|_| {
1901 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1902 })?;
1903
1904 for file_type in &file_types {
1905 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1906 let (file_ext, is_image) = match file_type {
1907 FileType::Image => (
1908 infer::get(&data)
1909 .expect("Failed to identify image file format for sample")
1910 .extension()
1911 .to_string(),
1912 true,
1913 ),
1914 other => (other.file_extension().to_string(), false),
1915 };
1916
1917 // Determine target directory based on sequence membership and
1918 // flatten option
1919 // - flatten=false + sequence_name: dataset/sequence_name/
1920 // - flatten=false + no sequence: dataset/ (root level)
1921 // - flatten=true: dataset/ (all files in output root)
1922 // NOTE: group (train/val/test) is NOT used for directory structure
1923 let sequence_dir = sample
1924 .sequence_name()
1925 .map(|name| sanitize_path_component(name));
1926
1927 let target_dir = if flatten {
1928 output.clone()
1929 } else {
1930 sequence_dir
1931 .as_ref()
1932 .map(|seq| output.join(seq))
1933 .unwrap_or_else(|| output.clone())
1934 };
1935 fs::create_dir_all(&target_dir).await?;
1936
1937 let sanitized_sample_name = sample
1938 .name()
1939 .map(|name| sanitize_path_component(&name))
1940 .unwrap_or_else(|| "unknown".to_string());
1941
1942 let image_name = sample.image_name().map(sanitize_path_component);
1943
1944 // Construct filename with smart prefixing for flatten mode
1945 // When flatten=true and sample belongs to a sequence:
1946 // - Check if filename already starts with "{sequence_name}_"
1947 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1948 // - If yes, use filename as-is (already uniquely named)
1949 let file_name = if is_image {
1950 if let Some(img_name) = image_name {
1951 Client::build_filename(
1952 &img_name,
1953 flatten,
1954 sequence_dir.as_ref(),
1955 sample.frame_number(),
1956 )
1957 } else {
1958 format!("{}.{}", sanitized_sample_name, file_ext)
1959 }
1960 } else {
1961 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1962 Client::build_filename(
1963 &base_name,
1964 flatten,
1965 sequence_dir.as_ref(),
1966 sample.frame_number(),
1967 )
1968 };
1969
1970 let file_path = target_dir.join(&file_name);
1971
1972 let mut file = File::create(&file_path).await?;
1973 file.write_all(&data).await?;
1974 }
1975 }
1976
1977 // Update progress after sample completes
1978 if let Some(progress) = &progress {
1979 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
1980 let _ = progress
1981 .send(Progress {
1982 current: completed,
1983 total,
1984 status: Some("Downloading".to_string()),
1985 })
1986 .await;
1987 }
1988
1989 Ok::<(), Error>(())
1990 })
1991 })
1992 .collect::<Vec<_>>();
1993
1994 join_all(tasks)
1995 .await
1996 .into_iter()
1997 .collect::<Result<Vec<_>, _>>()?
1998 .into_iter()
1999 .collect::<Result<Vec<_>, _>>()?;
2000
2001 Ok(())
2002 }
2003
2004 /// Builds a filename with smart prefixing for flatten mode.
2005 ///
2006 /// When flattening sequences into a single directory, this function ensures
2007 /// unique filenames by checking if the sequence prefix already exists and
2008 /// adding it if necessary.
2009 ///
2010 /// # Logic
2011 ///
2012 /// - If `flatten=false`: returns `base_name` unchanged
2013 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
2014 /// - If `flatten=true` and in sequence:
2015 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
2016 /// unchanged
2017 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
2018 /// `{sequence_name}_{base_name}`
2019 fn build_filename(
2020 base_name: &str,
2021 flatten: bool,
2022 sequence_name: Option<&String>,
2023 frame_number: Option<u32>,
2024 ) -> String {
2025 if !flatten || sequence_name.is_none() {
2026 return base_name.to_string();
2027 }
2028
2029 let seq_name = sequence_name.unwrap();
2030 let prefix = format!("{}_", seq_name);
2031
2032 // Check if already prefixed with sequence name
2033 if base_name.starts_with(&prefix) {
2034 base_name.to_string()
2035 } else {
2036 // Add sequence (and optionally frame) prefix
2037 match frame_number {
2038 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
2039 None => format!("{}{}", prefix, base_name),
2040 }
2041 }
2042 }
2043
2044 /// List available annotation sets for the specified dataset.
2045 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2046 pub async fn annotation_sets(
2047 &self,
2048 dataset_id: DatasetID,
2049 ) -> Result<Vec<AnnotationSet>, Error> {
2050 let params = HashMap::from([("dataset_id", dataset_id)]);
2051 self.rpc("annset.list".to_owned(), Some(params)).await
2052 }
2053
2054 /// Create a new annotation set for the specified dataset.
2055 ///
2056 /// # Arguments
2057 ///
2058 /// * `dataset_id` - The ID of the dataset to create the annotation set in
2059 /// * `name` - The name of the new annotation set
2060 /// * `description` - Optional description for the annotation set
2061 ///
2062 /// # Returns
2063 ///
2064 /// Returns the annotation set ID of the newly created annotation set.
2065 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2066 pub async fn create_annotation_set(
2067 &self,
2068 dataset_id: DatasetID,
2069 name: &str,
2070 description: Option<&str>,
2071 ) -> Result<AnnotationSetID, Error> {
2072 #[derive(Serialize)]
2073 struct Params<'a> {
2074 dataset_id: DatasetID,
2075 name: &'a str,
2076 operator: &'a str,
2077 #[serde(skip_serializing_if = "Option::is_none")]
2078 description: Option<&'a str>,
2079 }
2080
2081 #[derive(Deserialize)]
2082 struct CreateAnnotationSetResult {
2083 id: AnnotationSetID,
2084 }
2085
2086 let username = self.username().await?;
2087 let result: CreateAnnotationSetResult = self
2088 .rpc(
2089 "annset.add".to_owned(),
2090 Some(Params {
2091 dataset_id,
2092 name,
2093 operator: &username,
2094 description,
2095 }),
2096 )
2097 .await?;
2098 Ok(result.id)
2099 }
2100
2101 /// Deletes an annotation set by marking it as deleted.
2102 ///
2103 /// # Arguments
2104 ///
2105 /// * `annotation_set_id` - The ID of the annotation set to delete
2106 ///
2107 /// # Returns
2108 ///
2109 /// Returns `Ok(())` if the annotation set was successfully marked as
2110 /// deleted.
2111 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2112 pub async fn delete_annotation_set(
2113 &self,
2114 annotation_set_id: AnnotationSetID,
2115 ) -> Result<(), Error> {
2116 let params = HashMap::from([("id", annotation_set_id)]);
2117 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
2118 Ok(())
2119 }
2120
2121 /// Retrieve the annotation set with the specified ID.
2122 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2123 pub async fn annotation_set(
2124 &self,
2125 annotation_set_id: AnnotationSetID,
2126 ) -> Result<AnnotationSet, Error> {
2127 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
2128 self.rpc("annset.get".to_owned(), Some(params)).await
2129 }
2130
2131 /// Get the annotations for the specified annotation set with the
2132 /// requested annotation types. The annotation types are used to filter
2133 /// the annotations returned. The groups parameter is used to filter for
2134 /// dataset groups (train, val, test). Images which do not have any
2135 /// annotations are also included in the result as long as they are in the
2136 /// requested groups (when specified).
2137 ///
2138 /// The result is a vector of Annotations objects which contain the
2139 /// full dataset along with the annotations for the specified types.
2140 ///
2141 /// # Progress
2142 ///
2143 /// Reports progress with `status: None` as samples are fetched and
2144 /// processed for their annotations. Progress unit is samples processed
2145 /// (not individual annotations).
2146 ///
2147 /// To get the annotations as a DataFrame, use the `samples_dataframe`
2148 /// method instead.
2149 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2150 pub async fn annotations(
2151 &self,
2152 annotation_set_id: AnnotationSetID,
2153 groups: &[String],
2154 annotation_types: &[AnnotationType],
2155 progress: Option<Sender<Progress>>,
2156 ) -> Result<Vec<Annotation>, Error> {
2157 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
2158 let labels = self
2159 .labels(dataset_id)
2160 .await?
2161 .into_iter()
2162 .map(|label| (label.name().to_string(), label.index()))
2163 .collect::<HashMap<_, _>>();
2164 let total = self
2165 .samples_count(
2166 dataset_id,
2167 Some(annotation_set_id),
2168 annotation_types,
2169 groups,
2170 &[],
2171 )
2172 .await?
2173 .total as usize;
2174
2175 if total == 0 {
2176 return Ok(vec![]);
2177 }
2178
2179 let context = FetchContext {
2180 dataset_id,
2181 annotation_set_id: Some(annotation_set_id),
2182 groups,
2183 types: annotation_types.iter().map(|t| t.to_string()).collect(),
2184 labels: &labels,
2185 };
2186
2187 self.fetch_annotations_paginated(context, total, progress)
2188 .await
2189 }
2190
2191 async fn fetch_annotations_paginated(
2192 &self,
2193 context: FetchContext<'_>,
2194 total: usize,
2195 progress: Option<Sender<Progress>>,
2196 ) -> Result<Vec<Annotation>, Error> {
2197 let mut annotations = vec![];
2198 let mut continue_token: Option<String> = None;
2199 let mut current = 0;
2200
2201 loop {
2202 let params = SamplesListParams {
2203 dataset_id: context.dataset_id,
2204 annotation_set_id: context.annotation_set_id,
2205 types: context.types.clone(),
2206 group_names: context.groups.to_vec(),
2207 continue_token,
2208 };
2209
2210 let result: SamplesListResult =
2211 self.rpc("samples.list".to_owned(), Some(params)).await?;
2212 current += result.samples.len();
2213 continue_token = result.continue_token;
2214
2215 if result.samples.is_empty() {
2216 break;
2217 }
2218
2219 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
2220
2221 if let Some(progress) = &progress {
2222 let _ = progress
2223 .send(Progress {
2224 current,
2225 total,
2226 status: None,
2227 })
2228 .await;
2229 }
2230
2231 match &continue_token {
2232 Some(token) if !token.is_empty() => continue,
2233 _ => break,
2234 }
2235 }
2236
2237 drop(progress);
2238 Ok(annotations)
2239 }
2240
2241 fn process_sample_annotations(
2242 &self,
2243 samples: &[Sample],
2244 labels: &HashMap<String, u64>,
2245 annotations: &mut Vec<Annotation>,
2246 ) {
2247 for sample in samples {
2248 if sample.annotations().is_empty() {
2249 let mut annotation = Annotation::new();
2250 annotation.set_sample_id(sample.id());
2251 annotation.set_name(sample.name());
2252 annotation.set_sequence_name(sample.sequence_name().cloned());
2253 annotation.set_frame_number(sample.frame_number());
2254 annotation.set_group(sample.group().cloned());
2255 annotations.push(annotation);
2256 continue;
2257 }
2258
2259 for annotation in sample.annotations() {
2260 let mut annotation = annotation.clone();
2261 annotation.set_sample_id(sample.id());
2262 annotation.set_name(sample.name());
2263 annotation.set_sequence_name(sample.sequence_name().cloned());
2264 annotation.set_frame_number(sample.frame_number());
2265 annotation.set_group(sample.group().cloned());
2266 Self::set_label_index_from_map(&mut annotation, labels);
2267 annotations.push(annotation);
2268 }
2269 }
2270 }
2271
2272 /// Delete annotations in bulk from specified samples.
2273 ///
2274 /// This method calls the `annotation.bulk.del` API to efficiently remove
2275 /// annotations from multiple samples at once. Useful for clearing
2276 /// annotations before re-importing updated data.
2277 ///
2278 /// # Arguments
2279 /// * `annotation_set_id` - The annotation set containing the annotations
2280 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
2281 /// for masks
2282 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
2283 ///
2284 /// # Example
2285 /// ```no_run
2286 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
2287 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2288 /// # let client = Client::new()?.with_login("user", "pass").await?;
2289 /// let annotation_set_id = AnnotationSetID::from(123);
2290 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2291 ///
2292 /// client
2293 /// .delete_annotations_bulk(
2294 /// annotation_set_id,
2295 /// &["box".to_string(), "seg".to_string()],
2296 /// &sample_ids,
2297 /// )
2298 /// .await?;
2299 /// # Ok(())
2300 /// # }
2301 /// ```
2302 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2303 pub async fn delete_annotations_bulk(
2304 &self,
2305 annotation_set_id: AnnotationSetID,
2306 annotation_types: &[String],
2307 sample_ids: &[SampleID],
2308 ) -> Result<(), Error> {
2309 use crate::api::AnnotationBulkDeleteParams;
2310
2311 let params = AnnotationBulkDeleteParams {
2312 annotation_set_id: annotation_set_id.into(),
2313 annotation_types: annotation_types.to_vec(),
2314 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2315 delete_all: None,
2316 };
2317
2318 let _: String = self
2319 .rpc("annotation.bulk.del".to_owned(), Some(params))
2320 .await?;
2321 Ok(())
2322 }
2323
2324 /// Add annotations in bulk.
2325 ///
2326 /// This method calls the `annotation.add_bulk` API to efficiently add
2327 /// multiple annotations at once. The annotations must be in server format
2328 /// with image_id references.
2329 ///
2330 /// # Arguments
2331 /// * `annotation_set_id` - The annotation set to add annotations to
2332 /// * `annotations` - Vector of server-format annotations to add
2333 ///
2334 /// # Returns
2335 /// Vector of created annotation records from the server.
2336 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2337 pub async fn add_annotations_bulk(
2338 &self,
2339 annotation_set_id: AnnotationSetID,
2340 annotations: Vec<crate::api::ServerAnnotation>,
2341 ) -> Result<Vec<serde_json::Value>, Error> {
2342 use crate::api::AnnotationAddBulkParams;
2343
2344 let params = AnnotationAddBulkParams {
2345 annotation_set_id: annotation_set_id.into(),
2346 annotations,
2347 };
2348
2349 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2350 .await
2351 }
2352
2353 /// Helper to parse frame number from image_name when sequence_name is
2354 /// present. This ensures frame_number is always derived from the image
2355 /// filename, not from the server's frame_number field (which may be
2356 /// inconsistent).
2357 ///
2358 /// Returns Some(frame_number) if sequence_name is present and frame can be
2359 /// parsed, otherwise None.
2360 fn parse_frame_from_image_name(
2361 image_name: Option<&String>,
2362 sequence_name: Option<&String>,
2363 ) -> Option<u32> {
2364 use std::path::Path;
2365
2366 let sequence = sequence_name?;
2367 let name = image_name?;
2368
2369 // Extract stem (remove extension)
2370 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2371
2372 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2373 stem.strip_prefix(sequence)
2374 .and_then(|suffix| suffix.strip_prefix('_'))
2375 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2376 }
2377
2378 /// Helper to set label index from a label map
2379 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2380 if let Some(label) = annotation.label() {
2381 annotation.set_label_index(Some(labels[label.as_str()]));
2382 }
2383 }
2384
2385 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2386 pub async fn samples_count(
2387 &self,
2388 dataset_id: DatasetID,
2389 annotation_set_id: Option<AnnotationSetID>,
2390 annotation_types: &[AnnotationType],
2391 groups: &[String],
2392 types: &[FileType],
2393 ) -> Result<SamplesCountResult, Error> {
2394 // Use server type names for API calls (e.g., "box" instead of "box2d")
2395 let types = annotation_types
2396 .iter()
2397 .map(|t| t.as_server_type().to_string())
2398 .chain(types.iter().map(|t| t.to_string()))
2399 .collect::<Vec<_>>();
2400
2401 let params = SamplesListParams {
2402 dataset_id,
2403 annotation_set_id,
2404 group_names: groups.to_vec(),
2405 types,
2406 continue_token: None,
2407 };
2408
2409 self.rpc("samples.count".to_owned(), Some(params)).await
2410 }
2411
2412 /// Fetches samples from a dataset with optional annotation and file type
2413 /// filters.
2414 ///
2415 /// # Arguments
2416 ///
2417 /// * `dataset_id` - The dataset to fetch samples from
2418 /// * `annotation_set_id` - Optional annotation set to include annotations
2419 /// from
2420 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2421 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2422 /// * `types` - File types to include metadata for
2423 /// * `progress` - Optional channel for progress updates
2424 ///
2425 /// # Progress
2426 ///
2427 /// Reports progress with `status: None` as samples are fetched from the
2428 /// server in paginated batches. Progress unit is samples fetched.
2429 ///
2430 /// # Returns
2431 ///
2432 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2433 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2434 pub async fn samples(
2435 &self,
2436 dataset_id: DatasetID,
2437 annotation_set_id: Option<AnnotationSetID>,
2438 annotation_types: &[AnnotationType],
2439 groups: &[String],
2440 types: &[FileType],
2441 progress: Option<Sender<Progress>>,
2442 ) -> Result<Vec<Sample>, Error> {
2443 // Use server type names for API calls (e.g., "box" instead of "box2d")
2444 let types_vec = annotation_types
2445 .iter()
2446 .map(|t| t.as_server_type().to_string())
2447 .chain(types.iter().map(|t| t.to_string()))
2448 .collect::<Vec<_>>();
2449 let labels = self
2450 .labels(dataset_id)
2451 .await?
2452 .into_iter()
2453 .map(|label| (label.name().to_string(), label.index()))
2454 .collect::<HashMap<_, _>>();
2455 let total = self
2456 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2457 .await?
2458 .total as usize;
2459
2460 if total == 0 {
2461 return Ok(vec![]);
2462 }
2463
2464 let context = FetchContext {
2465 dataset_id,
2466 annotation_set_id,
2467 groups,
2468 types: types_vec,
2469 labels: &labels,
2470 };
2471
2472 self.fetch_samples_paginated(context, total, progress).await
2473 }
2474
2475 /// Get all sample names in a dataset.
2476 ///
2477 /// This is an efficient method for checking which samples already exist,
2478 /// useful for resuming interrupted imports. It only retrieves sample names
2479 /// without loading full annotation data.
2480 ///
2481 /// # Arguments
2482 ///
2483 /// * `dataset_id` - The dataset to query
2484 /// * `groups` - Optional group filter (empty = all groups)
2485 /// * `progress` - Optional progress channel
2486 ///
2487 /// # Progress
2488 ///
2489 /// Reports progress with `status: None` as sample names are fetched from
2490 /// the server in paginated batches. Progress unit is samples fetched.
2491 ///
2492 /// # Returns
2493 ///
2494 /// A HashSet of sample names (image_name field) that exist in the dataset.
2495 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2496 pub async fn sample_names(
2497 &self,
2498 dataset_id: DatasetID,
2499 groups: &[String],
2500 progress: Option<Sender<Progress>>,
2501 ) -> Result<std::collections::HashSet<String>, Error> {
2502 use std::collections::HashSet;
2503
2504 let total = self
2505 .samples_count(dataset_id, None, &[], groups, &[])
2506 .await?
2507 .total as usize;
2508
2509 if total == 0 {
2510 return Ok(HashSet::new());
2511 }
2512
2513 let mut names = HashSet::with_capacity(total);
2514 let mut continue_token: Option<String> = None;
2515 let mut current = 0;
2516
2517 loop {
2518 let params = SamplesListParams {
2519 dataset_id,
2520 annotation_set_id: None,
2521 types: vec![], // No type filter - we just want names
2522 group_names: groups.to_vec(),
2523 continue_token: continue_token.clone(),
2524 };
2525
2526 let result: SamplesListResult =
2527 self.rpc("samples.list".to_owned(), Some(params)).await?;
2528 current += result.samples.len();
2529 continue_token = result.continue_token;
2530
2531 if result.samples.is_empty() {
2532 break;
2533 }
2534
2535 // Extract sample names (normalized without extension)
2536 for sample in result.samples {
2537 if let Some(name) = sample.name() {
2538 names.insert(name);
2539 }
2540 }
2541
2542 if let Some(ref p) = progress {
2543 let _ = p
2544 .send(Progress {
2545 current,
2546 total,
2547 status: None,
2548 })
2549 .await;
2550 }
2551
2552 match &continue_token {
2553 Some(token) if !token.is_empty() => continue,
2554 _ => break,
2555 }
2556 }
2557
2558 Ok(names)
2559 }
2560
2561 async fn fetch_samples_paginated(
2562 &self,
2563 context: FetchContext<'_>,
2564 total: usize,
2565 progress: Option<Sender<Progress>>,
2566 ) -> Result<Vec<Sample>, Error> {
2567 let mut samples = vec![];
2568 let mut continue_token: Option<String> = None;
2569 let mut current = 0;
2570
2571 loop {
2572 let params = SamplesListParams {
2573 dataset_id: context.dataset_id,
2574 annotation_set_id: context.annotation_set_id,
2575 types: context.types.clone(),
2576 group_names: context.groups.to_vec(),
2577 continue_token: continue_token.clone(),
2578 };
2579
2580 let result: SamplesListResult =
2581 self.rpc("samples.list".to_owned(), Some(params)).await?;
2582 current += result.samples.len();
2583 continue_token = result.continue_token;
2584
2585 if result.samples.is_empty() {
2586 break;
2587 }
2588
2589 samples.append(
2590 &mut result
2591 .samples
2592 .into_iter()
2593 .map(|s| {
2594 // Use server's frame_number if valid (>= 0 after deserialization)
2595 // Otherwise parse from image_name as fallback
2596 // This ensures we respect explicit frame_number from uploads
2597 // while still handling legacy data that only has filename encoding
2598 let frame_number = s.frame_number.or_else(|| {
2599 Self::parse_frame_from_image_name(
2600 s.image_name.as_ref(),
2601 s.sequence_name.as_ref(),
2602 )
2603 });
2604
2605 let mut anns = s.annotations().to_vec();
2606 for ann in &mut anns {
2607 // Set annotation fields from parent sample
2608 ann.set_name(s.name());
2609 ann.set_group(s.group().cloned());
2610 ann.set_sequence_name(s.sequence_name().cloned());
2611 ann.set_frame_number(frame_number);
2612 Self::set_label_index_from_map(ann, context.labels);
2613 }
2614 s.with_annotations(anns).with_frame_number(frame_number)
2615 })
2616 .collect::<Vec<_>>(),
2617 );
2618
2619 if let Some(progress) = &progress {
2620 let _ = progress
2621 .send(Progress {
2622 current,
2623 total,
2624 status: None,
2625 })
2626 .await;
2627 }
2628
2629 match &continue_token {
2630 Some(token) if !token.is_empty() => continue,
2631 _ => break,
2632 }
2633 }
2634
2635 drop(progress);
2636 Ok(samples)
2637 }
2638
2639 /// Populates (imports) samples into a dataset using the `samples.populate2`
2640 /// API.
2641 ///
2642 /// This method creates new samples in the specified dataset, optionally
2643 /// with annotations and sensor data files. For each sample, the `files`
2644 /// field is checked for local file paths. If a filename is a valid path
2645 /// to an existing file, the file will be automatically uploaded to S3
2646 /// using presigned URLs returned by the server. The filename in the
2647 /// request is replaced with the basename (path removed) before sending
2648 /// to the server.
2649 ///
2650 /// # Important Notes
2651 ///
2652 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2653 /// annotations. Without it, the server will accept the request but will
2654 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2655 /// available annotation sets for a dataset, or create a new one via the
2656 /// Studio UI.
2657 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2658 /// boxes. Divide pixel coordinates by image width/height before creating
2659 /// [`Box2d`](crate::Box2d) annotations.
2660 /// - **Files are uploaded automatically** when the filename is a valid
2661 /// local path. The method will replace the full path with just the
2662 /// basename before sending to the server.
2663 /// - **Image dimensions are extracted automatically** for image files using
2664 /// the `imagesize` crate. The width/height are sent to the server and
2665 /// stored in the `image_files` table. These dimensions are returned by
2666 /// `samples.list` and used in [`samples_dataframe`](crate::samples_dataframe)
2667 /// to populate the `size` column.
2668 /// - **UUIDs are generated automatically** if not provided. If you need
2669 /// deterministic UUIDs, set `sample.uuid` explicitly before calling.
2670 ///
2671 /// # Arguments
2672 ///
2673 /// * `dataset_id` - The ID of the dataset to populate
2674 /// * `annotation_set_id` - **Required** if samples contain annotations,
2675 /// otherwise they will be ignored. Query with
2676 /// [`Client::annotation_sets`].
2677 /// * `samples` - Vector of samples to import with metadata and file
2678 /// references. For files, use the full local path - it will be uploaded
2679 /// automatically. UUIDs and image dimensions will be
2680 /// auto-generated/extracted if not provided.
2681 /// * `progress` - Optional channel for progress updates
2682 ///
2683 /// # Progress
2684 ///
2685 /// Reports progress with `status: None` as each sample's files are
2686 /// uploaded. Progress unit is samples (not individual files). Each
2687 /// sample may contain multiple files (image, lidar, radar, etc.) which
2688 /// are all uploaded before the sample is counted as complete.
2689 ///
2690 /// # Returns
2691 ///
2692 /// Returns the API result with sample UUIDs and upload status.
2693 ///
2694 /// # Example
2695 ///
2696 /// ```no_run
2697 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2698 ///
2699 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2700 /// # let client = Client::new()?.with_login("user", "pass").await?;
2701 /// # let dataset_id = DatasetID::from(1);
2702 /// // Query available annotation sets for the dataset
2703 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2704 /// let annotation_set_id = annotation_sets
2705 /// .first()
2706 /// .ok_or_else(|| {
2707 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2708 /// })?
2709 /// .id();
2710 ///
2711 /// // Create sample with annotation (UUID will be auto-generated)
2712 /// let mut sample = Sample::new();
2713 /// sample.width = Some(1920);
2714 /// sample.height = Some(1080);
2715 /// sample.group = Some("train".to_string());
2716 ///
2717 /// // Add file - use full path to local file, it will be uploaded automatically
2718 /// sample.files = vec![SampleFile::with_filename(
2719 /// "image".to_string(),
2720 /// "/path/to/image.jpg".to_string(),
2721 /// )];
2722 ///
2723 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2724 /// let mut annotation = Annotation::new();
2725 /// annotation.set_label(Some("person".to_string()));
2726 /// // Normalize pixel coordinates by dividing by image dimensions
2727 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2728 /// annotation.set_box2d(Some(bbox));
2729 /// sample.annotations = vec![annotation];
2730 ///
2731 /// // Populate with annotation_set_id (REQUIRED for annotations)
2732 /// let result = client
2733 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2734 /// .await?;
2735 /// # Ok(())
2736 /// # }
2737 /// ```
2738 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2739 pub async fn populate_samples(
2740 &self,
2741 dataset_id: DatasetID,
2742 annotation_set_id: Option<AnnotationSetID>,
2743 samples: Vec<Sample>,
2744 progress: Option<Sender<Progress>>,
2745 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2746 self.populate_samples_with_concurrency(
2747 dataset_id,
2748 annotation_set_id,
2749 samples,
2750 progress,
2751 None,
2752 )
2753 .await
2754 }
2755
2756 /// Populate samples with custom upload concurrency.
2757 ///
2758 /// Same as [`populate_samples`](Self::populate_samples) but allows
2759 /// specifying the maximum number of concurrent file uploads. Use this
2760 /// for bulk imports where higher concurrency can significantly reduce
2761 /// upload time.
2762 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2763 pub async fn populate_samples_with_concurrency(
2764 &self,
2765 dataset_id: DatasetID,
2766 annotation_set_id: Option<AnnotationSetID>,
2767 samples: Vec<Sample>,
2768 progress: Option<Sender<Progress>>,
2769 concurrency: Option<usize>,
2770 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2771 use crate::api::SamplesPopulateParams;
2772
2773 // Track which files need to be uploaded
2774 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
2775
2776 // Process samples to detect local files and generate UUIDs
2777 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
2778
2779 let has_files_to_upload = !files_to_upload.is_empty();
2780
2781 // Call populate API with presigned_urls=true if we have files to upload
2782 let params = SamplesPopulateParams {
2783 dataset_id,
2784 annotation_set_id,
2785 presigned_urls: Some(has_files_to_upload),
2786 samples,
2787 };
2788
2789 let results: Vec<crate::SamplesPopulateResult> = self
2790 .rpc("samples.populate2".to_owned(), Some(params))
2791 .await?;
2792
2793 // Upload files if we have any
2794 if has_files_to_upload {
2795 self.upload_sample_files(&results, files_to_upload, progress, concurrency)
2796 .await?;
2797 }
2798
2799 Ok(results)
2800 }
2801
2802 fn prepare_samples_for_upload(
2803 &self,
2804 samples: Vec<Sample>,
2805 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2806 ) -> Result<Vec<Sample>, Error> {
2807 Ok(samples
2808 .into_iter()
2809 .map(|mut sample| {
2810 // Generate UUID if not provided
2811 if sample.uuid.is_none() {
2812 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2813 }
2814
2815 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2816
2817 // Process files: detect local paths and queue for upload
2818 let files_copy = sample.files.clone();
2819 let updated_files: Vec<crate::SampleFile> = files_copy
2820 .iter()
2821 .map(|file| {
2822 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2823 })
2824 .collect();
2825
2826 sample.files = updated_files;
2827 sample
2828 })
2829 .collect())
2830 }
2831
2832 fn process_sample_file(
2833 &self,
2834 file: &crate::SampleFile,
2835 sample_uuid: &str,
2836 sample: &mut Sample,
2837 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2838 ) -> crate::SampleFile {
2839 use std::path::Path;
2840
2841 // Handle files with raw bytes (e.g., from ZIP archives)
2842 if let Some(bytes) = file.bytes()
2843 && let Some(filename) = file.filename()
2844 {
2845 // For image files with bytes, try to extract dimensions if not already set
2846 if file.file_type() == "image"
2847 && (sample.width.is_none() || sample.height.is_none())
2848 && let Ok(size) = imagesize::blob_size(bytes)
2849 {
2850 sample.width = Some(size.width as u32);
2851 sample.height = Some(size.height as u32);
2852 }
2853
2854 // Store the bytes for later upload
2855 files_to_upload.push((
2856 sample_uuid.to_string(),
2857 file.file_type().to_string(),
2858 FileSource::Bytes(bytes.to_vec()),
2859 filename.to_string(),
2860 ));
2861
2862 // Return SampleFile with just the filename
2863 return crate::SampleFile::with_filename(
2864 file.file_type().to_string(),
2865 filename.to_string(),
2866 );
2867 }
2868
2869 // Handle files with local paths
2870 if let Some(filename) = file.filename() {
2871 let path = Path::new(filename);
2872
2873 // Check if this is a valid local file path
2874 if path.exists()
2875 && path.is_file()
2876 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2877 {
2878 // For image files, try to extract dimensions if not already set
2879 if file.file_type() == "image"
2880 && (sample.width.is_none() || sample.height.is_none())
2881 && let Ok(size) = imagesize::size(path)
2882 {
2883 sample.width = Some(size.width as u32);
2884 sample.height = Some(size.height as u32);
2885 }
2886
2887 // Store the full path for later upload
2888 files_to_upload.push((
2889 sample_uuid.to_string(),
2890 file.file_type().to_string(),
2891 FileSource::Path(path.to_path_buf()),
2892 basename.to_string(),
2893 ));
2894
2895 // Return SampleFile with just the basename
2896 return crate::SampleFile::with_filename(
2897 file.file_type().to_string(),
2898 basename.to_string(),
2899 );
2900 }
2901 }
2902 // Return the file unchanged if not a local path
2903 file.clone()
2904 }
2905
2906 async fn upload_sample_files(
2907 &self,
2908 results: &[crate::SamplesPopulateResult],
2909 files_to_upload: Vec<(String, String, FileSource, String)>,
2910 progress: Option<Sender<Progress>>,
2911 concurrency: Option<usize>,
2912 ) -> Result<(), Error> {
2913 // Build a map from (sample_uuid, basename) -> file source
2914 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
2915 for (uuid, _file_type, source, basename) in files_to_upload {
2916 upload_map.insert((uuid, basename), source);
2917 }
2918
2919 let http = self.bulk_http.clone();
2920
2921 // Extract the data we need for parallel upload
2922 let upload_tasks: Vec<_> = results
2923 .iter()
2924 .map(|result| (result.uuid.clone(), result.urls.clone()))
2925 .collect();
2926
2927 parallel_foreach_items(
2928 upload_tasks,
2929 progress.clone(),
2930 concurrency,
2931 move |(uuid, urls)| {
2932 let http = http.clone();
2933 let upload_map = upload_map.clone();
2934
2935 async move {
2936 // Upload all files for this sample
2937 for url_info in &urls {
2938 if let Some(source) =
2939 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2940 {
2941 match source {
2942 FileSource::Path(path) => {
2943 upload_file_to_presigned_url(
2944 http.clone(),
2945 &url_info.url,
2946 path.clone(),
2947 )
2948 .await?;
2949 }
2950 FileSource::Bytes(bytes) => {
2951 upload_bytes_to_presigned_url(
2952 http.clone(),
2953 &url_info.url,
2954 bytes.clone(),
2955 &url_info.filename,
2956 )
2957 .await?;
2958 }
2959 }
2960 }
2961 }
2962
2963 Ok(())
2964 }
2965 },
2966 )
2967 .await
2968 }
2969
2970 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2971 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
2972 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
2973 if !url.starts_with("http://") && !url.starts_with("https://") {
2974 return Err(Error::InvalidParameters(format!(
2975 "Invalid URL (must be absolute): {}",
2976 url
2977 )));
2978 }
2979
2980 let resp = self.bulk_http.get(url).send().await?;
2981
2982 if !resp.status().is_success() {
2983 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2984 }
2985
2986 let bytes = resp.bytes().await?;
2987 Ok(bytes.to_vec())
2988 }
2989
2990 /// Get samples as a DataFrame with complete 2025.10 schema.
2991 ///
2992 /// This is the recommended method for obtaining dataset annotations in
2993 /// DataFrame format. It includes all sample metadata (size, location,
2994 /// pose, degradation) as optional columns.
2995 ///
2996 /// # Arguments
2997 ///
2998 /// * `dataset_id` - Dataset identifier
2999 /// * `annotation_set_id` - Optional annotation set filter
3000 /// * `groups` - Dataset groups to include (train, val, test)
3001 /// * `types` - Annotation types to filter (bbox, box3d, mask)
3002 /// * `progress` - Optional progress callback
3003 ///
3004 /// # Progress
3005 ///
3006 /// Reports progress with `status: None` as samples are fetched from the
3007 /// server in paginated batches. Progress unit is samples fetched. This
3008 /// method delegates to [`samples()`](Self::samples) and shares its
3009 /// progress behavior.
3010 ///
3011 /// # Example
3012 ///
3013 /// ```rust,no_run
3014 /// use edgefirst_client::Client;
3015 ///
3016 /// # async fn example() -> Result<(), edgefirst_client::Error> {
3017 /// # let client = Client::new()?;
3018 /// # let dataset_id = 1.into();
3019 /// # let annotation_set_id = 1.into();
3020 /// let df = client
3021 /// .samples_dataframe(
3022 /// dataset_id,
3023 /// Some(annotation_set_id),
3024 /// &["train".to_string()],
3025 /// &[],
3026 /// None,
3027 /// )
3028 /// .await?;
3029 /// println!("DataFrame shape: {:?}", df.shape());
3030 /// # Ok(())
3031 /// # }
3032 /// ```
3033 #[cfg(feature = "polars")]
3034 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3035 pub async fn samples_dataframe(
3036 &self,
3037 dataset_id: DatasetID,
3038 annotation_set_id: Option<AnnotationSetID>,
3039 groups: &[String],
3040 types: &[AnnotationType],
3041 progress: Option<Sender<Progress>>,
3042 ) -> Result<DataFrame, Error> {
3043 use crate::dataset::samples_dataframe;
3044
3045 let samples = self
3046 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
3047 .await?;
3048 samples_dataframe(&samples)
3049 }
3050
3051 /// Update image dimensions for existing samples in a dataset.
3052 ///
3053 /// This is useful for backfilling width/height data on samples that were
3054 /// uploaded before dimension extraction was added, or where dimensions
3055 /// could not be determined at upload time.
3056 ///
3057 /// # Arguments
3058 ///
3059 /// * `dataset_id` - The dataset containing the samples
3060 /// * `updates` - List of dimension updates (sample ID, width, height)
3061 ///
3062 /// # Returns
3063 ///
3064 /// The number of samples that were successfully updated.
3065 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, updates), fields(dataset_id = %dataset_id, count = updates.len())))]
3066 pub async fn update_sample_dimensions(
3067 &self,
3068 dataset_id: DatasetID,
3069 updates: Vec<crate::SampleDimensionUpdate>,
3070 ) -> Result<u64, Error> {
3071 use crate::api::SamplesUpdateDimensionsParams;
3072
3073 if updates.is_empty() {
3074 return Ok(0);
3075 }
3076
3077 // Batch in groups of 500 to stay within server limits
3078 let mut total_updated = 0u64;
3079 for chunk in updates.chunks(500) {
3080 let params = SamplesUpdateDimensionsParams {
3081 dataset_id,
3082 samples: chunk.to_vec(),
3083 };
3084 let result: crate::SamplesUpdateDimensionsResult = self
3085 .rpc("samples.update_dimensions".to_owned(), Some(params))
3086 .await?;
3087 total_updated += result.updated;
3088 }
3089 Ok(total_updated)
3090 }
3091
3092 /// Backfill missing image dimensions for a dataset.
3093 ///
3094 /// Downloads image data for samples that are missing width/height,
3095 /// extracts the dimensions using the `imagesize` crate, and updates
3096 /// the server with the computed values.
3097 ///
3098 /// This is a one-time repair operation for datasets that were uploaded
3099 /// before the client added automatic dimension extraction.
3100 ///
3101 /// # Arguments
3102 ///
3103 /// * `dataset_id` - The dataset to backfill
3104 /// * `progress` - Optional progress channel
3105 ///
3106 /// # Returns
3107 ///
3108 /// The number of samples whose dimensions were updated.
3109 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(dataset_id = %dataset_id)))]
3110 pub async fn backfill_sample_dimensions(
3111 &self,
3112 dataset_id: DatasetID,
3113 progress: Option<Sender<Progress>>,
3114 ) -> Result<u64, Error> {
3115 // Fetch all samples; listing progress is not forwarded to the caller
3116 // since it would interleave with the dimension-computing phase.
3117 let samples = self.samples(dataset_id, None, &[], &[], &[], None).await?;
3118
3119 // Filter to samples missing dimensions
3120 let missing: Vec<&Sample> = samples
3121 .iter()
3122 .filter(|s| s.width.is_none() || s.height.is_none())
3123 .collect();
3124
3125 if missing.is_empty() {
3126 return Ok(0);
3127 }
3128
3129 let total = missing.len();
3130 let mut updates: Vec<crate::SampleDimensionUpdate> = Vec::with_capacity(total);
3131
3132 for (i, sample) in missing.into_iter().enumerate() {
3133 let current = i + 1;
3134
3135 let Some(id) = sample.id() else {
3136 Self::send_progress(&progress, current, total).await;
3137 continue;
3138 };
3139
3140 let Some(url) = sample.image_url() else {
3141 #[cfg(feature = "profiling")]
3142 tracing::warn!(sample_id = %id, "skipping sample: no image URL");
3143 Self::send_progress(&progress, current, total).await;
3144 continue;
3145 };
3146
3147 // Download image data to determine dimensions
3148 let resp = self.bulk_http.get(url).send().await;
3149 let Ok(resp) = resp else {
3150 #[cfg(feature = "profiling")]
3151 tracing::warn!(sample_id = %id, "skipping sample: download failed");
3152 Self::send_progress(&progress, current, total).await;
3153 continue;
3154 };
3155
3156 // Skip non-success responses (e.g. 404, 500) rather than parsing error pages
3157 if !resp.status().is_success() {
3158 #[cfg(feature = "profiling")]
3159 tracing::warn!(sample_id = %id, status = %resp.status(), "skipping sample: non-success HTTP status");
3160 Self::send_progress(&progress, current, total).await;
3161 continue;
3162 }
3163
3164 let Ok(bytes) = resp.bytes().await else {
3165 #[cfg(feature = "profiling")]
3166 tracing::warn!(sample_id = %id, "skipping sample: failed to read response body");
3167 Self::send_progress(&progress, current, total).await;
3168 continue;
3169 };
3170
3171 // Extract dimensions from the downloaded image
3172 let Ok(size) = imagesize::blob_size(&bytes) else {
3173 #[cfg(feature = "profiling")]
3174 tracing::warn!(sample_id = %id, "skipping sample: could not determine dimensions");
3175 Self::send_progress(&progress, current, total).await;
3176 continue;
3177 };
3178
3179 let (Ok(width), Ok(height)) = (u32::try_from(size.width), u32::try_from(size.height))
3180 else {
3181 #[cfg(feature = "profiling")]
3182 tracing::warn!(sample_id = %id, width = size.width, height = size.height, "skipping sample: dimensions overflow u32");
3183 Self::send_progress(&progress, current, total).await;
3184 continue;
3185 };
3186
3187 updates.push(crate::SampleDimensionUpdate { id, width, height });
3188 Self::send_progress(&progress, current, total).await;
3189 }
3190
3191 // Send updates to server
3192 self.update_sample_dimensions(dataset_id, updates).await
3193 }
3194
3195 /// Emit a progress event if a progress channel is provided.
3196 async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
3197 if let Some(tx) = progress {
3198 let _ = tx
3199 .send(Progress {
3200 current,
3201 total,
3202 status: Some("Computing dimensions".to_string()),
3203 })
3204 .await;
3205 }
3206 }
3207
3208 /// List available snapshots. If a name is provided, only snapshots
3209 /// containing that name are returned.
3210 ///
3211 /// Results are sorted by match quality: exact matches first, then
3212 /// case-insensitive exact matches, then shorter descriptions (more
3213 /// specific), then alphabetically.
3214 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3215 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
3216 let snapshots: Vec<Snapshot> = self
3217 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
3218 .await?;
3219 if let Some(name) = name {
3220 Ok(filter_and_sort_by_name(snapshots, name, |s| {
3221 s.description()
3222 }))
3223 } else {
3224 Ok(snapshots)
3225 }
3226 }
3227
3228 /// Get the snapshot with the specified id.
3229 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3230 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
3231 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3232 self.rpc("snapshots.get".to_owned(), Some(params)).await
3233 }
3234
3235 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
3236 ///
3237 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
3238 /// pairs) that serve two primary purposes:
3239 ///
3240 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
3241 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
3242 /// restored with AGTG (Automatic Ground Truth Generation) and optional
3243 /// auto-depth processing.
3244 ///
3245 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
3246 /// migration between EdgeFirst Studio instances using the create →
3247 /// download → upload → restore workflow.
3248 ///
3249 /// Large files are automatically chunked into 100MB parts and uploaded
3250 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
3251 /// is streamed without loading into memory, maintaining constant memory
3252 /// usage.
3253 ///
3254 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3255 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
3256 /// better for large files to avoid timeout issues. Higher values (16-32)
3257 /// are better for many small files.
3258 ///
3259 /// # Arguments
3260 ///
3261 /// * `path` - Local file path to MCAP file or directory containing
3262 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
3263 /// * `progress` - Optional channel to receive upload progress updates
3264 ///
3265 /// # Progress
3266 ///
3267 /// Reports progress with `status: None` as file data is uploaded. Progress
3268 /// unit is bytes uploaded. For single files, total is the file size. For
3269 /// directories, total is the combined size of all files.
3270 ///
3271 /// # Returns
3272 ///
3273 /// Returns a `Snapshot` object with ID, description, status, path, and
3274 /// creation timestamp on success.
3275 ///
3276 /// # Errors
3277 ///
3278 /// Returns an error if:
3279 /// * Path doesn't exist or contains invalid UTF-8
3280 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
3281 /// * Upload fails or network error occurs
3282 /// * Server rejects the snapshot
3283 ///
3284 /// # Example
3285 ///
3286 /// ```no_run
3287 /// # use edgefirst_client::{Client, Progress};
3288 /// # use tokio::sync::mpsc;
3289 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3290 /// let client = Client::new()?.with_token_path(None)?;
3291 ///
3292 /// // Upload MCAP file with progress tracking
3293 /// let (tx, mut rx) = mpsc::channel(1);
3294 /// tokio::spawn(async move {
3295 /// while let Some(Progress {
3296 /// current,
3297 /// total,
3298 /// status,
3299 /// }) = rx.recv().await
3300 /// {
3301 /// println!(
3302 /// "{}: {}/{} bytes ({:.1}%)",
3303 /// status.as_deref().unwrap_or("Upload"),
3304 /// current,
3305 /// total,
3306 /// (current as f64 / total as f64) * 100.0
3307 /// );
3308 /// }
3309 /// });
3310 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
3311 /// println!("Created snapshot: {:?}", snapshot.id());
3312 ///
3313 /// // Upload dataset directory (no progress)
3314 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
3315 /// # Ok(())
3316 /// # }
3317 /// ```
3318 ///
3319 /// # See Also
3320 ///
3321 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3322 /// dataset
3323 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3324 /// data
3325 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3326 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3327 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
3328 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3329 pub async fn create_snapshot(
3330 &self,
3331 path: &str,
3332 progress: Option<Sender<Progress>>,
3333 ) -> Result<Snapshot, Error> {
3334 let path = Path::new(path);
3335
3336 if path.is_dir() {
3337 let path_str = path.to_str().ok_or_else(|| {
3338 Error::IoError(std::io::Error::new(
3339 std::io::ErrorKind::InvalidInput,
3340 "Path contains invalid UTF-8",
3341 ))
3342 })?;
3343 return self.create_snapshot_folder(path_str, progress).await;
3344 }
3345
3346 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3347 Error::IoError(std::io::Error::new(
3348 std::io::ErrorKind::InvalidInput,
3349 "Invalid filename",
3350 ))
3351 })?;
3352 let total = path.metadata()?.len() as usize;
3353 let current = Arc::new(AtomicUsize::new(0));
3354
3355 if let Some(progress) = &progress {
3356 let _ = progress
3357 .send(Progress {
3358 current: 0,
3359 total,
3360 status: None,
3361 })
3362 .await;
3363 }
3364
3365 let params = SnapshotCreateMultipartParams {
3366 snapshot_name: name.to_owned(),
3367 keys: vec![name.to_owned()],
3368 file_sizes: vec![total],
3369 snapshot_type: None,
3370 };
3371 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3372 .rpc(
3373 "snapshots.create_upload_url_multipart".to_owned(),
3374 Some(params),
3375 )
3376 .await?;
3377
3378 let snapshot_id = match multipart.get("snapshot_id") {
3379 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3380 _ => return Err(Error::InvalidResponse),
3381 };
3382
3383 let snapshot = self.snapshot(snapshot_id).await?;
3384 let part_prefix = snapshot
3385 .path()
3386 .split("::/")
3387 .last()
3388 .ok_or(Error::InvalidResponse)?
3389 .to_owned();
3390 let part_key = format!("{}/{}", part_prefix, name);
3391 let mut part = match multipart.get(&part_key) {
3392 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3393 _ => return Err(Error::InvalidResponse),
3394 }
3395 .clone();
3396 part.key = Some(part_key);
3397
3398 let params = upload_multipart(
3399 self.bulk_http.clone(),
3400 part.clone(),
3401 path.to_path_buf(),
3402 total,
3403 current,
3404 progress.clone(),
3405 )
3406 .await?;
3407
3408 let complete: String = self
3409 .rpc(
3410 "snapshots.complete_multipart_upload".to_owned(),
3411 Some(params),
3412 )
3413 .await?;
3414 debug!("Snapshot Multipart Complete: {:?}", complete);
3415
3416 let params: SnapshotStatusParams = SnapshotStatusParams {
3417 snapshot_id,
3418 status: "available".to_owned(),
3419 };
3420 let _: SnapshotStatusResult = self
3421 .rpc("snapshots.update".to_owned(), Some(params))
3422 .await?;
3423
3424 if let Some(progress) = progress {
3425 drop(progress);
3426 }
3427
3428 self.snapshot(snapshot_id).await
3429 }
3430
3431 async fn create_snapshot_folder(
3432 &self,
3433 path: &str,
3434 progress: Option<Sender<Progress>>,
3435 ) -> Result<Snapshot, Error> {
3436 let path = Path::new(path);
3437 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3438 Error::IoError(std::io::Error::new(
3439 std::io::ErrorKind::InvalidInput,
3440 "Invalid directory name",
3441 ))
3442 })?;
3443
3444 let files = WalkDir::new(path)
3445 .into_iter()
3446 .filter_map(|entry| entry.ok())
3447 .filter(|entry| entry.file_type().is_file())
3448 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3449 .collect::<Vec<_>>();
3450
3451 let total: usize = files
3452 .iter()
3453 .filter_map(|file| path.join(file).metadata().ok())
3454 .map(|metadata| metadata.len() as usize)
3455 .sum();
3456 let current = Arc::new(AtomicUsize::new(0));
3457
3458 if let Some(progress) = &progress {
3459 let _ = progress
3460 .send(Progress {
3461 current: 0,
3462 total,
3463 status: None,
3464 })
3465 .await;
3466 }
3467
3468 let keys = files
3469 .iter()
3470 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3471 .collect::<Vec<_>>();
3472 let file_sizes = files
3473 .iter()
3474 .filter_map(|key| path.join(key).metadata().ok())
3475 .map(|metadata| metadata.len() as usize)
3476 .collect::<Vec<_>>();
3477
3478 let params = SnapshotCreateMultipartParams {
3479 snapshot_name: name.to_owned(),
3480 keys,
3481 file_sizes,
3482 snapshot_type: None,
3483 };
3484
3485 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3486 .rpc(
3487 "snapshots.create_upload_url_multipart".to_owned(),
3488 Some(params),
3489 )
3490 .await?;
3491
3492 let snapshot_id = match multipart.get("snapshot_id") {
3493 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3494 _ => return Err(Error::InvalidResponse),
3495 };
3496
3497 let snapshot = self.snapshot(snapshot_id).await?;
3498 let part_prefix = snapshot
3499 .path()
3500 .split("::/")
3501 .last()
3502 .ok_or(Error::InvalidResponse)?
3503 .to_owned();
3504
3505 for file in files {
3506 let file_str = file.to_str().ok_or_else(|| {
3507 Error::IoError(std::io::Error::new(
3508 std::io::ErrorKind::InvalidInput,
3509 "File path contains invalid UTF-8",
3510 ))
3511 })?;
3512 let part_key = format!("{}/{}", part_prefix, file_str);
3513 let mut part = match multipart.get(&part_key) {
3514 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3515 _ => return Err(Error::InvalidResponse),
3516 }
3517 .clone();
3518 part.key = Some(part_key);
3519
3520 let params = upload_multipart(
3521 self.bulk_http.clone(),
3522 part.clone(),
3523 path.join(file),
3524 total,
3525 current.clone(),
3526 progress.clone(),
3527 )
3528 .await?;
3529
3530 let complete: String = self
3531 .rpc(
3532 "snapshots.complete_multipart_upload".to_owned(),
3533 Some(params),
3534 )
3535 .await?;
3536 debug!("Snapshot Part Complete: {:?}", complete);
3537 }
3538
3539 let params = SnapshotStatusParams {
3540 snapshot_id,
3541 status: "available".to_owned(),
3542 };
3543 let _: SnapshotStatusResult = self
3544 .rpc("snapshots.update".to_owned(), Some(params))
3545 .await?;
3546
3547 if let Some(progress) = progress {
3548 drop(progress);
3549 }
3550
3551 self.snapshot(snapshot_id).await
3552 }
3553
3554 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3555 ///
3556 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3557 /// This format is the native EdgeFirst Dataset Format used for efficient
3558 /// dataset storage and transfer.
3559 ///
3560 /// # Arguments
3561 ///
3562 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3563 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3564 /// * `description` - Optional description for the snapshot
3565 /// * `progress` - Optional progress channel for upload tracking
3566 ///
3567 /// # File Requirements
3568 ///
3569 /// - Arrow file must have `.arrow` extension
3570 /// - ZIP file must have `.zip` extension
3571 /// - Both files must exist and be readable
3572 ///
3573 /// # Example
3574 ///
3575 /// ```no_run
3576 /// # use edgefirst_client::Client;
3577 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3578 /// let client = Client::new()?.with_token_path(None)?;
3579 ///
3580 /// let snapshot = client
3581 /// .create_snapshot_edgefirst_format(
3582 /// "dataset.arrow",
3583 /// "dataset.zip",
3584 /// Some("My Dataset Snapshot"),
3585 /// None,
3586 /// )
3587 /// .await?;
3588 /// println!("Created snapshot: {}", snapshot.id());
3589 /// # Ok(())
3590 /// # }
3591 /// ```
3592 ///
3593 /// # See Also
3594 ///
3595 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3596 /// folder
3597 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3598 /// dataset
3599 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3600 pub async fn create_snapshot_edgefirst_format(
3601 &self,
3602 arrow_path: &str,
3603 zip_path: &str,
3604 description: Option<&str>,
3605 progress: Option<Sender<Progress>>,
3606 ) -> Result<Snapshot, Error> {
3607 let arrow_path = Path::new(arrow_path);
3608 let zip_path = Path::new(zip_path);
3609
3610 // Validate files exist
3611 if !arrow_path.exists() {
3612 return Err(Error::IoError(std::io::Error::new(
3613 std::io::ErrorKind::NotFound,
3614 format!("Arrow file not found: {}", arrow_path.display()),
3615 )));
3616 }
3617 if !zip_path.exists() {
3618 return Err(Error::IoError(std::io::Error::new(
3619 std::io::ErrorKind::NotFound,
3620 format!("ZIP file not found: {}", zip_path.display()),
3621 )));
3622 }
3623
3624 // Get file names
3625 let arrow_name = arrow_path
3626 .file_name()
3627 .and_then(|n| n.to_str())
3628 .ok_or_else(|| {
3629 Error::IoError(std::io::Error::new(
3630 std::io::ErrorKind::InvalidInput,
3631 "Invalid Arrow filename",
3632 ))
3633 })?;
3634 let zip_name = zip_path
3635 .file_name()
3636 .and_then(|n| n.to_str())
3637 .ok_or_else(|| {
3638 Error::IoError(std::io::Error::new(
3639 std::io::ErrorKind::InvalidInput,
3640 "Invalid ZIP filename",
3641 ))
3642 })?;
3643
3644 // Generate snapshot name from arrow file (without extension)
3645 let snapshot_name = description
3646 .map(|s| s.to_string())
3647 .or_else(|| {
3648 arrow_path
3649 .file_stem()
3650 .and_then(|s| s.to_str())
3651 .map(|s| s.to_string())
3652 })
3653 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3654
3655 // Calculate file sizes
3656 let arrow_size = arrow_path.metadata()?.len() as usize;
3657 let zip_size = zip_path.metadata()?.len() as usize;
3658 let total = arrow_size + zip_size;
3659 let current = Arc::new(AtomicUsize::new(0));
3660
3661 if let Some(progress) = &progress {
3662 let _ = progress
3663 .send(Progress {
3664 current: 0,
3665 total,
3666 status: None,
3667 })
3668 .await;
3669 }
3670
3671 // Create multipart upload request with "ziparrow" type
3672 let params = SnapshotCreateMultipartParams {
3673 snapshot_name,
3674 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
3675 file_sizes: vec![arrow_size, zip_size],
3676 snapshot_type: Some("ziparrow".to_string()),
3677 };
3678
3679 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3680 .rpc(
3681 "snapshots.create_upload_url_multipart".to_owned(),
3682 Some(params),
3683 )
3684 .await?;
3685
3686 let snapshot_id = match multipart.get("snapshot_id") {
3687 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3688 _ => return Err(Error::InvalidResponse),
3689 };
3690
3691 let snapshot = self.snapshot(snapshot_id).await?;
3692 let part_prefix = snapshot
3693 .path()
3694 .split("::/")
3695 .last()
3696 .ok_or(Error::InvalidResponse)?
3697 .to_owned();
3698
3699 // Upload Arrow file
3700 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
3701 let mut arrow_part = match multipart.get(&arrow_key) {
3702 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3703 _ => return Err(Error::InvalidResponse),
3704 };
3705 arrow_part.key = Some(arrow_key);
3706
3707 let params = upload_multipart(
3708 self.bulk_http.clone(),
3709 arrow_part,
3710 arrow_path.to_path_buf(),
3711 total,
3712 current.clone(),
3713 progress.clone(),
3714 )
3715 .await?;
3716
3717 let _: String = self
3718 .rpc(
3719 "snapshots.complete_multipart_upload".to_owned(),
3720 Some(params),
3721 )
3722 .await?;
3723 debug!("Arrow file upload complete");
3724
3725 // Upload ZIP file
3726 let zip_key = format!("{}/{}", part_prefix, zip_name);
3727 let mut zip_part = match multipart.get(&zip_key) {
3728 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3729 _ => return Err(Error::InvalidResponse),
3730 };
3731 zip_part.key = Some(zip_key);
3732
3733 let params = upload_multipart(
3734 self.bulk_http.clone(),
3735 zip_part,
3736 zip_path.to_path_buf(),
3737 total,
3738 current.clone(),
3739 progress.clone(),
3740 )
3741 .await?;
3742
3743 let _: String = self
3744 .rpc(
3745 "snapshots.complete_multipart_upload".to_owned(),
3746 Some(params),
3747 )
3748 .await?;
3749 debug!("ZIP file upload complete");
3750
3751 // Mark snapshot as available
3752 let params = SnapshotStatusParams {
3753 snapshot_id,
3754 status: "available".to_owned(),
3755 };
3756 let _: SnapshotStatusResult = self
3757 .rpc("snapshots.update".to_owned(), Some(params))
3758 .await?;
3759
3760 if let Some(progress) = progress {
3761 drop(progress);
3762 }
3763
3764 self.snapshot(snapshot_id).await
3765 }
3766
3767 /// Delete a snapshot from EdgeFirst Studio.
3768 ///
3769 /// Permanently removes a snapshot and its associated data. This operation
3770 /// cannot be undone.
3771 ///
3772 /// # Arguments
3773 ///
3774 /// * `snapshot_id` - The snapshot ID to delete
3775 ///
3776 /// # Errors
3777 ///
3778 /// Returns an error if:
3779 /// * Snapshot doesn't exist
3780 /// * User lacks permission to delete the snapshot
3781 /// * Server error occurs
3782 ///
3783 /// # Example
3784 ///
3785 /// ```no_run
3786 /// # use edgefirst_client::{Client, SnapshotID};
3787 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3788 /// let client = Client::new()?.with_token_path(None)?;
3789 /// let snapshot_id = SnapshotID::from(123);
3790 /// client.delete_snapshot(snapshot_id).await?;
3791 /// # Ok(())
3792 /// # }
3793 /// ```
3794 ///
3795 /// # See Also
3796 ///
3797 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3798 /// * [`snapshots`](Self::snapshots) - List all snapshots
3799 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3800 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
3801 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3802 let _: serde_json::Value = self
3803 .rpc("snapshots.delete".to_owned(), Some(params))
3804 .await?;
3805 Ok(())
3806 }
3807
3808 /// Create a snapshot from an existing dataset on the server.
3809 ///
3810 /// Triggers server-side snapshot generation which exports the dataset's
3811 /// images and annotations into a downloadable EdgeFirst Dataset Format
3812 /// snapshot.
3813 ///
3814 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
3815 /// while restore creates a dataset from a snapshot, this method creates a
3816 /// snapshot from a dataset.
3817 ///
3818 /// # Arguments
3819 ///
3820 /// * `dataset_id` - The dataset ID to create snapshot from
3821 /// * `description` - Description for the created snapshot
3822 ///
3823 /// # Returns
3824 ///
3825 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
3826 /// for monitoring progress.
3827 ///
3828 /// # Errors
3829 ///
3830 /// Returns an error if:
3831 /// * Dataset doesn't exist
3832 /// * User lacks permission to access the dataset
3833 /// * Server rejects the request
3834 ///
3835 /// # Example
3836 ///
3837 /// ```no_run
3838 /// # use edgefirst_client::{Client, DatasetID};
3839 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3840 /// let client = Client::new()?.with_token_path(None)?;
3841 /// let dataset_id = DatasetID::from(123);
3842 ///
3843 /// // Create snapshot from dataset (all annotation sets)
3844 /// let result = client
3845 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
3846 /// .await?;
3847 /// println!("Created snapshot: {:?}", result.id);
3848 ///
3849 /// // Monitor progress via task ID
3850 /// if let Some(task_id) = result.task_id {
3851 /// println!("Task: {}", task_id);
3852 /// }
3853 /// # Ok(())
3854 /// # }
3855 /// ```
3856 ///
3857 /// # See Also
3858 ///
3859 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
3860 /// snapshot
3861 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3862 /// dataset
3863 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3864 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3865 pub async fn create_snapshot_from_dataset(
3866 &self,
3867 dataset_id: DatasetID,
3868 description: &str,
3869 annotation_set_id: Option<AnnotationSetID>,
3870 ) -> Result<SnapshotFromDatasetResult, Error> {
3871 // Resolve annotation_set_id: use provided value or fetch default
3872 let annotation_set_id = match annotation_set_id {
3873 Some(id) => id,
3874 None => {
3875 // Fetch annotation sets and find default ("annotations") or use first
3876 let sets = self.annotation_sets(dataset_id).await?;
3877 if sets.is_empty() {
3878 return Err(Error::InvalidParameters(
3879 "No annotation sets available for dataset".to_owned(),
3880 ));
3881 }
3882 // Look for "annotations" set (default), otherwise use first
3883 sets.iter()
3884 .find(|s| s.name() == "annotations")
3885 .unwrap_or(&sets[0])
3886 .id()
3887 }
3888 };
3889 let params = SnapshotCreateFromDataset {
3890 description: description.to_owned(),
3891 dataset_id,
3892 annotation_set_id,
3893 };
3894 self.rpc("snapshots.create".to_owned(), Some(params)).await
3895 }
3896
3897 /// Download a snapshot from EdgeFirst Studio to local storage.
3898 ///
3899 /// Downloads all files in a snapshot (single MCAP file or directory of
3900 /// EdgeFirst Dataset Format files) to the specified output path. Files are
3901 /// downloaded concurrently with progress tracking.
3902 ///
3903 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3904 /// downloads (default: half of CPU cores, min 2, max 8).
3905 ///
3906 /// # Arguments
3907 ///
3908 /// * `snapshot_id` - The snapshot ID to download
3909 /// * `output` - Local directory path to save downloaded files
3910 /// * `progress` - Optional channel to receive download progress updates
3911 ///
3912 /// # Progress
3913 ///
3914 /// Reports progress with `status: None` as file data is received. Progress
3915 /// unit is bytes downloaded across all files combined. The total
3916 /// accumulates as file sizes become known (from HTTP Content-Length
3917 /// headers), so both `current` and `total` may increase during
3918 /// download.
3919 ///
3920 /// # Errors
3921 ///
3922 /// Returns an error if:
3923 /// * Snapshot doesn't exist
3924 /// * Output directory cannot be created
3925 /// * Download fails or network error occurs
3926 ///
3927 /// # Example
3928 ///
3929 /// ```no_run
3930 /// # use edgefirst_client::{Client, SnapshotID, Progress};
3931 /// # use tokio::sync::mpsc;
3932 /// # use std::path::PathBuf;
3933 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3934 /// let client = Client::new()?.with_token_path(None)?;
3935 /// let snapshot_id = SnapshotID::from(123);
3936 ///
3937 /// // Download with progress tracking
3938 /// let (tx, mut rx) = mpsc::channel(1);
3939 /// tokio::spawn(async move {
3940 /// while let Some(Progress {
3941 /// current,
3942 /// total,
3943 /// status,
3944 /// }) = rx.recv().await
3945 /// {
3946 /// println!(
3947 /// "{}: {}/{} bytes",
3948 /// status.as_deref().unwrap_or("Download"),
3949 /// current,
3950 /// total
3951 /// );
3952 /// }
3953 /// });
3954 /// client
3955 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
3956 /// .await?;
3957 /// # Ok(())
3958 /// # }
3959 /// ```
3960 ///
3961 /// # See Also
3962 ///
3963 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3964 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3965 /// dataset
3966 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3967 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
3968 pub async fn download_snapshot(
3969 &self,
3970 snapshot_id: SnapshotID,
3971 output: PathBuf,
3972 progress: Option<Sender<Progress>>,
3973 ) -> Result<(), Error> {
3974 fs::create_dir_all(&output).await?;
3975
3976 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3977 let items: HashMap<String, String> = self
3978 .rpc("snapshots.create_download_url".to_owned(), Some(params))
3979 .await?;
3980
3981 // Single-phase: each task holds its semaphore permit for the full
3982 // lifetime of the request (GET → headers → stream → disk). This bounds
3983 // the number of simultaneously-open connections to max_tasks() and
3984 // avoids accumulating all responses in memory before streaming.
3985 //
3986 // total is updated atomically as each response's Content-Length header
3987 // arrives, so progress tracking is accurate without a separate phase.
3988 let http = self.bulk_http.clone();
3989 let current = Arc::new(AtomicUsize::new(0));
3990 let total = Arc::new(AtomicUsize::new(0));
3991 let sem = Arc::new(Semaphore::new(max_tasks()));
3992
3993 let tasks = items
3994 .into_iter()
3995 .map(|(key, url)| {
3996 let http = http.clone();
3997 let output = output.clone();
3998 let progress = progress.clone();
3999 let current = current.clone();
4000 let total = total.clone();
4001 let sem = sem.clone();
4002
4003 tokio::spawn(async move {
4004 let _permit = sem.acquire().await.map_err(|_| {
4005 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4006 })?;
4007
4008 let res = http.get(url).send().await?;
4009 let res = res.error_for_status()?;
4010
4011 // Contribute this file's size to the running total so the
4012 // caller's progress bar knows the overall scope.
4013 if let Some(len) = res.content_length() {
4014 total.fetch_add(len as usize, Ordering::SeqCst);
4015 }
4016
4017 let mut file = File::create(output.join(key)).await?;
4018 let mut stream = res.bytes_stream();
4019
4020 while let Some(chunk) = stream.next().await {
4021 let chunk = chunk?;
4022 file.write_all(&chunk).await?;
4023 let len = chunk.len();
4024
4025 if let Some(progress) = &progress {
4026 let cur = current.fetch_add(len, Ordering::SeqCst) + len;
4027 let tot = total.load(Ordering::SeqCst);
4028 let _ = progress
4029 .send(Progress {
4030 current: cur,
4031 total: tot,
4032 status: None,
4033 })
4034 .await;
4035 }
4036 }
4037
4038 Ok::<(), Error>(())
4039 })
4040 })
4041 .collect::<Vec<_>>();
4042
4043 join_all(tasks)
4044 .await
4045 .into_iter()
4046 .collect::<Result<Vec<_>, _>>()?
4047 .into_iter()
4048 .collect::<Result<Vec<_>, _>>()?;
4049
4050 Ok(())
4051 }
4052
4053 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
4054 ///
4055 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
4056 /// the specified project. For MCAP files, supports:
4057 ///
4058 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
4059 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
4060 /// present)
4061 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
4062 /// * **Topic filtering**: Select specific MCAP topics to restore
4063 ///
4064 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
4065 /// dataset structure.
4066 ///
4067 /// # Arguments
4068 ///
4069 /// * `project_id` - Target project ID
4070 /// * `snapshot_id` - Snapshot ID to restore
4071 /// * `topics` - MCAP topics to include (empty = all topics)
4072 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
4073 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
4074 /// * `dataset_name` - Optional custom dataset name
4075 /// * `dataset_description` - Optional dataset description
4076 ///
4077 /// # Returns
4078 ///
4079 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
4080 ///
4081 /// # Errors
4082 ///
4083 /// Returns an error if:
4084 /// * Snapshot or project doesn't exist
4085 /// * Snapshot format is invalid
4086 /// * Server rejects restoration parameters
4087 ///
4088 /// # Example
4089 ///
4090 /// ```no_run
4091 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
4092 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4093 /// let client = Client::new()?.with_token_path(None)?;
4094 /// let project_id = ProjectID::from(1);
4095 /// let snapshot_id = SnapshotID::from(123);
4096 ///
4097 /// // Restore MCAP with AGTG for "person" and "car" detection
4098 /// let result = client
4099 /// .restore_snapshot(
4100 /// project_id,
4101 /// snapshot_id,
4102 /// &[], // All topics
4103 /// &["person".to_string(), "car".to_string()], // AGTG labels
4104 /// true, // Auto-depth
4105 /// Some("Highway Dataset"),
4106 /// Some("Collected on I-95"),
4107 /// )
4108 /// .await?;
4109 /// println!("Restored to dataset: {:?}", result.dataset_id);
4110 /// # Ok(())
4111 /// # }
4112 /// ```
4113 ///
4114 /// # See Also
4115 ///
4116 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4117 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
4118 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
4119 #[allow(clippy::too_many_arguments)]
4120 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4121 pub async fn restore_snapshot(
4122 &self,
4123 project_id: ProjectID,
4124 snapshot_id: SnapshotID,
4125 topics: &[String],
4126 autolabel: &[String],
4127 autodepth: bool,
4128 dataset_name: Option<&str>,
4129 dataset_description: Option<&str>,
4130 ) -> Result<SnapshotRestoreResult, Error> {
4131 let params = SnapshotRestore {
4132 project_id,
4133 snapshot_id,
4134 fps: 1,
4135 autodepth,
4136 agtg_pipeline: !autolabel.is_empty(),
4137 autolabel: autolabel.to_vec(),
4138 topics: topics.to_vec(),
4139 dataset_name: dataset_name.map(|s| s.to_owned()),
4140 dataset_description: dataset_description.map(|s| s.to_owned()),
4141 };
4142 self.rpc("snapshots.restore".to_owned(), Some(params)).await
4143 }
4144
4145 /// Returns a list of experiments available to the user. The experiments
4146 /// are returned as a vector of Experiment objects. If name is provided
4147 /// then only experiments containing this string are returned.
4148 ///
4149 /// Results are sorted by match quality: exact matches first, then
4150 /// case-insensitive exact matches, then shorter names (more specific),
4151 /// then alphabetically.
4152 ///
4153 /// Experiments provide a method of organizing training and validation
4154 /// sessions together and are akin to an Experiment in MLFlow terminology.
4155 /// Each experiment can have multiple trainer sessions associated with it,
4156 /// these would be akin to runs in MLFlow terminology.
4157 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4158 pub async fn experiments(
4159 &self,
4160 project_id: ProjectID,
4161 name: Option<&str>,
4162 ) -> Result<Vec<Experiment>, Error> {
4163 let params = HashMap::from([("project_id", project_id)]);
4164 let experiments: Vec<Experiment> =
4165 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
4166 if let Some(name) = name {
4167 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
4168 } else {
4169 Ok(experiments)
4170 }
4171 }
4172
4173 /// Return the experiment with the specified experiment ID. If the
4174 /// experiment does not exist, an error is returned.
4175 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4176 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
4177 let params = HashMap::from([("trainer_id", experiment_id)]);
4178 self.rpc("trainer.get".to_owned(), Some(params)).await
4179 }
4180
4181 /// Returns a list of trainer sessions available to the user. The trainer
4182 /// sessions are returned as a vector of TrainingSession objects. If name
4183 /// is provided then only trainer sessions containing this string are
4184 /// returned.
4185 ///
4186 /// Results are sorted by match quality: exact matches first, then
4187 /// case-insensitive exact matches, then shorter names (more specific),
4188 /// then alphabetically.
4189 ///
4190 /// Trainer sessions are akin to runs in MLFlow terminology. These
4191 /// represent an actual training session which will produce metrics and
4192 /// model artifacts.
4193 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4194 pub async fn training_sessions(
4195 &self,
4196 experiment_id: ExperimentID,
4197 name: Option<&str>,
4198 ) -> Result<Vec<TrainingSession>, Error> {
4199 let params = HashMap::from([("trainer_id", experiment_id)]);
4200 let sessions: Vec<TrainingSession> = self
4201 .rpc("trainer.session.list".to_owned(), Some(params))
4202 .await?;
4203 if let Some(name) = name {
4204 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
4205 } else {
4206 Ok(sessions)
4207 }
4208 }
4209
4210 /// Return the trainer session with the specified trainer session ID. If
4211 /// the trainer session does not exist, an error is returned.
4212 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4213 pub async fn training_session(
4214 &self,
4215 session_id: TrainingSessionID,
4216 ) -> Result<TrainingSession, Error> {
4217 let params = HashMap::from([("trainer_session_id", session_id)]);
4218 self.rpc("trainer.session.get".to_owned(), Some(params))
4219 .await
4220 }
4221
4222 /// List validation sessions for the given project.
4223 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4224 pub async fn validation_sessions(
4225 &self,
4226 project_id: ProjectID,
4227 ) -> Result<Vec<ValidationSession>, Error> {
4228 let params = HashMap::from([("project_id", project_id)]);
4229 self.rpc("validate.session.list".to_owned(), Some(params))
4230 .await
4231 }
4232
4233 /// Retrieve a specific validation session.
4234 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4235 pub async fn validation_session(
4236 &self,
4237 session_id: ValidationSessionID,
4238 ) -> Result<ValidationSession, Error> {
4239 let params = HashMap::from([("validate_session_id", session_id)]);
4240 self.rpc("validate.session.get".to_owned(), Some(params))
4241 .await
4242 }
4243
4244 /// Create a new validation session via Studio's `cloud.server.start`.
4245 ///
4246 /// Pass `is_local: true` in the [`StartValidationRequest`] to create
4247 /// a **user-managed** session: the database row is created and the
4248 /// session is fully usable for data uploads / downloads / metrics,
4249 /// but no EC2 instance is provisioned and no automated validator
4250 /// pipeline is started. That is the mode our integration tests use
4251 /// — they create a session, exercise the wrapper APIs against it,
4252 /// then call [`Client::delete_validation_sessions`] in teardown so
4253 /// no stray sessions accumulate on the test account.
4254 ///
4255 /// Returns a [`NewValidationSession`] carrying the backing task id
4256 /// and the freshly-minted validation session id.
4257 ///
4258 /// # Errors
4259 ///
4260 /// Surfaces any RPC error from `cloud.server.start`. Common cases:
4261 /// `RpcError(101, …)` if a required entity is missing (project,
4262 /// training session, dataset, …); `PermissionDenied` if the caller
4263 /// can't write to the target project.
4264 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, req)))]
4265 pub async fn start_validation_session(
4266 &self,
4267 req: StartValidationRequest,
4268 ) -> Result<NewValidationSession, Error> {
4269 // Build the params shape the server expects. `cloud.server.start`
4270 // is intentionally generic — different server types pull
4271 // different fields out of `params` — so we serialize manually to
4272 // match the JS frontend's call site verbatim (see
4273 // `dve-frontend/src/components/ValidationPage/StartValidatorModal.vue`).
4274 let mut body = serde_json::Map::new();
4275 body.insert(
4276 "type".into(),
4277 serde_json::Value::String("validation".into()),
4278 );
4279 body.insert("name".into(), serde_json::Value::String(req.name));
4280 body.insert("project_id".into(), serde_json::to_value(req.project_id)?);
4281 body.insert(
4282 "training_session_id".into(),
4283 serde_json::to_value(req.training_session_id)?,
4284 );
4285 body.insert(
4286 "model_file".into(),
4287 serde_json::Value::String(req.model_file),
4288 );
4289 body.insert("val_type".into(), serde_json::Value::String(req.val_type));
4290 body.insert("is_local".into(), serde_json::Value::Bool(req.is_local));
4291 body.insert(
4292 "is_kubernetes".into(),
4293 serde_json::Value::Bool(req.is_kubernetes),
4294 );
4295
4296 // `validate.session` reads its config from `params.params` (one
4297 // extra envelope level). The outer `params` wrapper is required
4298 // even when the inner map is empty.
4299 let inner = serde_json::to_value(req.params)?;
4300 let mut outer = serde_json::Map::new();
4301 outer.insert("params".into(), inner);
4302 body.insert("params".into(), serde_json::Value::Object(outer));
4303
4304 if let Some(d) = req.description {
4305 body.insert("description".into(), serde_json::Value::String(d));
4306 }
4307 if let Some(id) = req.dataset_id {
4308 body.insert("dataset_id".into(), serde_json::to_value(id)?);
4309 }
4310 if let Some(id) = req.annotation_set_id {
4311 body.insert("annotation_set_id".into(), serde_json::to_value(id)?);
4312 }
4313 if let Some(id) = req.snapshot_id {
4314 body.insert("snapshot_id".into(), serde_json::to_value(id)?);
4315 }
4316
4317 self.rpc("cloud.server.start".to_owned(), Some(body)).await
4318 }
4319
4320 /// Delete one or more validation sessions via
4321 /// `validate.session.delete`.
4322 ///
4323 /// Used by integration tests to tear down sessions they created
4324 /// with [`Client::start_validation_session`]; idempotent against
4325 /// already-deleted ids on the server side (the RPC accepts the
4326 /// list, deletes what it can, and surfaces an error only if none
4327 /// of the ids were resolvable).
4328 ///
4329 /// # Errors
4330 ///
4331 /// Surfaces any RPC error from `validate.session.delete`. A
4332 /// `PermissionDenied` indicates the caller lacks
4333 /// `TrainerWrite` on at least one of the listed sessions.
4334 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4335 pub async fn delete_validation_sessions(
4336 &self,
4337 session_ids: &[ValidationSessionID],
4338 ) -> Result<(), Error> {
4339 let mut body = serde_json::Map::new();
4340 body.insert("session_ids".into(), serde_json::to_value(session_ids)?);
4341 let _: serde_json::Value = self
4342 .rpc("validate.session.delete".to_owned(), Some(body))
4343 .await?;
4344 Ok(())
4345 }
4346
4347 /// List the artifacts for the specified trainer session. The artifacts
4348 /// are returned as a vector of strings.
4349 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4350 pub async fn artifacts(
4351 &self,
4352 training_session_id: TrainingSessionID,
4353 ) -> Result<Vec<Artifact>, Error> {
4354 let params = HashMap::from([("training_session_id", training_session_id)]);
4355 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
4356 .await
4357 }
4358
4359 /// Download the model artifact for the specified trainer session to the
4360 /// specified file path, if path is not provided it will be downloaded to
4361 /// the current directory with the same filename.
4362 ///
4363 /// # Progress
4364 ///
4365 /// Reports progress with `status: None` as file data is received. Progress
4366 /// unit is bytes downloaded. Total is determined from the HTTP
4367 /// Content-Length header (may be 0 if server doesn't provide it).
4368 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4369 pub async fn download_artifact(
4370 &self,
4371 training_session_id: TrainingSessionID,
4372 modelname: &str,
4373 filename: Option<PathBuf>,
4374 progress: Option<Sender<Progress>>,
4375 ) -> Result<(), Error> {
4376 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
4377 let resp = self
4378 .bulk_http
4379 .get(format!(
4380 "{}/download_model?training_session_id={}&file={}",
4381 self.url,
4382 training_session_id.value(),
4383 modelname
4384 ))
4385 .header("Authorization", format!("Bearer {}", self.token().await))
4386 .send()
4387 .await?;
4388 if !resp.status().is_success() {
4389 let err = resp.error_for_status_ref().unwrap_err();
4390 return Err(Error::HttpError(err));
4391 }
4392
4393 if let Some(parent) = filename.parent() {
4394 fs::create_dir_all(parent).await?;
4395 }
4396
4397 stream_response_to_file(resp, &filename, progress).await
4398 }
4399
4400 /// Download the model checkpoint associated with the specified trainer
4401 /// session to the specified file path, if path is not provided it will be
4402 /// downloaded to the current directory with the same filename.
4403 ///
4404 /// There is no API for listing checkpoints it is expected that trainers are
4405 /// aware of possible checkpoints and their names within the checkpoint
4406 /// folder on the server.
4407 ///
4408 /// # Progress
4409 ///
4410 /// Reports progress with `status: None` as file data is received. Progress
4411 /// unit is bytes downloaded. Total is determined from the HTTP
4412 /// Content-Length header (may be 0 if server doesn't provide it).
4413 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4414 pub async fn download_checkpoint(
4415 &self,
4416 training_session_id: TrainingSessionID,
4417 checkpoint: &str,
4418 filename: Option<PathBuf>,
4419 progress: Option<Sender<Progress>>,
4420 ) -> Result<(), Error> {
4421 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
4422 let resp = self
4423 .bulk_http
4424 .get(format!(
4425 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
4426 self.url,
4427 training_session_id.value(),
4428 checkpoint
4429 ))
4430 .header("Authorization", format!("Bearer {}", self.token().await))
4431 .send()
4432 .await?;
4433 if !resp.status().is_success() {
4434 let err = resp.error_for_status_ref().unwrap_err();
4435 return Err(Error::HttpError(err));
4436 }
4437
4438 if let Some(parent) = filename.parent() {
4439 fs::create_dir_all(parent).await?;
4440 }
4441
4442 stream_response_to_file(resp, &filename, progress).await
4443 }
4444
4445 /// Return a list of tasks for the current user.
4446 ///
4447 /// # Arguments
4448 ///
4449 /// * `name` - Optional filter for task name (client-side substring match)
4450 /// * `workflow` - Optional filter for workflow/task type. If provided,
4451 /// filters server-side by exact match. Valid values include: "trainer",
4452 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
4453 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
4454 /// "convertor", "twostage"
4455 /// * `status` - Optional filter for task status (e.g., "running",
4456 /// "complete", "error")
4457 /// * `manager` - Optional filter for task manager type (e.g., "aws",
4458 /// "user", "kubernetes")
4459 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4460 pub async fn tasks(
4461 &self,
4462 name: Option<&str>,
4463 workflow: Option<&str>,
4464 status: Option<&str>,
4465 manager: Option<&str>,
4466 ) -> Result<Vec<Task>, Error> {
4467 let mut params = TasksListParams {
4468 continue_token: None,
4469 types: workflow.map(|w| vec![w.to_owned()]),
4470 status: status.map(|s| vec![s.to_owned()]),
4471 manager: manager.map(|m| vec![m.to_owned()]),
4472 };
4473 let mut tasks = Vec::new();
4474
4475 loop {
4476 let result = self
4477 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
4478 .await?;
4479 tasks.extend(result.tasks);
4480
4481 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
4482 params.continue_token = None;
4483 } else {
4484 params.continue_token = result.continue_token;
4485 }
4486
4487 if params.continue_token.is_none() {
4488 break;
4489 }
4490 }
4491
4492 if let Some(name) = name {
4493 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4494 }
4495
4496 Ok(tasks)
4497 }
4498
4499 /// Submits a job (app run) to the server and returns the resulting `Job`
4500 /// record (which carries the linked task id alongside the cloud-batch
4501 /// metadata).
4502 ///
4503 /// # Arguments
4504 /// * `app_name` - The name of the registered app to run (e.g., `"edgefirst-validator"`).
4505 /// * `job_name` - A user-defined label for this run.
4506 /// * `env` - Environment variables passed to the job (string-string map).
4507 /// * `data` - Job input payload (e.g., session ids, parameters).
4508 ///
4509 /// # Returns
4510 /// The full `Job` record returned by the server (wraps the BK_BATCH object),
4511 /// including AWS Batch job ID, state, and the linked `task_id`. Callers that
4512 /// only need the task ID can call `.task_id()` on the returned `Job`.
4513 pub async fn job_run(
4514 &self,
4515 app_name: &str,
4516 job_name: &str,
4517 env: std::collections::HashMap<String, String>,
4518 data: std::collections::HashMap<String, crate::api::Parameter>,
4519 ) -> Result<crate::api::Job, Error> {
4520 let req = JobRunRequest {
4521 name: app_name.to_owned(),
4522 job_name: job_name.to_owned(),
4523 env,
4524 data,
4525 };
4526 let resp: crate::api::Job = match self.rpc("job.run".to_owned(), Some(&req)).await {
4527 Ok(r) => r,
4528 Err(Error::RpcError(code, msg)) => {
4529 return Err(map_rpc_error("job.run", code, msg, None));
4530 }
4531 Err(e) => return Err(e),
4532 };
4533 Ok(resp)
4534 }
4535
4536 /// Requests a running job task be stopped.
4537 ///
4538 /// Returns `Ok(())` if the stop request was accepted by the server. The
4539 /// task may still take time to fully terminate; poll `task_info` if you
4540 /// need to wait for shutdown.
4541 pub async fn job_stop(&self, task_id: crate::api::TaskID) -> Result<(), Error> {
4542 let req = JobStopRequest {
4543 task_id: task_id.value(),
4544 };
4545 // We don't care about the response body; deserialize as serde_json::Value.
4546 let _resp: serde_json::Value = match self.rpc("job.stop".to_owned(), Some(&req)).await {
4547 Ok(r) => r,
4548 Err(Error::RpcError(code, msg)) => {
4549 return Err(map_rpc_error("job.stop", code, msg, Some(task_id)));
4550 }
4551 Err(e) => return Err(e),
4552 };
4553 Ok(())
4554 }
4555
4556 /// Lists job (app-run) entries visible to the authenticated user.
4557 ///
4558 /// The server returns AWS Batch-wrapper entries (not bare `Task` objects),
4559 /// surfacing cloud-batch state (`RUNNING`/`SUCCEEDED`/...) and the linked
4560 /// `task_id`. Use `Job::task_id()` + `Client::task_info` to fetch the
4561 /// underlying task details.
4562 ///
4563 /// The server does not support server-side filters, so the optional
4564 /// `name` argument is applied client-side as a substring match against
4565 /// each job's `job_name`.
4566 pub async fn jobs(&self, name: Option<&str>) -> Result<Vec<crate::api::Job>, Error> {
4567 let req = JobsListRequest {};
4568 let mut jobs: Vec<crate::api::Job> = match self.rpc("job.list".to_owned(), Some(&req)).await
4569 {
4570 Ok(r) => r,
4571 Err(Error::RpcError(code, msg)) => {
4572 return Err(map_rpc_error("job.list", code, msg, None));
4573 }
4574 Err(e) => return Err(e),
4575 };
4576 if let Some(name) = name {
4577 let needle = name.to_lowercase();
4578 jobs.retain(|j| j.job_name.to_lowercase().contains(&needle));
4579 jobs.sort_by(|a, b| a.job_name.cmp(&b.job_name));
4580 }
4581 Ok(jobs)
4582 }
4583
4584 /// Retrieve the task information and status.
4585 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4586 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4587 self.rpc(
4588 "task.get".to_owned(),
4589 Some(HashMap::from([("id", task_id)])),
4590 )
4591 .await
4592 }
4593
4594 /// Updates the tasks status.
4595 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4596 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4597 let status = TaskStatus {
4598 task_id,
4599 status: status.to_owned(),
4600 };
4601 self.rpc("docker.update.status".to_owned(), Some(status))
4602 .await
4603 }
4604
4605 /// Defines the stages for the task. The stages are defined as a mapping
4606 /// from stage names to their descriptions. Once stages are defined their
4607 /// status can be updated using the update_stage method.
4608 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4609 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4610 let stages: Vec<HashMap<String, String>> = stages
4611 .iter()
4612 .map(|(key, value)| {
4613 let mut stage_map = HashMap::new();
4614 stage_map.insert(key.to_string(), value.to_string());
4615 stage_map
4616 })
4617 .collect();
4618 let params = TaskStages { task_id, stages };
4619 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4620 Ok(())
4621 }
4622
4623 /// Updates the progress of the task for the provided stage and status
4624 /// information.
4625 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4626 pub async fn update_stage(
4627 &self,
4628 task_id: TaskID,
4629 stage: &str,
4630 status: &str,
4631 message: &str,
4632 percentage: u8,
4633 ) -> Result<(), Error> {
4634 let stage = Stage::new(
4635 Some(task_id),
4636 stage.to_owned(),
4637 Some(status.to_owned()),
4638 Some(message.to_owned()),
4639 percentage,
4640 );
4641 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4642 Ok(())
4643 }
4644
4645 /// Authenticated fetch from the Studio server using the bulk HTTP client
4646 /// (no total-request timeout; idle read timeout per chunk).
4647 ///
4648 /// **Buffers the entire response body into memory.** Suitable for small to
4649 /// medium payloads. For very large binary downloads (multi-GB artifacts or
4650 /// checkpoints), prefer a streaming approach that writes directly to disk.
4651 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4652 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4653 let req = self
4654 .bulk_http
4655 .get(format!("{}/{}", self.url, query))
4656 .header("User-Agent", "EdgeFirst Client")
4657 .header("Authorization", format!("Bearer {}", self.token().await));
4658 let resp = req.send().await?;
4659
4660 if resp.status().is_success() {
4661 let body = resp.bytes().await?;
4662
4663 if log_enabled!(Level::Trace) {
4664 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
4665 }
4666
4667 Ok(body.to_vec())
4668 } else {
4669 let err = resp.error_for_status_ref().unwrap_err();
4670 Err(Error::HttpError(err))
4671 }
4672 }
4673
4674 /// Sends a multipart post request to the server. This is used by the
4675 /// upload and download APIs which do not use JSON-RPC but instead transfer
4676 /// files using multipart/form-data.
4677 ///
4678 /// The result field is deserialized as `serde_json::Value` rather than
4679 /// `String` because different server endpoints return different shapes —
4680 /// `val.data.upload` returns a plain string while `task.data.upload`
4681 /// returns an object `{"message":…,"path":…,"size":…}`. All current
4682 /// callers discard the return value so this is backwards-compatible.
4683 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
4684 pub async fn post_multipart(
4685 &self,
4686 method: &str,
4687 form: Form,
4688 ) -> Result<serde_json::Value, Error> {
4689 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4690 .ok()
4691 .and_then(|s| s.parse().ok())
4692 .unwrap_or(600u64);
4693
4694 let req = self
4695 .http
4696 .post(format!("{}/api?method={}", self.url, method))
4697 .header("Accept", "application/json")
4698 .header("User-Agent", "EdgeFirst Client")
4699 .header("Authorization", format!("Bearer {}", self.token().await))
4700 .timeout(Duration::from_secs(upload_timeout_secs))
4701 .multipart(form);
4702 let resp = req.send().await?;
4703
4704 if resp.status().is_success() {
4705 let body = resp.bytes().await?;
4706
4707 if log_enabled!(Level::Trace) {
4708 trace!(
4709 "POST Multipart Response: {}",
4710 String::from_utf8_lossy(&body)
4711 );
4712 }
4713
4714 let response: RpcResponse<serde_json::Value> = match serde_json::from_slice(&body) {
4715 Ok(response) => response,
4716 Err(err) => {
4717 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4718 return Err(err.into());
4719 }
4720 };
4721
4722 if let Some(error) = response.error {
4723 Err(Error::RpcError(error.code, error.message))
4724 } else if let Some(result) = response.result {
4725 Ok(result)
4726 } else {
4727 Err(Error::InvalidResponse)
4728 }
4729 } else {
4730 // HTTP-level failure on the multipart upload. Map 413 to the
4731 // typed `PayloadTooLarge` variant so callers see the same error
4732 // type from both single-file rpc_download paths and multipart
4733 // upload paths; everything else falls through to HttpError.
4734 let status = resp.status();
4735 if status.as_u16() == 413 {
4736 return Err(Error::PayloadTooLarge {
4737 method: method.to_string(),
4738 size_hint: None,
4739 });
4740 }
4741 let err = resp.error_for_status_ref().unwrap_err();
4742 Err(Error::HttpError(err))
4743 }
4744 }
4745
4746 /// Internal helper: POST a JSON-RPC request and stream the binary response
4747 /// to `output_path`. The response is assumed to be raw binary (not a JSON
4748 /// envelope). Use for endpoints that return file contents directly.
4749 ///
4750 /// On HTTP non-success, the response body is read as text and surfaced
4751 /// via `Error::RpcError(status_code, body)`.
4752 pub(crate) async fn rpc_download<P: Serialize>(
4753 &self,
4754 method: &str,
4755 params: &P,
4756 output_path: &std::path::Path,
4757 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
4758 ) -> Result<(), Error> {
4759 let envelope = serde_json::json!({
4760 "jsonrpc": "2.0",
4761 "id": 0,
4762 "method": method,
4763 "params": params,
4764 });
4765
4766 let url = format!("{}/api", self.url);
4767 let resp = self
4768 .bulk_http
4769 .post(&url)
4770 .header("Authorization", format!("Bearer {}", self.token().await))
4771 .json(&envelope)
4772 .send()
4773 .await?;
4774
4775 let status = resp.status();
4776 if !status.is_success() {
4777 if status.as_u16() == 413 {
4778 return Err(Error::PayloadTooLarge {
4779 method: method.to_string(),
4780 size_hint: None,
4781 });
4782 }
4783 let body = resp.text().await.unwrap_or_default();
4784 return Err(Error::RpcError(status.as_u16() as i32, body));
4785 }
4786
4787 // HTTP 200 with Content-Type: application/json can mean two things:
4788 // (a) a JSON-RPC error envelope when the server failed mid-way
4789 // (e.g. {"jsonrpc":"2.0","error":{"code":N,"message":"..."}}),
4790 // (b) a legitimate JSON file payload — validation traces, chart
4791 // bodies, metrics, etc., are typically served with this MIME.
4792 //
4793 // Disambiguate structurally: a JSON-RPC 2.0 envelope is required to
4794 // carry a `jsonrpc` member, and an *error* envelope further requires
4795 // an `error.code` integer (per RFC 8259 + JSON-RPC 2.0 §5). Only
4796 // decode the body as an error if both markers are present. This is
4797 // strict enough to leave legitimate JSON artifacts that happen to
4798 // contain a free-form `error` field (metrics, diagnostics, log
4799 // dumps) untouched, while still catching every real server
4800 // failure.
4801 let content_type = resp
4802 .headers()
4803 .get(reqwest::header::CONTENT_TYPE)
4804 .and_then(|v| v.to_str().ok())
4805 .unwrap_or("")
4806 .to_owned();
4807 if content_type.contains("application/json") {
4808 let body = resp.bytes().await?;
4809 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&body)
4810 && is_jsonrpc_error_envelope(&val)
4811 && let Some(err_obj) = val.get("error")
4812 {
4813 let code = err_obj.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
4814 let message = err_obj
4815 .get("message")
4816 .and_then(|m| m.as_str())
4817 .unwrap_or("unknown error")
4818 .to_string();
4819 return Err(Error::RpcError(code, message));
4820 }
4821 // Not an error envelope — body is a JSON file. Write it to disk
4822 // and emit a single completion progress event so callers (e.g.,
4823 // Python download_data progress callbacks) see the download
4824 // finish.
4825 //
4826 // `Path::parent` returns `Some("")` for a bare filename like
4827 // "metrics.json"; `create_dir_all("")` errors out with
4828 // `NotFound`, so only create the parent when it actually names
4829 // a directory.
4830 if let Some(parent) = output_path.parent()
4831 && !parent.as_os_str().is_empty()
4832 {
4833 tokio::fs::create_dir_all(parent).await?;
4834 }
4835 let mut file = tokio::fs::File::create(output_path).await?;
4836 file.write_all(&body).await?;
4837 file.flush().await?;
4838 if let Some(tx) = progress {
4839 let total = body.len();
4840 // Use the awaited send for the final event so completion
4841 // handlers are never silently dropped.
4842 let _ = tx
4843 .send(Progress {
4844 current: total,
4845 total,
4846 status: None,
4847 })
4848 .await;
4849 }
4850 return Ok(());
4851 }
4852
4853 // Same empty-parent guard for the streaming download path: passing
4854 // a bare filename like "metrics.json" must write to the current
4855 // directory rather than failing on `create_dir_all("")`.
4856 if let Some(parent) = output_path.parent()
4857 && !parent.as_os_str().is_empty()
4858 {
4859 tokio::fs::create_dir_all(parent).await?;
4860 }
4861
4862 stream_response_to_file(resp, output_path, progress).await
4863 }
4864
4865 /// Send a JSON-RPC request to the server. The method is the name of the
4866 /// method to call on the server. The params are the parameters to pass to
4867 /// the method. The method and params are serialized into a JSON-RPC
4868 /// request and sent to the server. The response is deserialized into
4869 /// the specified type and returned to the caller.
4870 ///
4871 /// NOTE: This API would generally not be called directly and instead users
4872 /// should use the higher-level methods provided by the client.
4873 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
4874 pub async fn rpc<Params, RpcResult>(
4875 &self,
4876 method: String,
4877 params: Option<Params>,
4878 ) -> Result<RpcResult, Error>
4879 where
4880 Params: Serialize,
4881 RpcResult: DeserializeOwned,
4882 {
4883 let auth_expires = self.token_expiration().await?;
4884 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
4885 self.renew_token().await?;
4886 }
4887
4888 self.rpc_without_auth(method, params).await
4889 }
4890
4891 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
4892 async fn rpc_without_auth<Params, RpcResult>(
4893 &self,
4894 method: String,
4895 params: Option<Params>,
4896 ) -> Result<RpcResult, Error>
4897 where
4898 Params: Serialize,
4899 RpcResult: DeserializeOwned,
4900 {
4901 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4902 .ok()
4903 .and_then(|s| s.parse().ok())
4904 .unwrap_or(5usize);
4905
4906 let url = format!("{}/api", self.url);
4907
4908 // Serialize request body once before retry loop to avoid Clone bound on Params
4909 let request = RpcRequest {
4910 method: method.clone(),
4911 params,
4912 ..Default::default()
4913 };
4914
4915 // Log request for debugging (log crate) and profiling (tracing crate)
4916 let request_json = if method == "auth.login" {
4917 // Redact auth.login params (contains password)
4918 serde_json::json!({
4919 "jsonrpc": "2.0",
4920 "method": &method,
4921 "params": "[REDACTED - contains credentials]",
4922 "id": request.id
4923 })
4924 .to_string()
4925 } else {
4926 serde_json::to_string(&request)?
4927 };
4928
4929 if log_enabled!(Level::Trace) {
4930 trace!("RPC Request: {}", request_json);
4931 }
4932
4933 // Record request on current span for Perfetto when profiling is enabled
4934 #[cfg(feature = "profiling")]
4935 tracing::Span::current().record("request", &request_json);
4936
4937 let request_body = serde_json::to_vec(&request)?;
4938 let mut last_error: Option<Error> = None;
4939
4940 for attempt in 0..=max_retries {
4941 if attempt > 0 {
4942 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
4943 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
4944 // while ensuring we never retry faster than the base delay
4945 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
4946 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
4947 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
4948 let delay = Duration::from_millis(delay_ms);
4949 warn!(
4950 "Retry {}/{} for RPC '{}' after {:?}",
4951 attempt, max_retries, method, delay
4952 );
4953 tokio::time::sleep(delay).await;
4954 }
4955
4956 let result = self
4957 .http
4958 .post(&url)
4959 .header("Accept", "application/json")
4960 .header("Content-Type", "application/json")
4961 .header("User-Agent", "EdgeFirst Client")
4962 .header("Authorization", format!("Bearer {}", self.token().await))
4963 .body(request_body.clone())
4964 .send()
4965 .await;
4966
4967 match result {
4968 Ok(res) => {
4969 let status = res.status();
4970 let status_code = status.as_u16();
4971
4972 // Check for retryable HTTP status codes before processing response
4973 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
4974 && attempt < max_retries
4975 {
4976 warn!(
4977 "RPC '{}' failed with HTTP {} (retrying)",
4978 method, status_code
4979 );
4980 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
4981 continue;
4982 }
4983
4984 // Process the response
4985 match self.process_rpc_response(res).await {
4986 Ok(result) => {
4987 if attempt > 0 {
4988 debug!("RPC '{}' succeeded on retry {}", method, attempt);
4989 }
4990 return Ok(result);
4991 }
4992 Err(e) => {
4993 // Don't retry client errors (4xx except 408, 429)
4994 if attempt > 0 {
4995 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4996 }
4997 return Err(e);
4998 }
4999 }
5000 }
5001 Err(e) => {
5002 // Transport error (timeout, connection failure, etc.)
5003 let is_timeout = e.is_timeout();
5004 let is_connect = e.is_connect();
5005
5006 if (is_timeout || is_connect) && attempt < max_retries {
5007 warn!(
5008 "RPC '{}' transport error (retrying): {}",
5009 method,
5010 if is_timeout {
5011 "timeout"
5012 } else {
5013 "connection failed"
5014 }
5015 );
5016 last_error = Some(Error::HttpError(e));
5017 continue;
5018 }
5019
5020 if attempt > 0 {
5021 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
5022 }
5023 return Err(Error::HttpError(e));
5024 }
5025 }
5026 }
5027
5028 // Should not reach here
5029 Err(last_error.unwrap_or_else(|| {
5030 Error::InvalidParameters(format!(
5031 "RPC '{}' failed after {} retries",
5032 method, max_retries
5033 ))
5034 }))
5035 }
5036
5037 async fn process_rpc_response<RpcResult>(
5038 &self,
5039 res: reqwest::Response,
5040 ) -> Result<RpcResult, Error>
5041 where
5042 RpcResult: DeserializeOwned,
5043 {
5044 let body = res.bytes().await?;
5045 let response_str = String::from_utf8_lossy(&body);
5046
5047 if log_enabled!(Level::Trace) {
5048 trace!("RPC Response: {}", response_str);
5049 }
5050
5051 // Record response on current span for Perfetto when profiling is enabled
5052 // Truncate large responses to avoid bloating trace files
5053 #[cfg(feature = "profiling")]
5054 {
5055 const MAX_RESPONSE_LEN: usize = 4096;
5056 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
5057 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
5058 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
5059 format!(
5060 "{}...[truncated {} bytes]",
5061 &response_str[..safe_end],
5062 response_str.len() - safe_end
5063 )
5064 } else {
5065 response_str.to_string()
5066 };
5067 tracing::Span::current().record("response", &truncated);
5068 }
5069
5070 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
5071 Ok(response) => response,
5072 Err(err) => {
5073 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
5074 return Err(err.into());
5075 }
5076 };
5077
5078 // FIXME: Studio Server always returns 999 as the id.
5079 // if request.id.to_string() != response.id {
5080 // return Err(Error::InvalidRpcId(response.id));
5081 // }
5082
5083 if let Some(error) = response.error {
5084 Err(Error::RpcError(error.code, error.message))
5085 } else if let Some(result) = response.result {
5086 Ok(result)
5087 } else {
5088 Err(Error::InvalidResponse)
5089 }
5090 }
5091}
5092
5093/// Process items in parallel with semaphore concurrency control and progress
5094/// tracking.
5095///
5096/// This helper eliminates boilerplate for parallel item processing with:
5097/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
5098/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
5099/// - Atomic progress counter with automatic item-level updates
5100/// - Progress updates sent after each item completes (not byte-level streaming)
5101/// - Proper error propagation from spawned tasks
5102///
5103/// Note: This is optimized for discrete items with post-completion progress
5104/// updates. For byte-level streaming progress or custom retry logic, use
5105/// specialized implementations.
5106///
5107/// # Arguments
5108///
5109/// * `items` - Collection of items to process in parallel
5110/// * `progress` - Optional progress channel for tracking completion
5111/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
5112/// * `work_fn` - Async function to execute for each item
5113///
5114/// # Examples
5115///
5116/// ```rust,ignore
5117/// // Use default concurrency
5118/// parallel_foreach_items(samples, progress, None, |sample| async move {
5119/// sample.download(&client, file_type).await?;
5120/// Ok(())
5121/// }).await?;
5122/// ```
5123async fn parallel_foreach_items<T, F, Fut>(
5124 items: Vec<T>,
5125 progress: Option<Sender<Progress>>,
5126 concurrency: Option<usize>,
5127 work_fn: F,
5128) -> Result<(), Error>
5129where
5130 T: Send + 'static,
5131 F: Fn(T) -> Fut + Send + Sync + 'static,
5132 Fut: Future<Output = Result<(), Error>> + Send + 'static,
5133{
5134 let total = items.len();
5135 let current = Arc::new(AtomicUsize::new(0));
5136 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
5137 let work_fn = Arc::new(work_fn);
5138
5139 let tasks = items
5140 .into_iter()
5141 .map(|item| {
5142 let sem = sem.clone();
5143 let current = current.clone();
5144 let progress = progress.clone();
5145 let work_fn = work_fn.clone();
5146
5147 tokio::spawn(async move {
5148 let _permit = sem.acquire().await.map_err(|_| {
5149 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5150 })?;
5151
5152 // Execute the actual work
5153 work_fn(item).await?;
5154
5155 // Update progress
5156 if let Some(progress) = &progress {
5157 let current = current.fetch_add(1, Ordering::SeqCst);
5158 let _ = progress
5159 .send(Progress {
5160 current: current + 1,
5161 total,
5162 status: None,
5163 })
5164 .await;
5165 }
5166
5167 Ok::<(), Error>(())
5168 })
5169 })
5170 .collect::<Vec<_>>();
5171
5172 join_all(tasks)
5173 .await
5174 .into_iter()
5175 .collect::<Result<Vec<_>, _>>()?
5176 .into_iter()
5177 .collect::<Result<Vec<_>, _>>()?;
5178
5179 if let Some(progress) = progress {
5180 drop(progress);
5181 }
5182
5183 Ok(())
5184}
5185
5186/// Upload a file to S3 using multipart upload with presigned URLs.
5187///
5188/// Splits a file into chunks (100MB each) and uploads them in parallel using
5189/// S3 multipart upload protocol. Returns completion parameters with ETags for
5190/// finalizing the upload.
5191///
5192/// This function handles:
5193/// - Splitting files into parts based on PART_SIZE (100MB)
5194/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
5195/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
5196/// - Retry logic (handled by reqwest client)
5197/// - Progress tracking across all parts
5198///
5199/// # Arguments
5200///
5201/// * `http` - HTTP client for making requests
5202/// * `part` - Snapshot part info with presigned URLs for each chunk
5203/// * `path` - Local file path to upload
5204/// * `total` - Total bytes across all files for progress calculation
5205/// * `current` - Atomic counter tracking bytes uploaded across all operations
5206/// * `progress` - Optional channel for sending progress updates
5207///
5208/// # Returns
5209///
5210/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
5211async fn upload_multipart(
5212 http: reqwest::Client,
5213 part: SnapshotPart,
5214 path: PathBuf,
5215 total: usize,
5216 confirmed_bytes: Arc<AtomicUsize>,
5217 progress: Option<Sender<Progress>>,
5218) -> Result<SnapshotCompleteMultipartParams, Error> {
5219 let filesize = path.metadata()?.len() as usize;
5220 let n_parts = filesize.div_ceil(PART_SIZE);
5221 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
5222
5223 let key = part.key.ok_or(Error::InvalidResponse)?;
5224 let upload_id = part.upload_id;
5225
5226 let urls = part.urls.clone();
5227
5228 // Pre-allocate ETag slots for all parts
5229 let etags = Arc::new(tokio::sync::Mutex::new(vec![
5230 EtagPart {
5231 etag: "".to_owned(),
5232 part_number: 0,
5233 };
5234 n_parts
5235 ]));
5236
5237 // Per-part byte counters for streaming progress (reset on retry)
5238 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
5239 (0..n_parts)
5240 .map(|_| AtomicUsize::new(0))
5241 .collect::<Vec<_>>(),
5242 );
5243
5244 // Upload all parts in parallel with concurrency limiting
5245 let tasks = (0..n_parts)
5246 .map(|part_idx| {
5247 let http = http.clone();
5248 let url = urls[part_idx].clone();
5249 let etags = etags.clone();
5250 let path = path.to_owned();
5251 let sem = sem.clone();
5252 let progress = progress.clone();
5253 let confirmed_bytes = confirmed_bytes.clone();
5254 let part_bytes = part_bytes.clone();
5255
5256 // Calculate this part's size
5257 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5258 filesize % PART_SIZE
5259 } else {
5260 PART_SIZE
5261 };
5262
5263 tokio::spawn(async move {
5264 // Acquire semaphore permit to limit concurrent uploads
5265 let _permit = sem.acquire().await.map_err(|_| {
5266 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5267 })?;
5268
5269 // Upload part with streaming progress and retry logic
5270 let etag = upload_part_with_progress(
5271 http,
5272 url,
5273 path,
5274 part_idx,
5275 n_parts,
5276 part_size,
5277 total,
5278 confirmed_bytes.clone(),
5279 part_bytes.clone(),
5280 progress.clone(),
5281 )
5282 .await?;
5283
5284 // Store ETag for this part (needed to complete multipart upload)
5285 let mut etags_guard = etags.lock().await;
5286 etags_guard[part_idx] = EtagPart {
5287 etag,
5288 part_number: part_idx + 1,
5289 };
5290
5291 // Part completed successfully - add to confirmed bytes
5292 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
5293 // Reset part counter since it's now confirmed
5294 part_bytes[part_idx].store(0, Ordering::SeqCst);
5295
5296 // Send final progress update for this part
5297 if let Some(progress) = &progress {
5298 let current = confirmed_bytes.load(Ordering::SeqCst)
5299 + part_bytes
5300 .iter()
5301 .map(|p| p.load(Ordering::SeqCst))
5302 .sum::<usize>();
5303 let _ = progress
5304 .send(Progress {
5305 current,
5306 total,
5307 status: None,
5308 })
5309 .await;
5310 }
5311
5312 Ok::<(), Error>(())
5313 })
5314 })
5315 .collect::<Vec<_>>();
5316
5317 // Wait for all parts to complete (double collect to handle both JoinError and
5318 // inner Error)
5319 join_all(tasks)
5320 .await
5321 .into_iter()
5322 .collect::<Result<Vec<_>, _>>()?
5323 .into_iter()
5324 .collect::<Result<Vec<_>, _>>()?;
5325
5326 Ok(SnapshotCompleteMultipartParams {
5327 key,
5328 upload_id,
5329 etag_list: etags.lock().await.clone(),
5330 })
5331}
5332
5333/// Upload a single part with streaming progress tracking and retry logic.
5334///
5335/// Progress is reported continuously as bytes are sent. On retry, the part's
5336/// progress counter is reset to avoid over-reporting.
5337#[allow(clippy::too_many_arguments)]
5338async fn upload_part_with_progress(
5339 http: reqwest::Client,
5340 url: String,
5341 path: PathBuf,
5342 part_idx: usize,
5343 n_parts: usize,
5344 part_size: usize,
5345 total: usize,
5346 confirmed_bytes: Arc<AtomicUsize>,
5347 part_bytes: Arc<Vec<AtomicUsize>>,
5348 progress: Option<Sender<Progress>>,
5349) -> Result<String, Error> {
5350 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5351 .ok()
5352 .and_then(|s| s.parse().ok())
5353 .unwrap_or(5usize);
5354
5355 // Per-part total upload timeout. Covers the send phase (request body) where
5356 // read_timeout does not apply. Each part is at most PART_SIZE (100MB), so
5357 // this bounds how long a stalled upload can block before retrying.
5358 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5359 .ok()
5360 .and_then(|s| s.parse().ok())
5361 .unwrap_or(600u64); // 600s = 100MB at ~170 KB/s minimum
5362
5363 let mut last_error: Option<Error> = None;
5364
5365 for attempt in 0..=max_retries {
5366 if attempt > 0 {
5367 // Reset this part's progress counter before retry
5368 part_bytes[part_idx].store(0, Ordering::SeqCst);
5369
5370 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5371 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5372 warn!(
5373 "Retry {}/{} for part {} after {:?}",
5374 attempt, max_retries, part_idx, delay
5375 );
5376 tokio::time::sleep(delay).await;
5377 }
5378
5379 match upload_part_streaming(
5380 http.clone(),
5381 url.clone(),
5382 path.clone(),
5383 part_idx,
5384 n_parts,
5385 part_size,
5386 total,
5387 upload_timeout_secs,
5388 confirmed_bytes.clone(),
5389 part_bytes.clone(),
5390 progress.clone(),
5391 )
5392 .await
5393 {
5394 Ok(etag) => return Ok(etag),
5395 Err(e) => {
5396 // Check if error is retryable
5397 let is_retryable = matches!(
5398 &e,
5399 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
5400 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
5401 );
5402
5403 if is_retryable && attempt < max_retries {
5404 last_error = Some(e);
5405 continue;
5406 }
5407
5408 return Err(e);
5409 }
5410 }
5411 }
5412
5413 Err(last_error
5414 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
5415}
5416
5417/// Perform the actual upload with streaming progress.
5418#[allow(clippy::too_many_arguments)]
5419async fn upload_part_streaming(
5420 http: reqwest::Client,
5421 url: String,
5422 path: PathBuf,
5423 part_idx: usize,
5424 n_parts: usize,
5425 _part_size: usize,
5426 total: usize,
5427 upload_timeout_secs: u64,
5428 confirmed_bytes: Arc<AtomicUsize>,
5429 part_bytes: Arc<Vec<AtomicUsize>>,
5430 progress: Option<Sender<Progress>>,
5431) -> Result<String, Error> {
5432 let filesize = path.metadata()?.len() as usize;
5433 let mut file = File::open(&path).await?;
5434 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
5435 .await?;
5436 let file = file.take(PART_SIZE as u64);
5437
5438 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5439 filesize % PART_SIZE
5440 } else {
5441 PART_SIZE
5442 };
5443
5444 // Create stream with progress tracking
5445 let stream = FramedRead::new(file, BytesCodec::new());
5446
5447 // Wrap stream to track bytes sent and report progress
5448 let progress_stream = stream.map(move |result| {
5449 if let Ok(ref bytes) = result {
5450 let bytes_len = bytes.len();
5451 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
5452
5453 // Send progress update (fire-and-forget via try_send to avoid blocking)
5454 if let Some(ref progress) = progress {
5455 let current = confirmed_bytes.load(Ordering::SeqCst)
5456 + part_bytes
5457 .iter()
5458 .map(|p| p.load(Ordering::SeqCst))
5459 .sum::<usize>();
5460 // Best-effort progress reporting: use try_send to avoid blocking.
5461 // If the channel is full or closed, we intentionally skip this update
5462 // to avoid stalling the upload; subsequent updates will still be delivered.
5463 let _ = progress.try_send(Progress {
5464 current,
5465 total,
5466 status: None,
5467 });
5468 }
5469 }
5470 result.map(|b| b.freeze())
5471 });
5472
5473 let body = Body::wrap_stream(progress_stream);
5474
5475 let resp = http
5476 .put(url)
5477 .header(CONTENT_LENGTH, body_length)
5478 .timeout(Duration::from_secs(upload_timeout_secs))
5479 .body(body)
5480 .send()
5481 .await?
5482 .error_for_status()?;
5483
5484 let etag = resp
5485 .headers()
5486 .get("etag")
5487 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
5488 .to_str()
5489 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
5490 .to_owned();
5491
5492 // Studio Server requires etag without the quotes.
5493 let etag = etag
5494 .strip_prefix("\"")
5495 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
5496 let etag = etag
5497 .strip_suffix("\"")
5498 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
5499
5500 Ok(etag.to_owned())
5501}
5502
5503/// Upload a complete file to a presigned S3 URL using HTTP PUT.
5504///
5505/// This is used for populate_samples to upload files to S3 after
5506/// receiving presigned URLs from the server.
5507///
5508/// Includes explicit retry logic with exponential backoff for transient
5509/// failures.
5510/// Classify a reqwest transport error (one where no HTTP response was received)
5511/// as a transient failure worth retrying.
5512///
5513/// Presigned-URL uploads buffer the body in memory and a PUT to the same object
5514/// key is idempotent, so replaying any transport-level failure is safe. Besides
5515/// timeouts and connect failures this covers request/body send errors such as
5516/// hyper's `IncompleteMessage` (a peer closing a keep-alive connection mid-send)
5517/// — transients that pipelined, high-concurrency uploads provoke far more often
5518/// than serial ones, and which the previous `is_timeout() || is_connect()` gate
5519/// missed (aborting the whole upload on a single blip).
5520fn is_retryable_upload_error(e: &reqwest::Error) -> bool {
5521 e.is_timeout() || e.is_connect() || e.is_request() || e.is_body()
5522}
5523
5524async fn upload_file_to_presigned_url(
5525 http: reqwest::Client,
5526 url: &str,
5527 path: PathBuf,
5528) -> Result<(), Error> {
5529 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5530 .ok()
5531 .and_then(|s| s.parse().ok())
5532 .unwrap_or(5usize);
5533
5534 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5535 .ok()
5536 .and_then(|s| s.parse().ok())
5537 .unwrap_or(600u64);
5538
5539 // Read the entire file into memory once
5540 let file_data = fs::read(&path).await?;
5541 let file_size = file_data.len();
5542 let filename = path.file_name().unwrap_or_default().to_string_lossy();
5543
5544 let mut last_error: Option<Error> = None;
5545
5546 for attempt in 0..=max_retries {
5547 if attempt > 0 {
5548 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5549 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5550 warn!(
5551 "Retry {}/{} for upload '{}' after {:?}",
5552 attempt, max_retries, filename, delay
5553 );
5554 tokio::time::sleep(delay).await;
5555 }
5556
5557 // Attempt upload
5558 let result = http
5559 .put(url)
5560 .header(CONTENT_LENGTH, file_size)
5561 .timeout(Duration::from_secs(upload_timeout_secs))
5562 .body(file_data.clone())
5563 .send()
5564 .await;
5565
5566 match result {
5567 Ok(resp) => {
5568 if resp.status().is_success() {
5569 if attempt > 0 {
5570 debug!(
5571 "Upload '{}' succeeded on retry {} ({} bytes)",
5572 filename, attempt, file_size
5573 );
5574 } else {
5575 debug!(
5576 "Successfully uploaded file: {} ({} bytes)",
5577 filename, file_size
5578 );
5579 }
5580 return Ok(());
5581 }
5582
5583 let status = resp.status();
5584 let status_code = status.as_u16();
5585
5586 // Check if error is retryable
5587 let is_retryable =
5588 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5589
5590 if is_retryable && attempt < max_retries {
5591 let error_text = resp.text().await.unwrap_or_default();
5592 warn!(
5593 "Upload '{}' failed with HTTP {} (retryable): {}",
5594 filename, status_code, error_text
5595 );
5596 last_error = Some(Error::InvalidParameters(format!(
5597 "Upload failed: HTTP {} - {}",
5598 status, error_text
5599 )));
5600 continue;
5601 }
5602
5603 // Non-retryable error or max retries exceeded
5604 let error_text = resp.text().await.unwrap_or_default();
5605 if attempt > 0 {
5606 error!(
5607 "Upload '{}' failed after {} retries: HTTP {} - {}",
5608 filename, attempt, status, error_text
5609 );
5610 }
5611 return Err(Error::InvalidParameters(format!(
5612 "Upload failed: HTTP {} - {}",
5613 status, error_text
5614 )));
5615 }
5616 Err(e) => {
5617 // Transport error: no HTTP response was received. The body is
5618 // buffered in memory and the PUT is idempotent, so any transient
5619 // transport failure is safe to replay (see
5620 // `is_retryable_upload_error`).
5621 if is_retryable_upload_error(&e) && attempt < max_retries {
5622 warn!("Upload '{}' transport error (retrying): {}", filename, e);
5623 last_error = Some(Error::HttpError(e));
5624 continue;
5625 }
5626
5627 // Non-retryable or max retries exceeded
5628 if attempt > 0 {
5629 error!(
5630 "Upload '{}' failed after {} retries: {}",
5631 filename, attempt, e
5632 );
5633 }
5634 return Err(Error::HttpError(e));
5635 }
5636 }
5637 }
5638
5639 // Should not reach here, but return last error if we do
5640 Err(last_error.unwrap_or_else(|| {
5641 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5642 }))
5643}
5644
5645/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
5646///
5647/// This is used for populate_samples to upload file content from memory
5648/// (e.g., from ZIP archives) without writing to disk first.
5649///
5650/// Includes explicit retry logic with exponential backoff for transient
5651/// failures.
5652async fn upload_bytes_to_presigned_url(
5653 http: reqwest::Client,
5654 url: &str,
5655 file_data: Vec<u8>,
5656 filename: &str,
5657) -> Result<(), Error> {
5658 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5659 .ok()
5660 .and_then(|s| s.parse().ok())
5661 .unwrap_or(5usize);
5662
5663 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5664 .ok()
5665 .and_then(|s| s.parse().ok())
5666 .unwrap_or(600u64);
5667
5668 let file_size = file_data.len();
5669 let mut last_error: Option<Error> = None;
5670
5671 for attempt in 0..=max_retries {
5672 if attempt > 0 {
5673 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5674 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5675 warn!(
5676 "Retry {}/{} for upload '{}' after {:?}",
5677 attempt, max_retries, filename, delay
5678 );
5679 tokio::time::sleep(delay).await;
5680 }
5681
5682 // Attempt upload
5683 let result = http
5684 .put(url)
5685 .header(CONTENT_LENGTH, file_size)
5686 .timeout(Duration::from_secs(upload_timeout_secs))
5687 .body(file_data.clone())
5688 .send()
5689 .await;
5690
5691 match result {
5692 Ok(resp) => {
5693 if resp.status().is_success() {
5694 if attempt > 0 {
5695 debug!(
5696 "Upload '{}' succeeded on retry {} ({} bytes)",
5697 filename, attempt, file_size
5698 );
5699 } else {
5700 debug!(
5701 "Successfully uploaded file: {} ({} bytes)",
5702 filename, file_size
5703 );
5704 }
5705 return Ok(());
5706 }
5707
5708 let status = resp.status();
5709 let status_code = status.as_u16();
5710
5711 // Check if error is retryable
5712 let is_retryable =
5713 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5714
5715 if is_retryable && attempt < max_retries {
5716 let error_text = resp.text().await.unwrap_or_default();
5717 warn!(
5718 "Upload '{}' failed with HTTP {} (retryable): {}",
5719 filename, status_code, error_text
5720 );
5721 last_error = Some(Error::InvalidParameters(format!(
5722 "Upload failed: HTTP {} - {}",
5723 status, error_text
5724 )));
5725 continue;
5726 }
5727
5728 // Non-retryable error or max retries exceeded
5729 let error_text = resp.text().await.unwrap_or_default();
5730 if attempt > 0 {
5731 error!(
5732 "Upload '{}' failed after {} retries: HTTP {} - {}",
5733 filename, attempt, status, error_text
5734 );
5735 }
5736 return Err(Error::InvalidParameters(format!(
5737 "Upload failed: HTTP {} - {}",
5738 status, error_text
5739 )));
5740 }
5741 Err(e) => {
5742 // Transport error: no HTTP response was received. The body is
5743 // buffered in memory and the PUT is idempotent, so any transient
5744 // transport failure is safe to replay (see
5745 // `is_retryable_upload_error`).
5746 if is_retryable_upload_error(&e) && attempt < max_retries {
5747 warn!("Upload '{}' transport error (retrying): {}", filename, e);
5748 last_error = Some(Error::HttpError(e));
5749 continue;
5750 }
5751
5752 // Non-retryable or max retries exceeded
5753 if attempt > 0 {
5754 error!(
5755 "Upload '{}' failed after {} retries: {}",
5756 filename, attempt, e
5757 );
5758 }
5759 return Err(Error::HttpError(e));
5760 }
5761 }
5762 }
5763
5764 // Should not reach here, but return last error if we do
5765 Err(last_error.unwrap_or_else(|| {
5766 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5767 }))
5768}
5769
5770#[cfg(test)]
5771mod tests {
5772 use super::*;
5773
5774 #[test]
5775 fn test_filter_and_sort_by_name_exact_match_first() {
5776 // Test that exact matches come first
5777 let items = vec![
5778 "Deer Roundtrip 123".to_string(),
5779 "Deer".to_string(),
5780 "Reindeer".to_string(),
5781 "DEER".to_string(),
5782 ];
5783 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5784 assert_eq!(result[0], "Deer"); // Exact match first
5785 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
5786 }
5787
5788 #[test]
5789 fn test_filter_and_sort_by_name_shorter_names_preferred() {
5790 // Test that shorter names (more specific) come before longer ones
5791 let items = vec![
5792 "Test Dataset ABC".to_string(),
5793 "Test".to_string(),
5794 "Test Dataset".to_string(),
5795 ];
5796 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5797 assert_eq!(result[0], "Test"); // Exact match first
5798 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
5799 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
5800 }
5801
5802 #[test]
5803 fn test_filter_and_sort_by_name_case_insensitive_filter() {
5804 // Test that filtering is case-insensitive
5805 let items = vec![
5806 "UPPERCASE".to_string(),
5807 "lowercase".to_string(),
5808 "MixedCase".to_string(),
5809 ];
5810 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
5811 assert_eq!(result.len(), 3); // All items should match
5812 }
5813
5814 #[test]
5815 fn test_filter_and_sort_by_name_no_matches() {
5816 // Test that empty result is returned when no matches
5817 let items = vec!["Apple".to_string(), "Banana".to_string()];
5818 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
5819 assert!(result.is_empty());
5820 }
5821
5822 #[test]
5823 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
5824 // Test alphabetical ordering for same-length names
5825 let items = vec![
5826 "TestC".to_string(),
5827 "TestA".to_string(),
5828 "TestB".to_string(),
5829 ];
5830 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5831 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
5832 }
5833
5834 #[test]
5835 fn test_build_filename_no_flatten() {
5836 // When flatten=false, should return base_name unchanged
5837 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
5838 assert_eq!(result, "image.jpg");
5839
5840 let result = Client::build_filename("test.png", false, None, None);
5841 assert_eq!(result, "test.png");
5842 }
5843
5844 #[test]
5845 fn test_build_filename_flatten_no_sequence() {
5846 // When flatten=true but no sequence, should return base_name unchanged
5847 let result = Client::build_filename("standalone.jpg", true, None, None);
5848 assert_eq!(result, "standalone.jpg");
5849 }
5850
5851 #[test]
5852 fn test_build_filename_flatten_with_sequence_not_prefixed() {
5853 // When flatten=true, in sequence, filename not prefixed → add prefix
5854 let result = Client::build_filename(
5855 "image.camera.jpeg",
5856 true,
5857 Some(&"deer_sequence".to_string()),
5858 Some(42),
5859 );
5860 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
5861 }
5862
5863 #[test]
5864 fn test_build_filename_flatten_with_sequence_no_frame() {
5865 // When flatten=true, in sequence, no frame number → prefix with sequence only
5866 let result =
5867 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
5868 assert_eq!(result, "sequence_A_image.jpg");
5869 }
5870
5871 #[test]
5872 fn test_build_filename_flatten_already_prefixed() {
5873 // When flatten=true, filename already starts with sequence_ → return unchanged
5874 let result = Client::build_filename(
5875 "deer_sequence_042.camera.jpeg",
5876 true,
5877 Some(&"deer_sequence".to_string()),
5878 Some(42),
5879 );
5880 assert_eq!(result, "deer_sequence_042.camera.jpeg");
5881 }
5882
5883 #[test]
5884 fn test_build_filename_flatten_already_prefixed_different_frame() {
5885 // Edge case: filename has sequence prefix but we're adding different frame
5886 // Should still respect existing prefix
5887 let result = Client::build_filename(
5888 "sequence_A_001.jpg",
5889 true,
5890 Some(&"sequence_A".to_string()),
5891 Some(2),
5892 );
5893 assert_eq!(result, "sequence_A_001.jpg");
5894 }
5895
5896 #[test]
5897 fn test_build_filename_flatten_partial_match() {
5898 // Edge case: filename contains sequence name but not as prefix
5899 let result = Client::build_filename(
5900 "test_sequence_A_image.jpg",
5901 true,
5902 Some(&"sequence_A".to_string()),
5903 Some(5),
5904 );
5905 // Should add prefix because it doesn't START with "sequence_A_"
5906 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
5907 }
5908
5909 #[test]
5910 fn test_build_filename_flatten_preserves_extension() {
5911 // Verify that file extensions are preserved correctly
5912 let extensions = vec![
5913 "jpeg",
5914 "jpg",
5915 "png",
5916 "camera.jpeg",
5917 "lidar.pcd",
5918 "depth.png",
5919 ];
5920
5921 for ext in extensions {
5922 let filename = format!("image.{}", ext);
5923 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
5924 assert!(
5925 result.ends_with(&format!(".{}", ext)),
5926 "Extension .{} not preserved in {}",
5927 ext,
5928 result
5929 );
5930 }
5931 }
5932
5933 #[test]
5934 fn test_build_filename_flatten_sanitization_compatibility() {
5935 // Test with sanitized path components (no special chars)
5936 let result = Client::build_filename(
5937 "sample_001.jpg",
5938 true,
5939 Some(&"seq_name_with_underscores".to_string()),
5940 Some(10),
5941 );
5942 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
5943 }
5944
5945 // =========================================================================
5946 // Additional filter_and_sort_by_name tests for exact match determinism
5947 // =========================================================================
5948
5949 #[test]
5950 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
5951 // Test that searching for "Deer" always returns "Deer" first, not
5952 // "Deer Roundtrip 20251129" or similar
5953 let items = vec![
5954 "Deer Roundtrip 20251129".to_string(),
5955 "White-Tailed Deer".to_string(),
5956 "Deer".to_string(),
5957 "Deer Snapshot Test".to_string(),
5958 "Reindeer Dataset".to_string(),
5959 ];
5960
5961 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5962
5963 // CRITICAL: First result must be exact match "Deer"
5964 assert_eq!(
5965 result.first().map(|s| s.as_str()),
5966 Some("Deer"),
5967 "Expected exact match 'Deer' first, got: {:?}",
5968 result.first()
5969 );
5970
5971 // Verify all items containing "Deer" are present (case-insensitive)
5972 assert_eq!(result.len(), 5);
5973 }
5974
5975 #[test]
5976 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
5977 // Verify case-sensitive exact match takes priority over case-insensitive
5978 let items = vec![
5979 "DEER".to_string(),
5980 "deer".to_string(),
5981 "Deer".to_string(),
5982 "Deer Test".to_string(),
5983 ];
5984
5985 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5986
5987 // Priority 1: Case-sensitive exact match "Deer" first
5988 assert_eq!(result[0], "Deer");
5989 // Priority 2: Case-insensitive exact matches next
5990 assert!(result[1] == "DEER" || result[1] == "deer");
5991 assert!(result[2] == "DEER" || result[2] == "deer");
5992 }
5993
5994 #[test]
5995 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
5996 // Realistic scenario: User searches for snapshot "Deer" and multiple
5997 // snapshots exist with similar names
5998 let items = vec![
5999 "Unit Testing - Deer Dataset Backup".to_string(),
6000 "Deer".to_string(),
6001 "Deer Snapshot 2025-01-15".to_string(),
6002 "Original Deer".to_string(),
6003 ];
6004
6005 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6006
6007 // MUST return exact match first for deterministic test behavior
6008 assert_eq!(
6009 result[0], "Deer",
6010 "Searching for 'Deer' should return exact 'Deer' first"
6011 );
6012 }
6013
6014 #[test]
6015 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
6016 // Realistic scenario: User searches for dataset "Deer" but multiple
6017 // datasets have "Deer" in their name
6018 let items = vec![
6019 "Deer Roundtrip".to_string(),
6020 "Deer".to_string(),
6021 "deer".to_string(),
6022 "White-Tailed Deer".to_string(),
6023 "Deer-V2".to_string(),
6024 ];
6025
6026 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6027
6028 // Exact case-sensitive match must be first
6029 assert_eq!(result[0], "Deer");
6030 // Case-insensitive exact match should be second
6031 assert_eq!(result[1], "deer");
6032 // Shorter names should come before longer names
6033 assert!(
6034 result.iter().position(|s| s == "Deer-V2").unwrap()
6035 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
6036 );
6037 }
6038
6039 #[test]
6040 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
6041 // CRITICAL: The first result should ALWAYS be the best match
6042 // This is essential for deterministic test behavior
6043 let scenarios = vec![
6044 // (items, filter, expected_first)
6045 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
6046 (vec!["test", "TEST", "Test Data"], "test", "test"),
6047 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
6048 ];
6049
6050 for (items, filter, expected_first) in scenarios {
6051 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
6052 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
6053
6054 assert_eq!(
6055 result.first().map(|s| s.as_str()),
6056 Some(expected_first),
6057 "For filter '{}', expected first result '{}', got: {:?}",
6058 filter,
6059 expected_first,
6060 result.first()
6061 );
6062 }
6063 }
6064
6065 #[test]
6066 fn test_with_server_clears_storage() {
6067 use crate::storage::MemoryTokenStorage;
6068
6069 // Create client with memory storage and a token
6070 let storage = Arc::new(MemoryTokenStorage::new());
6071 storage.store("test-token").unwrap();
6072
6073 let client = Client::new().unwrap().with_storage(storage.clone());
6074
6075 // Verify token is loaded
6076 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
6077
6078 // Change server - should clear storage
6079 let _new_client = client.with_server("test").unwrap();
6080
6081 // Verify storage was cleared
6082 assert_eq!(storage.load().unwrap(), None);
6083 }
6084
6085 #[test]
6086 fn test_with_server_clears_storage_even_for_full_url() {
6087 // Regression: `with_server` used to short-circuit to `with_url`
6088 // when given a full URL, which preserved the bearer token. The
6089 // contract for `with_server` is that switching servers means
6090 // the token from the old server is no longer trusted.
6091 use crate::storage::MemoryTokenStorage;
6092
6093 let storage = Arc::new(MemoryTokenStorage::new());
6094 storage.store("token-from-old-server").unwrap();
6095 let client = Client::new().unwrap().with_storage(storage.clone());
6096 assert_eq!(
6097 storage.load().unwrap(),
6098 Some("token-from-old-server".to_string())
6099 );
6100
6101 // Switch to a self-hosted Studio (full URL). Storage must be
6102 // cleared, and the new client must have a blank in-memory token.
6103 let new_client = client
6104 .with_server("https://studio.example.com")
6105 .expect("https full URL through with_server");
6106 assert_eq!(storage.load().unwrap(), None);
6107 assert_eq!(new_client.url(), "https://studio.example.com");
6108
6109 // The new client should not carry the old token in memory either.
6110 let in_mem = tokio::runtime::Runtime::new()
6111 .unwrap()
6112 .block_on(async { new_client.token.read().await.clone() });
6113 assert!(in_mem.is_empty(), "expected blank token, got {in_mem:?}");
6114 }
6115
6116 #[test]
6117 fn test_with_server_rejects_insecure_full_url() {
6118 // `with_server` validates full URLs through `with_url`, so the
6119 // HTTPS rule applies uniformly. Plain http to a public host
6120 // must be rejected — the bearer token would otherwise leak in
6121 // plaintext when the caller next authenticates.
6122 let client = Client::new().unwrap();
6123 let err = client.with_server("http://studio.example.com").unwrap_err();
6124 assert!(matches!(err, Error::InsecureUrl(_)));
6125 }
6126
6127 // ===== with_url HTTPS enforcement =====
6128 //
6129 // The bearer token rides in the Authorization header, so plain
6130 // http:// to a public host would leak it in the clear. The function
6131 // must reject those URLs, but still let wiremock / local-dev URLs
6132 // through (loopback addresses, "localhost", "*.localhost").
6133
6134 #[test]
6135 fn with_url_accepts_https_public_host() {
6136 let client = Client::new().unwrap();
6137 let out = client
6138 .with_url("https://studio.example.com")
6139 .expect("https public host must be accepted");
6140 assert_eq!(out.url(), "https://studio.example.com");
6141 }
6142
6143 #[test]
6144 fn with_url_accepts_http_loopback_ipv4() {
6145 let client = Client::new().unwrap();
6146 let out = client
6147 .with_url("http://127.0.0.1:8080")
6148 .expect("http://127.0.0.1 must be accepted (loopback)");
6149 assert_eq!(out.url(), "http://127.0.0.1:8080");
6150 }
6151
6152 #[test]
6153 fn with_url_accepts_http_loopback_ipv6() {
6154 let client = Client::new().unwrap();
6155 let out = client
6156 .with_url("http://[::1]:8080")
6157 .expect("http://[::1] must be accepted (loopback)");
6158 assert!(out.url().starts_with("http://[::1]"));
6159 }
6160
6161 #[test]
6162 fn with_url_accepts_http_localhost() {
6163 let client = Client::new().unwrap();
6164 client
6165 .with_url("http://localhost:8080")
6166 .expect("http://localhost must be accepted");
6167 client
6168 .with_url("http://LOCALHOST")
6169 .expect("http://LOCALHOST must be accepted (case-insensitive)");
6170 client
6171 .with_url("http://wiremock.localhost")
6172 .expect("http://*.localhost must be accepted");
6173 }
6174
6175 #[test]
6176 fn with_url_rejects_http_public_host() {
6177 let client = Client::new().unwrap();
6178 let err = client.with_url("http://studio.example.com").unwrap_err();
6179 match err {
6180 Error::InsecureUrl(u) => assert_eq!(u, "http://studio.example.com"),
6181 other => panic!("expected InsecureUrl, got {other:?}"),
6182 }
6183 }
6184
6185 #[test]
6186 fn with_url_rejects_http_public_ip() {
6187 let client = Client::new().unwrap();
6188 // 8.8.8.8 is not loopback; must be rejected.
6189 let err = client.with_url("http://8.8.8.8").unwrap_err();
6190 assert!(matches!(err, Error::InsecureUrl(_)));
6191 }
6192
6193 #[test]
6194 fn with_url_rejects_non_http_scheme() {
6195 let client = Client::new().unwrap();
6196 // file:// would otherwise parse, but it's not a transport we
6197 // can use for RPC and we don't want to silently accept it.
6198 let err = client.with_url("file:///etc/passwd").unwrap_err();
6199 assert!(matches!(err, Error::InsecureUrl(_)));
6200 }
6201}
6202
6203#[cfg(test)]
6204mod tests_map_rpc_error {
6205 use super::*;
6206 use crate::api::TaskID;
6207
6208 #[test]
6209 fn maps_not_found_with_task_id_to_typed_variant() {
6210 // Server code 101 + "not found" message + task_id present → TaskNotFound
6211 let task_id = TaskID::try_from("task-1a2b").unwrap();
6212 let err = map_rpc_error(
6213 "task.data.list",
6214 101,
6215 "task not found".to_string(),
6216 Some(task_id),
6217 );
6218 assert!(matches!(err, Error::TaskNotFound(_)));
6219 }
6220
6221 #[test]
6222 fn maps_cannot_find_phrasing_to_typed_variant() {
6223 // The DVE server emits "Cannot find task..." — the original "not found"
6224 // substring match missed this and the caller saw a generic RpcError.
6225 let task_id = TaskID::try_from("task-1a2b").unwrap();
6226 let err = map_rpc_error(
6227 "task.data.list",
6228 101,
6229 "Cannot find task with id 6789".to_string(),
6230 Some(task_id),
6231 );
6232 assert!(
6233 matches!(err, Error::TaskNotFound(_)),
6234 "'Cannot find task' should map to TaskNotFound, got {err:?}"
6235 );
6236 }
6237
6238 #[test]
6239 fn maps_does_not_exist_phrasing_to_typed_variant() {
6240 let task_id = TaskID::try_from("task-1a2b").unwrap();
6241 let err = map_rpc_error(
6242 "task.chart.get",
6243 101,
6244 "task does not exist".to_string(),
6245 Some(task_id),
6246 );
6247 assert!(matches!(err, Error::TaskNotFound(_)));
6248 }
6249
6250 #[test]
6251 fn maps_code_101_with_unknown_phrasing_when_task_id_supplied() {
6252 // Server contract for code 101 is "resource not found"; even if the
6253 // phrasing is novel, the typed variant should be returned so callers
6254 // can write a stable `match`.
6255 let task_id = TaskID::try_from("task-1a2b").unwrap();
6256 let err = map_rpc_error(
6257 "task.data.list",
6258 101,
6259 "completely novel server message".to_string(),
6260 Some(task_id),
6261 );
6262 assert!(
6263 matches!(err, Error::TaskNotFound(_)),
6264 "code 101 + task_id should always map to TaskNotFound, got {err:?}"
6265 );
6266 }
6267
6268 #[test]
6269 fn maps_permission_codes_to_typed_variant() {
6270 for code in [401, 403] {
6271 let err = map_rpc_error("task.chart.add", code, "denied".to_string(), None);
6272 assert!(
6273 matches!(err, Error::PermissionDenied(_)),
6274 "code {} did not map",
6275 code
6276 );
6277 }
6278 }
6279
6280 #[test]
6281 fn permission_denied_records_method_for_diagnostics() {
6282 let err = map_rpc_error("task.data.upload", 403, "forbidden".to_string(), None);
6283 match err {
6284 Error::PermissionDenied(method) => assert_eq!(method, "task.data.upload"),
6285 other => panic!("expected PermissionDenied, got {:?}", other),
6286 }
6287 }
6288
6289 #[test]
6290 fn maps_payload_too_large_to_typed_variant() {
6291 let err = map_rpc_error("val.data.upload", 413, "request too large".into(), None);
6292 match err {
6293 Error::PayloadTooLarge { method, size_hint } => {
6294 assert_eq!(method, "val.data.upload");
6295 assert!(size_hint.is_none());
6296 }
6297 other => panic!("expected PayloadTooLarge, got {:?}", other),
6298 }
6299 }
6300
6301 #[test]
6302 fn falls_through_to_generic_rpc_error_for_unknown_codes() {
6303 let err = map_rpc_error("task.data.list", -99999, "weird".to_string(), None);
6304 match err {
6305 Error::RpcError(code, msg) => {
6306 assert_eq!(code, -99999);
6307 assert_eq!(msg, "weird");
6308 }
6309 other => panic!("expected RpcError, got {:?}", other),
6310 }
6311 }
6312
6313 #[test]
6314 fn not_found_without_task_id_falls_through() {
6315 // Code 101 without task_id → generic RpcError (no task to name)
6316 let err = map_rpc_error("task.data.list", 101, "not found".to_string(), None);
6317 assert!(matches!(err, Error::RpcError(101, _)));
6318 }
6319
6320 #[test]
6321 fn code_101_with_task_id_always_maps_even_with_unrelated_message() {
6322 // Previously the test asserted fall-through for non-"not found"
6323 // messages, but the contract for code 101 is "resource not found"
6324 // (see api.go), so when a task_id is present the typed variant is
6325 // returned unconditionally to give callers a stable error type.
6326 let task_id = TaskID::try_from("task-1a2b").unwrap();
6327 let err = map_rpc_error(
6328 "task.data.list",
6329 101,
6330 "permission denied".to_string(),
6331 Some(task_id),
6332 );
6333 assert!(matches!(err, Error::TaskNotFound(_)));
6334 }
6335}
6336
6337#[cfg(test)]
6338mod tests_jobs {
6339 use super::*;
6340
6341 #[test]
6342 fn jobs_list_request_serializes_to_empty_object() {
6343 let req = JobsListRequest {};
6344 assert_eq!(serde_json::to_value(&req).unwrap(), serde_json::json!({}));
6345 }
6346
6347 #[test]
6348 fn job_deserializes_from_bk_batch_shape() {
6349 let json = r#"{
6350 "code": "edgefirst-validator:2.9.5",
6351 "title": "EdgeFirst Validator",
6352 "job_name": "smoke-test",
6353 "job_id": "aws-batch-abc",
6354 "state": "RUNNING",
6355 "launch": "2026-05-14T15:00:00Z",
6356 "task_id": 6789,
6357 "docker_task": {},
6358 "extra_field": "ignored"
6359 }"#;
6360 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6361 assert_eq!(job.code, "edgefirst-validator:2.9.5");
6362 assert_eq!(job.state, "RUNNING");
6363 assert_eq!(job.task_id, 6789);
6364 assert_eq!(job.task_id().value(), 6789);
6365 }
6366}
6367
6368#[cfg(test)]
6369mod tests_job_run {
6370 use super::*;
6371 use crate::api::Parameter;
6372 use std::collections::HashMap;
6373
6374 #[test]
6375 fn job_run_request_serializes_with_expected_fields() {
6376 let req = JobRunRequest {
6377 name: "edgefirst-validator".into(),
6378 job_name: "post-profile-run".into(),
6379 env: HashMap::from([("LOG_LEVEL".into(), "info".into())]),
6380 data: HashMap::from([("validation_session_id".into(), Parameter::Integer(2707))]),
6381 };
6382 let json = serde_json::to_value(&req).unwrap();
6383 assert_eq!(json["name"], "edgefirst-validator");
6384 assert_eq!(json["job_name"], "post-profile-run");
6385 assert_eq!(json["env"]["LOG_LEVEL"], "info");
6386 assert_eq!(json["data"]["validation_session_id"], 2707);
6387 }
6388
6389 #[test]
6390 fn job_run_response_deserializes_as_job() {
6391 // job.run now returns the full BK_BATCH record; deserialize as Job.
6392 let json = r#"{
6393 "code": "edgefirst-validator:2.9.5",
6394 "title": "EdgeFirst Validator",
6395 "job_name": "post-profile-run",
6396 "job_id": "aws-batch-job-xxx",
6397 "state": "SUBMITTED",
6398 "task_id": 6789
6399 }"#;
6400 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6401 assert_eq!(job.task_id, 6789);
6402 assert_eq!(job.job_id, "aws-batch-job-xxx");
6403 assert_eq!(job.state, "SUBMITTED");
6404 }
6405}
6406
6407#[cfg(test)]
6408mod tests_job_stop {
6409 use super::*;
6410 use crate::api::TaskID;
6411
6412 #[test]
6413 fn job_stop_request_serializes_with_task_id() {
6414 let task_id = TaskID::try_from("task-1a2b").unwrap();
6415 let req = JobStopRequest {
6416 task_id: task_id.value(),
6417 };
6418 let json = serde_json::to_value(&req).unwrap();
6419 assert_eq!(json["task_id"], task_id.value());
6420 }
6421}
6422
6423#[cfg(test)]
6424mod tests_task_data_list_request {
6425 use super::*;
6426 use crate::api::TaskID;
6427
6428 #[test]
6429 fn task_data_list_request_serializes_with_task_id() {
6430 let task_id = TaskID::try_from("task-1a2b").unwrap();
6431 let req = TaskDataListRequest {
6432 task_id: task_id.value(),
6433 };
6434 let json = serde_json::to_value(&req).unwrap();
6435 assert_eq!(json["task_id"], task_id.value());
6436 }
6437}
6438
6439#[cfg(test)]
6440mod tests_task_data_download {
6441 use super::*;
6442 use crate::api::TaskID;
6443
6444 #[test]
6445 fn task_data_download_request_serializes_with_all_fields() {
6446 let task_id = TaskID::try_from("task-1a2b").unwrap();
6447 let req = TaskDataDownloadRequest {
6448 task_id: task_id.value(),
6449 folder: "predictions".into(),
6450 file: "predictions.parquet".into(),
6451 };
6452 let json = serde_json::to_value(&req).unwrap();
6453 assert_eq!(json["task_id"], task_id.value());
6454 assert_eq!(json["folder"], "predictions");
6455 assert_eq!(json["file"], "predictions.parquet");
6456 }
6457}
6458
6459#[cfg(test)]
6460mod tests_task_chart_add {
6461 use super::*;
6462 use crate::api::{Parameter, TaskID};
6463
6464 #[test]
6465 fn task_chart_add_request_serializes_with_correct_fields() {
6466 let task_id = TaskID::try_from("task-1a2b").unwrap();
6467 let data = Parameter::Object(std::collections::HashMap::from([(
6468 "type".into(),
6469 Parameter::String("line".into()),
6470 )]));
6471 let req = TaskChartAddRequest {
6472 task_id: task_id.value(),
6473 group_name: "metrics".into(),
6474 chart_name: "loss".into(),
6475 params: None,
6476 data,
6477 };
6478 let json = serde_json::to_value(&req).unwrap();
6479 assert_eq!(json["task_id"], task_id.value());
6480 assert_eq!(json["group_name"], "metrics");
6481 assert_eq!(json["chart_name"], "loss");
6482 assert_eq!(json["data"]["type"], "line");
6483 assert!(json["params"].is_null());
6484 }
6485}
6486
6487#[cfg(test)]
6488mod tests_task_chart_list {
6489 use super::*;
6490 use crate::api::TaskID;
6491
6492 #[test]
6493 fn task_chart_list_request_omits_empty_group_name() {
6494 let task_id = TaskID::try_from("task-1a2b").unwrap();
6495 let req = TaskChartListRequest {
6496 task_id: task_id.value(),
6497 group_name: String::new(),
6498 };
6499 let json = serde_json::to_value(&req).unwrap();
6500 assert_eq!(json["task_id"], task_id.value());
6501 assert_eq!(json["group_name"], "");
6502 }
6503}
6504
6505#[cfg(test)]
6506mod tests_task_chart_get {
6507 use super::*;
6508 use crate::api::TaskID;
6509
6510 #[test]
6511 fn task_chart_get_request_serializes_with_all_fields() {
6512 let task_id = TaskID::try_from("task-1a2b").unwrap();
6513 let req = TaskChartGetRequest {
6514 task_id: task_id.value(),
6515 group_name: "metrics".into(),
6516 chart_name: "loss".into(),
6517 };
6518 let json = serde_json::to_value(&req).unwrap();
6519 assert_eq!(json["task_id"], task_id.value());
6520 assert_eq!(json["group_name"], "metrics");
6521 assert_eq!(json["chart_name"], "loss");
6522 }
6523}
6524
6525#[cfg(test)]
6526mod tests_val_data_download {
6527 use super::*;
6528
6529 #[test]
6530 fn val_data_download_request_serializes() {
6531 let req = ValDataDownloadRequest {
6532 session_id: 2707,
6533 filename: "trace/imx95.json".into(),
6534 };
6535 let json = serde_json::to_value(&req).unwrap();
6536 assert_eq!(json["session_id"], 2707);
6537 assert_eq!(json["filename"], "trace/imx95.json");
6538 }
6539}
6540
6541#[cfg(test)]
6542mod tests_val_data_list {
6543 use super::*;
6544
6545 #[test]
6546 fn val_data_list_request_serializes() {
6547 let req = ValDataListRequest { session_id: 2707 };
6548 assert_eq!(
6549 serde_json::to_value(&req).unwrap(),
6550 serde_json::json!({"session_id": 2707})
6551 );
6552 }
6553}
6554
6555#[cfg(test)]
6556mod tests_jsonrpc_envelope_detection {
6557 use super::*;
6558
6559 #[test]
6560 fn detects_real_envelope() {
6561 let v = serde_json::json!({
6562 "jsonrpc": "2.0",
6563 "id": 0,
6564 "error": { "code": 101, "message": "Cannot find task" },
6565 });
6566 assert!(is_jsonrpc_error_envelope(&v));
6567 }
6568
6569 #[test]
6570 fn rejects_plain_json_artifact_with_error_field() {
6571 // A diagnostics file with a free-form `error` object — must not be
6572 // misread as an RPC envelope just because the key collides.
6573 let v = serde_json::json!({
6574 "metric": "loss",
6575 "value": 0.42,
6576 "error": { "code": "ENV_NOT_FOUND", "message": "missing var" },
6577 });
6578 assert!(
6579 !is_jsonrpc_error_envelope(&v),
6580 "missing jsonrpc sentinel should mean 'not an envelope'"
6581 );
6582 }
6583
6584 #[test]
6585 fn rejects_envelope_missing_jsonrpc_sentinel() {
6586 // Bare `error` block without the protocol-version marker.
6587 let v = serde_json::json!({
6588 "id": 0,
6589 "error": { "code": 101, "message": "x" },
6590 });
6591 assert!(!is_jsonrpc_error_envelope(&v));
6592 }
6593
6594 #[test]
6595 fn rejects_envelope_with_non_object_error_field() {
6596 // A diagnostics file shaped like JSON-RPC accidentally but using
6597 // a string for `error`.
6598 let v = serde_json::json!({
6599 "jsonrpc": "2.0",
6600 "error": "something went wrong",
6601 });
6602 assert!(!is_jsonrpc_error_envelope(&v));
6603 }
6604
6605 #[test]
6606 fn rejects_envelope_without_error_code() {
6607 // Real envelopes always carry an integer error.code; missing one
6608 // is suspicious enough to refuse the envelope classification.
6609 let v = serde_json::json!({
6610 "jsonrpc": "2.0",
6611 "error": { "message": "no code" },
6612 });
6613 assert!(!is_jsonrpc_error_envelope(&v));
6614 }
6615
6616 #[test]
6617 fn rejects_envelope_with_non_numeric_error_code() {
6618 let v = serde_json::json!({
6619 "jsonrpc": "2.0",
6620 "error": { "code": "ENOENT", "message": "x" },
6621 });
6622 assert!(!is_jsonrpc_error_envelope(&v));
6623 }
6624
6625 #[test]
6626 fn rejects_non_object_root() {
6627 // A JSON file whose root is an array — common for metrics dumps —
6628 // must not be misread.
6629 let v = serde_json::json!([1, 2, 3]);
6630 assert!(!is_jsonrpc_error_envelope(&v));
6631 }
6632
6633 #[test]
6634 fn accepts_unsigned_error_code() {
6635 // The server's code is technically i32 but JSON has no signed/
6636 // unsigned distinction — accept both shapes.
6637 let v = serde_json::json!({
6638 "jsonrpc": "2.0",
6639 "error": { "code": 101u32, "message": "x" },
6640 });
6641 assert!(is_jsonrpc_error_envelope(&v));
6642 }
6643}
6644
6645#[cfg(test)]
6646mod tests_validate_chart_args {
6647 use super::*;
6648
6649 #[test]
6650 fn rejects_empty_group() {
6651 let err = validate_chart_args("", "name").unwrap_err();
6652 assert!(matches!(err, Error::InvalidParameters(_)));
6653 }
6654
6655 #[test]
6656 fn rejects_empty_name() {
6657 let err = validate_chart_args("group", "").unwrap_err();
6658 assert!(matches!(err, Error::InvalidParameters(_)));
6659 }
6660
6661 #[test]
6662 fn rejects_both_empty() {
6663 let err = validate_chart_args("", "").unwrap_err();
6664 assert!(matches!(err, Error::InvalidParameters(_)));
6665 }
6666
6667 #[test]
6668 fn accepts_valid_args() {
6669 assert!(validate_chart_args("group", "name").is_ok());
6670 }
6671
6672 #[test]
6673 fn accepts_unicode_args() {
6674 // Unicode names are allowed; only emptiness is rejected.
6675 assert!(validate_chart_args("metrics-集合", "损失").is_ok());
6676 }
6677}
6678
6679// ---------------------------------------------------------------------------
6680// Additional offline tests for request shapes + helpers added in DE-2565.
6681//
6682// These focus on the wire-shape and helper logic that does not require a
6683// live Studio server — they significantly boost coverage of client.rs.
6684// ---------------------------------------------------------------------------
6685
6686#[cfg(test)]
6687mod tests_job_run_request_shape {
6688 use super::*;
6689 use crate::api::Parameter;
6690 use std::collections::HashMap;
6691
6692 #[test]
6693 fn empty_env_and_data_serialize_as_empty_objects() {
6694 let req = JobRunRequest {
6695 name: "edgefirst-validator".into(),
6696 job_name: "smoke".into(),
6697 env: HashMap::new(),
6698 data: HashMap::new(),
6699 };
6700 let json = serde_json::to_value(&req).unwrap();
6701 assert_eq!(json["name"], "edgefirst-validator");
6702 assert_eq!(json["env"], serde_json::json!({}));
6703 assert_eq!(json["data"], serde_json::json!({}));
6704 }
6705
6706 #[test]
6707 fn data_passes_through_parameter_object_payloads() {
6708 // Confirms the Parameter wrapper survives JSON serialization round-trip
6709 // for the kind of structured chart payload that exercises Parameter
6710 // variants (Real, Integer, String, Array, Object, Boolean).
6711 let req = JobRunRequest {
6712 name: "edgefirst-validator".into(),
6713 job_name: "feat".into(),
6714 env: HashMap::new(),
6715 data: HashMap::from([
6716 ("flag".into(), Parameter::Boolean(true)),
6717 ("epochs".into(), Parameter::Integer(50)),
6718 ("lr".into(), Parameter::Real(1e-3)),
6719 ("name".into(), Parameter::String("hello".into())),
6720 ]),
6721 };
6722 let json = serde_json::to_value(&req).unwrap();
6723 assert_eq!(json["data"]["flag"], true);
6724 assert_eq!(json["data"]["epochs"], 50);
6725 assert!(json["data"]["lr"].as_f64().unwrap() > 0.0);
6726 assert_eq!(json["data"]["name"], "hello");
6727 }
6728}
6729
6730#[cfg(test)]
6731mod tests_task_data_chart_request_shape {
6732 use super::*;
6733 use crate::api::{Parameter, TaskID};
6734
6735 #[test]
6736 fn chart_add_request_with_params_serializes_object() {
6737 let task_id = TaskID::try_from("task-1a2b").unwrap();
6738 let params = Parameter::Object(std::collections::HashMap::from([(
6739 "y_axis".into(),
6740 Parameter::String("log".into()),
6741 )]));
6742 let data = Parameter::Object(std::collections::HashMap::from([(
6743 "type".into(),
6744 Parameter::String("line".into()),
6745 )]));
6746 let req = TaskChartAddRequest {
6747 task_id: task_id.value(),
6748 group_name: "metrics".into(),
6749 chart_name: "loss".into(),
6750 params: Some(params),
6751 data,
6752 };
6753 let json = serde_json::to_value(&req).unwrap();
6754 assert_eq!(json["params"]["y_axis"], "log");
6755 }
6756
6757 #[test]
6758 fn task_data_list_request_round_trips() {
6759 let task_id = TaskID::try_from("task-1a2b").unwrap();
6760 let req = TaskDataListRequest {
6761 task_id: task_id.value(),
6762 };
6763 let json = serde_json::to_string(&req).unwrap();
6764 // Field order is stable for a single-field struct, so an exact match
6765 // is meaningful here.
6766 assert_eq!(json, format!("{{\"task_id\":{}}}", task_id.value()));
6767 }
6768
6769 #[test]
6770 fn task_data_download_request_treats_folder_and_file_independently() {
6771 let task_id = TaskID::try_from("task-1a2b").unwrap();
6772 let req = TaskDataDownloadRequest {
6773 task_id: task_id.value(),
6774 folder: "validation/run-01".into(),
6775 file: "metrics.json".into(),
6776 };
6777 let json = serde_json::to_value(&req).unwrap();
6778 // Server takes folder + file separately (not a single combined path)
6779 // so callers don't have to escape slashes themselves.
6780 assert_eq!(json["folder"], "validation/run-01");
6781 assert_eq!(json["file"], "metrics.json");
6782 }
6783}
6784
6785#[cfg(test)]
6786mod tests_val_data_request_shape {
6787 use super::*;
6788
6789 #[test]
6790 fn val_data_list_round_trips() {
6791 let req = ValDataListRequest { session_id: 2707 };
6792 let s = serde_json::to_string(&req).unwrap();
6793 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6794 assert_eq!(back["session_id"], 2707);
6795 }
6796
6797 #[test]
6798 fn val_data_download_round_trips_with_nested_path() {
6799 let req = ValDataDownloadRequest {
6800 session_id: 2707,
6801 filename: "subfolder/imx95.json".into(),
6802 };
6803 let s = serde_json::to_string(&req).unwrap();
6804 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6805 assert_eq!(back["session_id"], 2707);
6806 assert_eq!(back["filename"], "subfolder/imx95.json");
6807 }
6808}
6809
6810#[cfg(test)]
6811mod tests_progress_struct {
6812 use super::*;
6813
6814 #[test]
6815 fn progress_can_be_constructed_with_zero_total() {
6816 // Servers sometimes omit Content-Length; progress events should still
6817 // be representable. This guards the public field-level API.
6818 let p = Progress {
6819 current: 0,
6820 total: 0,
6821 status: None,
6822 };
6823 assert_eq!(p.current, 0);
6824 assert_eq!(p.total, 0);
6825 assert!(p.status.is_none());
6826 }
6827
6828 #[test]
6829 fn progress_tracks_current_independently_of_total() {
6830 let p = Progress {
6831 current: 123,
6832 total: 456,
6833 status: Some("Downloading".into()),
6834 };
6835 assert_eq!(p.current, 123);
6836 assert_eq!(p.total, 456);
6837 assert_eq!(p.status.as_deref(), Some("Downloading"));
6838 }
6839
6840 #[test]
6841 fn progress_can_be_cloned() {
6842 // Progress is consumed by progress sinks which may need to retain a
6843 // copy independently of the channel — derive(Clone) must hold.
6844 let p = Progress {
6845 current: 10,
6846 total: 20,
6847 status: Some("phase".into()),
6848 };
6849 let q = p.clone();
6850 assert_eq!(q.current, p.current);
6851 assert_eq!(q.total, p.total);
6852 assert_eq!(q.status, p.status);
6853 }
6854}
6855
6856#[cfg(test)]
6857mod tests_bare_filename_parent {
6858 // Documents the empty-parent guard added for `rpc_download` so that
6859 // callers passing a bare filename like "metrics.json" download to the
6860 // current directory instead of erroring on `create_dir_all("")`.
6861 use std::path::Path;
6862
6863 #[test]
6864 fn bare_filename_parent_is_empty_path() {
6865 // This is the invariant our guard depends on. If a future Rust
6866 // release ever changed `Path::parent` for bare filenames, the guard
6867 // would need revisiting.
6868 let p = Path::new("metrics.json");
6869 let parent = p.parent().expect("bare filename always has Some parent");
6870 assert!(
6871 parent.as_os_str().is_empty(),
6872 "Path::parent for bare filename should be empty, got: {parent:?}"
6873 );
6874 }
6875
6876 #[test]
6877 fn path_with_directory_has_non_empty_parent() {
6878 // The companion case: when the path includes a directory, the
6879 // parent is non-empty and `create_dir_all` should be invoked.
6880 let p = Path::new("dir/metrics.json");
6881 let parent = p.parent().expect("path-with-dir always has Some parent");
6882 assert!(!parent.as_os_str().is_empty());
6883 assert_eq!(parent, Path::new("dir"));
6884 }
6885}