edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult,
8 NewValidationSession, Organization, Project, ProjectID, SampleID, SamplesCountResult,
9 SamplesListParams, SamplesListResult, Snapshot, SnapshotCreateFromDataset,
10 SnapshotFromDatasetResult, SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage,
11 StartValidationRequest, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
12 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
13 ValidationSessionID,
14 },
15 dataset::{
16 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
17 },
18 retry::{create_retry_policy, log_retry_configuration},
19 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
20};
21use base64::Engine as _;
22use chrono::{DateTime, Utc};
23use directories::ProjectDirs;
24use futures::{StreamExt as _, future::join_all};
25use log::{Level, debug, error, log_enabled, trace, warn};
26use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::{
29 collections::HashMap,
30 ffi::OsStr,
31 fs::create_dir_all,
32 io::{SeekFrom, Write as _},
33 path::{Path, PathBuf},
34 sync::{
35 Arc,
36 atomic::{AtomicUsize, Ordering},
37 },
38 time::Duration,
39 vec,
40};
41use tokio::{
42 fs::{self, File},
43 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
44 sync::{RwLock, Semaphore, mpsc::Sender},
45};
46use tokio_util::codec::{BytesCodec, FramedRead};
47use walkdir::WalkDir;
48
49#[cfg(feature = "polars")]
50use polars::prelude::*;
51
52/// Maps a JSON-RPC error code to a typed `Error` variant when the code is
53/// well-known; otherwise returns `Error::RpcError(code, message)` unchanged.
54///
55/// Scoped to the new DE-2565 methods. Existing methods continue to return
56/// `Error::RpcError` directly.
57///
58/// Server error codes (from `api.go` via `jrpc.Fail`):
59/// - `1` – generic server error
60/// - `3` – validation / bad request
61/// - `10` – internal server error
62/// - `101` – resource not found (e.g. "Cannot find task...", "not found in DB")
63/// - `401` – unauthenticated
64/// - `403` – forbidden
65/// - `413` – payload too large
66pub(crate) fn map_rpc_error(
67 method: &str,
68 code: i32,
69 message: String,
70 task_id: Option<crate::api::TaskID>,
71) -> Error {
72 // Server emits "Cannot find task...", "not found in DB", and other phrasings
73 // for code 101. Code 101 with a task_id is task-not-found by contract
74 // (see api.go), so we return the typed variant unconditionally when the
75 // caller supplied a task_id — message phrasing is treated as informational
76 // and is preserved by the RPC layer for diagnostic logging upstream.
77 if code == 101
78 && let Some(id) = task_id
79 {
80 return Error::TaskNotFound(id);
81 }
82 match code {
83 401 | 403 => Error::PermissionDenied(method.to_string()),
84 413 => Error::PayloadTooLarge {
85 method: method.to_string(),
86 size_hint: None,
87 },
88 _ => Error::RpcError(code, message),
89 }
90}
91
92/// Returns true if `val` is structurally a JSON-RPC 2.0 *error* envelope.
93///
94/// A real envelope must:
95/// 1. Be a JSON object,
96/// 2. Carry a `"jsonrpc"` member (the protocol-version sentinel — JSON-RPC
97/// 2.0 §5 mandates this on every response object),
98/// 3. Carry an `"error"` object that includes a numeric `"code"` field.
99///
100/// This is intentionally stricter than a "looks for a top-level `error`
101/// key" check so that legitimate JSON file payloads (validation traces,
102/// metrics dumps, diagnostics) which happen to include a free-form `error`
103/// field are *not* misclassified as RPC failures.
104///
105/// Extracted so it can be unit-tested without a live server.
106pub(crate) fn is_jsonrpc_error_envelope(val: &serde_json::Value) -> bool {
107 let Some(obj) = val.as_object() else {
108 return false;
109 };
110 // Protocol-version sentinel — only JSON-RPC envelopes carry this.
111 if !obj.contains_key("jsonrpc") {
112 return false;
113 }
114 let Some(err) = obj.get("error").and_then(|e| e.as_object()) else {
115 return false;
116 };
117 err.get("code")
118 .map(|c| c.is_i64() || c.is_u64())
119 .unwrap_or(false)
120}
121
122/// Validates that `group` and `name` are both non-empty strings for chart
123/// operations (`add_chart`, `get_chart`). Extracted so it can be unit-tested
124/// without a live server.
125pub(crate) fn validate_chart_args(group: &str, name: &str) -> Result<(), Error> {
126 if group.is_empty() || name.is_empty() {
127 return Err(Error::InvalidParameters(
128 "chart: group and name must be non-empty".into(),
129 ));
130 }
131 Ok(())
132}
133
134static PART_SIZE: usize = 100 * 1024 * 1024;
135
136/// Source for file content during upload - either a local path or raw bytes.
137#[derive(Clone)]
138enum FileSource {
139 /// File content from a local filesystem path.
140 Path(PathBuf),
141 /// File content as raw bytes (e.g., from a ZIP archive).
142 Bytes(Vec<u8>),
143}
144
145fn max_tasks() -> usize {
146 std::env::var("MAX_TASKS")
147 .ok()
148 .and_then(|v| v.parse().ok())
149 .unwrap_or_else(|| {
150 // Default to half the number of CPUs, minimum 2, maximum 8
151 let cpus = std::thread::available_parallelism()
152 .map(|n| n.get())
153 .unwrap_or(4);
154 (cpus / 2).clamp(2, 8)
155 })
156}
157
158/// Maximum concurrent upload tasks for multipart S3 uploads.
159///
160/// Higher concurrency improves upload throughput by saturating available
161/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
162fn max_upload_tasks() -> usize {
163 std::env::var("MAX_UPLOAD_TASKS")
164 .ok()
165 .and_then(|v| v.parse().ok())
166 .unwrap_or(8) // Default to 8 concurrent part uploads
167}
168
169/// Filters items by name and sorts by match quality.
170///
171/// Match quality priority (best to worst):
172/// 1. Exact match (case-sensitive)
173/// 2. Exact match (case-insensitive)
174/// 3. Substring match (shorter names first, then alphabetically)
175///
176/// This ensures that searching for "Deer" returns "Deer" before
177/// "Deer Roundtrip 20251129" or "Reindeer".
178fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
179where
180 F: Fn(&T) -> &str,
181{
182 let filter_lower = filter.to_lowercase();
183 let mut filtered: Vec<T> = items
184 .into_iter()
185 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
186 .collect();
187
188 filtered.sort_by(|a, b| {
189 let name_a = get_name(a);
190 let name_b = get_name(b);
191
192 // Priority 1: Exact match (case-sensitive)
193 let exact_a = name_a == filter;
194 let exact_b = name_b == filter;
195 if exact_a != exact_b {
196 return exact_b.cmp(&exact_a); // true (exact) comes first
197 }
198
199 // Priority 2: Exact match (case-insensitive)
200 let exact_ci_a = name_a.to_lowercase() == filter_lower;
201 let exact_ci_b = name_b.to_lowercase() == filter_lower;
202 if exact_ci_a != exact_ci_b {
203 return exact_ci_b.cmp(&exact_ci_a);
204 }
205
206 // Priority 3: Shorter names first (more specific matches)
207 let len_cmp = name_a.len().cmp(&name_b.len());
208 if len_cmp != std::cmp::Ordering::Equal {
209 return len_cmp;
210 }
211
212 // Priority 4: Alphabetical order for stability
213 name_a.cmp(name_b)
214 });
215
216 filtered
217}
218
219/// Whether `host` refers to a loopback (machine-local) endpoint.
220///
221/// Used by [`Client::with_url`] to decide whether a plain-`http://` URL is
222/// safe to accept. Loopback traffic never leaves the machine, so the
223/// usual concern about leaking the Studio bearer token in plaintext does
224/// not apply — that's how wiremock and local dev servers connect.
225fn is_loopback_host(host: Option<&url::Host<&str>>) -> bool {
226 match host {
227 Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
228 Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
229 // RFC 6761 reserves "localhost" (and `*.localhost`) as a loopback
230 // name. Compare case-insensitively because URL hosts are matched
231 // that way and developers do type capitalized variants.
232 Some(url::Host::Domain(d)) => {
233 d.eq_ignore_ascii_case("localhost") || d.to_ascii_lowercase().ends_with(".localhost")
234 }
235 None => false,
236 }
237}
238
239fn sanitize_path_component(name: &str) -> String {
240 let trimmed = name.trim();
241 if trimmed.is_empty() {
242 return "unnamed".to_string();
243 }
244
245 let component = Path::new(trimmed)
246 .file_name()
247 .unwrap_or_else(|| OsStr::new(trimmed));
248
249 let sanitized: String = component
250 .to_string_lossy()
251 .chars()
252 .map(|c| match c {
253 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
254 _ => c,
255 })
256 .collect();
257
258 if sanitized.is_empty() {
259 "unnamed".to_string()
260 } else {
261 sanitized
262 }
263}
264
265/// Progress information for long-running operations.
266///
267/// This struct tracks the current progress of operations like file uploads,
268/// downloads, or dataset processing. It provides the current count, total
269/// count, and an optional status string to enable progress reporting in
270/// applications.
271///
272/// # Multi-Stage Progress
273///
274/// The `status` field enables multi-stage progress tracking. When an operation
275/// has multiple phases, the status field changes to indicate the current phase.
276/// Applications should detect status changes to reset their progress display.
277///
278/// # Operation Progress Details
279///
280/// | Operation | Status | Unit | Notes |
281/// |-----------|--------|------|-------|
282/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
283/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
284/// | [`samples`] | `None` | samples | Paginated API fetch |
285/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
286/// | [`annotations`] | `None` | samples | Samples processed for annotations |
287/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
288/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
289/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
290///
291/// [`download_dataset`]: Client::download_dataset
292/// [`populate_samples`]: Client::populate_samples
293/// [`samples`]: Client::samples
294/// [`sample_names`]: Client::sample_names
295/// [`annotations`]: Client::annotations
296/// [`download_artifact`]: Client::download_artifact
297/// [`download_checkpoint`]: Client::download_checkpoint
298/// [`download_snapshot`]: Client::download_snapshot
299///
300/// # Examples
301///
302/// Basic progress display:
303///
304/// ```rust
305/// use edgefirst_client::Progress;
306///
307/// let progress = Progress {
308/// current: 25,
309/// total: 100,
310/// status: Some("Downloading".to_string()),
311/// };
312/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
313/// println!(
314/// "{}: {:.1}% ({}/{})",
315/// progress.status.as_deref().unwrap_or("Progress"),
316/// percentage,
317/// progress.current,
318/// progress.total
319/// );
320/// ```
321///
322/// Multi-stage progress handling (e.g., for `download_dataset`):
323///
324/// ```rust,ignore
325/// let mut last_status: Option<String> = None;
326///
327/// while let Some(progress) = rx.recv().await {
328/// // Detect stage change and reset progress bar
329/// if progress.status != last_status {
330/// if let Some(ref status) = progress.status {
331/// println!("\n{}", status);
332/// }
333/// last_status = progress.status.clone();
334/// }
335///
336/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
337/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
338/// }
339/// ```
340#[derive(Debug, Clone)]
341pub struct Progress {
342 /// Current number of completed items or bytes.
343 pub current: usize,
344 /// Total number of items or bytes to process.
345 pub total: usize,
346 /// Optional status describing the current operation phase.
347 ///
348 /// When this value changes from `None` to `Some(...)` or between different
349 /// values, it indicates a new phase has started. Applications should reset
350 /// their progress display when the status changes.
351 ///
352 /// Currently only [`Client::download_dataset`] uses status changes:
353 /// - Phase 1: `None` while fetching sample metadata
354 /// - Phase 2: `"Downloading"` while downloading files
355 ///
356 /// All other operations use `None` throughout.
357 pub status: Option<String>,
358}
359
360#[derive(Serialize)]
361struct RpcRequest<Params> {
362 id: u64,
363 jsonrpc: String,
364 method: String,
365 params: Option<Params>,
366}
367
368impl<T> Default for RpcRequest<T> {
369 fn default() -> Self {
370 RpcRequest {
371 id: 0,
372 jsonrpc: "2.0".to_string(),
373 method: "".to_string(),
374 params: None,
375 }
376 }
377}
378
379#[derive(Deserialize)]
380struct RpcError {
381 code: i32,
382 message: String,
383}
384
385#[derive(Deserialize)]
386struct RpcResponse<RpcResult> {
387 #[allow(dead_code)]
388 id: String,
389 #[allow(dead_code)]
390 jsonrpc: String,
391 error: Option<RpcError>,
392 result: Option<RpcResult>,
393}
394
395#[derive(Deserialize)]
396#[allow(dead_code)]
397struct EmptyResult {}
398
399#[derive(Debug, Serialize)]
400#[allow(dead_code)]
401struct SnapshotCreateParams {
402 snapshot_name: String,
403 keys: Vec<String>,
404}
405
406#[derive(Debug, Deserialize)]
407#[allow(dead_code)]
408struct SnapshotCreateResult {
409 snapshot_id: SnapshotID,
410 urls: Vec<String>,
411}
412
413#[derive(Debug, Serialize)]
414struct SnapshotCreateMultipartParams {
415 snapshot_name: String,
416 keys: Vec<String>,
417 file_sizes: Vec<usize>,
418 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
419 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
420 snapshot_type: Option<String>,
421}
422
423#[derive(Debug, Deserialize)]
424#[serde(untagged)]
425enum SnapshotCreateMultipartResultField {
426 Id(u64),
427 Part(SnapshotPart),
428}
429
430#[derive(Debug, Serialize)]
431struct SnapshotCompleteMultipartParams {
432 key: String,
433 upload_id: String,
434 etag_list: Vec<EtagPart>,
435}
436
437#[derive(Debug, Clone, Serialize)]
438struct EtagPart {
439 #[serde(rename = "ETag")]
440 etag: String,
441 #[serde(rename = "PartNumber")]
442 part_number: usize,
443}
444
445#[derive(Debug, Clone, Deserialize)]
446struct SnapshotPart {
447 key: Option<String>,
448 upload_id: String,
449 urls: Vec<String>,
450}
451
452#[derive(Debug, Serialize)]
453struct SnapshotStatusParams {
454 snapshot_id: SnapshotID,
455 status: String,
456}
457
458#[derive(Deserialize, Debug)]
459struct SnapshotStatusResult {
460 #[allow(dead_code)]
461 pub id: SnapshotID,
462 #[allow(dead_code)]
463 pub uid: String,
464 #[allow(dead_code)]
465 pub description: String,
466 #[allow(dead_code)]
467 pub date: String,
468 #[allow(dead_code)]
469 pub status: String,
470}
471
472#[derive(Serialize)]
473#[allow(dead_code)]
474struct ImageListParams {
475 images_filter: ImagesFilter,
476 image_files_filter: HashMap<String, String>,
477 only_ids: bool,
478}
479
480#[derive(Serialize)]
481#[allow(dead_code)]
482struct ImagesFilter {
483 dataset_id: DatasetID,
484}
485
486/// Main client for interacting with EdgeFirst Studio Server.
487///
488/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
489/// and manages authentication, RPC calls, and data operations. It provides
490/// methods for managing projects, datasets, experiments, training sessions,
491/// and various utility functions for data processing.
492///
493/// The client supports multiple authentication methods and can work with both
494/// SaaS and self-hosted EdgeFirst Studio instances.
495///
496/// # Features
497///
498/// - **Authentication**: Token-based authentication with automatic persistence
499/// - **Dataset Management**: Upload, download, and manipulate datasets
500/// - **Project Operations**: Create and manage projects and experiments
501/// - **Training & Validation**: Submit and monitor ML training jobs
502/// - **Data Integration**: Convert between EdgeFirst datasets and popular
503/// formats
504/// - **Progress Tracking**: Real-time progress updates for long-running
505/// operations
506///
507/// # Examples
508///
509/// ```no_run
510/// use edgefirst_client::{Client, DatasetID};
511/// use std::str::FromStr;
512///
513/// # async fn example() -> Result<(), edgefirst_client::Error> {
514/// // Create a new client and authenticate
515/// let mut client = Client::new()?;
516/// let client = client
517/// .with_login("your-email@example.com", "password")
518/// .await?;
519///
520/// // Or use an existing token
521/// let base_client = Client::new()?;
522/// let client = base_client.with_token("your-token-here")?;
523///
524/// // Get organization and projects
525/// let org = client.organization().await?;
526/// let projects = client.projects(None).await?;
527///
528/// // Work with datasets
529/// let dataset_id = DatasetID::from_str("ds-abc123")?;
530/// let dataset = client.dataset(dataset_id).await?;
531/// # Ok(())
532/// # }
533/// ```
534/// Client is Clone but cannot derive Debug due to dyn TokenStorage
535#[derive(Clone)]
536pub struct Client {
537 http: reqwest::Client,
538 /// HTTP client for long-running bulk transfers (uploads/downloads, no total-request
539 /// timeout). An idle read timeout is still configured on the underlying client, and
540 /// some operations (such as uploads) may apply additional per-request timeouts.
541 bulk_http: reqwest::Client,
542 url: String,
543 token: Arc<RwLock<String>>,
544 /// Token storage backend. When set, tokens are automatically persisted.
545 storage: Option<Arc<dyn TokenStorage>>,
546 /// Legacy token path field for backwards compatibility with
547 /// with_token_path(). Deprecated: Use with_storage() instead.
548 token_path: Option<PathBuf>,
549}
550
551impl std::fmt::Debug for Client {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("Client")
554 .field("url", &self.url)
555 .field("has_storage", &self.storage.is_some())
556 .field("token_path", &self.token_path)
557 .finish()
558 }
559}
560
561/// Private context struct for pagination operations
562struct FetchContext<'a> {
563 dataset_id: DatasetID,
564 annotation_set_id: Option<AnnotationSetID>,
565 groups: &'a [String],
566 types: Vec<String>,
567 labels: &'a HashMap<String, u64>,
568}
569
570#[derive(Debug, Serialize)]
571struct JobsListRequest {}
572
573#[derive(Debug, Serialize)]
574struct JobRunRequest {
575 name: String,
576 job_name: String,
577 env: std::collections::HashMap<String, String>,
578 data: std::collections::HashMap<String, crate::api::Parameter>,
579}
580
581#[derive(Debug, Serialize)]
582struct JobStopRequest {
583 task_id: u64,
584}
585
586#[derive(Debug, Serialize)]
587pub(crate) struct TaskDataListRequest {
588 pub(crate) task_id: u64,
589}
590
591#[derive(Debug, Serialize)]
592pub(crate) struct TaskDataDownloadRequest {
593 pub(crate) task_id: u64,
594 pub(crate) folder: String,
595 pub(crate) file: String,
596}
597
598#[derive(Debug, Serialize)]
599pub(crate) struct TaskChartAddRequest {
600 pub(crate) task_id: u64,
601 pub(crate) group_name: String,
602 pub(crate) chart_name: String,
603 pub(crate) params: Option<crate::api::Parameter>,
604 pub(crate) data: crate::api::Parameter,
605}
606
607#[derive(Debug, Serialize)]
608pub(crate) struct TaskChartListRequest {
609 pub(crate) task_id: u64,
610 pub(crate) group_name: String,
611}
612
613#[derive(Debug, Serialize)]
614pub(crate) struct TaskChartGetRequest {
615 pub(crate) task_id: u64,
616 pub(crate) group_name: String,
617 pub(crate) chart_name: String,
618}
619
620#[derive(Debug, Serialize)]
621pub(crate) struct ValDataDownloadRequest {
622 pub(crate) session_id: u64,
623 pub(crate) filename: String,
624}
625
626#[derive(Debug, Serialize)]
627pub(crate) struct ValDataListRequest {
628 pub(crate) session_id: u64,
629}
630
631/// Streams the body of a successful `reqwest` response to a file on disk,
632/// emitting optional progress events.
633///
634/// Both `download_artifact` and `rpc_download` share this logic. The caller is
635/// responsible for creating any required parent directories before calling this
636/// function.
637///
638/// # Arguments
639/// * `resp` - A successful (HTTP 2xx) `reqwest::Response` whose body will
640/// be streamed to `path`.
641/// * `path` - Destination file path (created or truncated).
642/// * `progress` - Optional channel; events carry bytes received and
643/// `Content-Length` total (0 if the server omits it).
644///
645/// # Errors
646/// Returns `Error::IoError` on file I/O failures or propagates stream errors.
647async fn stream_response_to_file(
648 resp: reqwest::Response,
649 path: &std::path::Path,
650 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
651) -> Result<(), Error> {
652 use tokio::io::AsyncWriteExt as _;
653 let total = resp.content_length().unwrap_or(0) as usize;
654 let mut stream = resp.bytes_stream();
655 let mut file = tokio::fs::File::create(path).await?;
656 let mut current = 0usize;
657
658 if let Some(ref tx) = progress {
659 let _ = tx
660 .send(Progress {
661 current: 0,
662 total,
663 status: None,
664 })
665 .await;
666 }
667
668 while let Some(chunk) = stream.next().await {
669 let chunk = chunk?;
670 file.write_all(&chunk).await?;
671 current += chunk.len();
672 if let Some(ref tx) = progress {
673 let _ = tx
674 .send(Progress {
675 current,
676 total,
677 status: None,
678 })
679 .await;
680 }
681 }
682
683 // Flush tokio's internal write buffer to the OS before returning.
684 // tokio::fs::File buffers writes internally; without this, the buffer
685 // may not reach the filesystem before the caller reads the file.
686 file.flush().await?;
687 Ok(())
688}
689
690impl Client {
691 /// Create a new unauthenticated client with the default saas server.
692 ///
693 /// By default, the client uses [`FileTokenStorage`] for token persistence.
694 /// Use [`with_storage`][Self::with_storage],
695 /// [`with_memory_storage`][Self::with_memory_storage],
696 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
697 /// behavior.
698 ///
699 /// To connect to a different server, use [`with_server`][Self::with_server]
700 /// or [`with_token`][Self::with_token] (tokens include the server
701 /// instance).
702 ///
703 /// This client is created without a token and will need to authenticate
704 /// before using methods that require authentication.
705 ///
706 /// # Examples
707 ///
708 /// ```rust,no_run
709 /// use edgefirst_client::Client;
710 ///
711 /// # fn main() -> Result<(), edgefirst_client::Error> {
712 /// // Create client with default file storage
713 /// let client = Client::new()?;
714 ///
715 /// // Create client without token persistence
716 /// let client = Client::new()?.with_memory_storage();
717 /// # Ok(())
718 /// # }
719 /// ```
720 pub fn new() -> Result<Self, Error> {
721 log_retry_configuration();
722
723 // Get timeout from environment or use default
724 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
725 .ok()
726 .and_then(|s| s.parse().ok())
727 .unwrap_or(30); // Default 30s total deadline for API calls
728
729 // Per-chunk idle timeout for bulk transfers: fires only when no bytes
730 // arrive for this duration. Resets after every received chunk, so a
731 // healthy multi-GB transfer will never be interrupted.
732 let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
733 .ok()
734 .and_then(|s| s.parse().ok())
735 .unwrap_or(120); // Default 120s idle timeout for bulk transfers
736
737 // Create single HTTP client with URL-based retry policy
738 //
739 // The retry policy classifies requests into two categories:
740 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
741 // errors
742 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
743 //
744 // This allows the same client to handle both API calls and file operations
745 // with appropriate retry behavior for each. See retry.rs for details.
746 let http = reqwest::Client::builder()
747 .connect_timeout(Duration::from_secs(10))
748 .timeout(Duration::from_secs(timeout_secs))
749 .pool_idle_timeout(Duration::from_secs(90))
750 .pool_max_idle_per_host(10)
751 .retry(create_retry_policy())
752 .build()?;
753
754 // Separate HTTP client for bulk transfers (uploads and downloads).
755 // No total-request timeout (EDGEFIRST_TIMEOUT does not apply here).
756 // Uses read_timeout instead: resets after every received chunk, so a
757 // healthy large transfer is never interrupted, but a truly stalled
758 // connection (no bytes for EDGEFIRST_READ_TIMEOUT seconds) is aborted.
759 let bulk_http = reqwest::Client::builder()
760 .connect_timeout(Duration::from_secs(30))
761 .read_timeout(Duration::from_secs(read_timeout_secs))
762 .pool_idle_timeout(Duration::from_secs(90))
763 // Bulk file transfers fan out to many concurrent presigned-URL
764 // uploads — up to `EDGEFIRST_UPLOAD_BATCHES` pipelined batches ×
765 // `max_tasks()` uploads each. Keep enough idle connections warm to
766 // reuse across that fan-out instead of churning new TLS handshakes.
767 .pool_max_idle_per_host(64)
768 .retry(create_retry_policy())
769 .build()?;
770
771 // Default to file storage, loading any existing token
772 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
773 Ok(file_storage) => Arc::new(file_storage),
774 Err(e) => {
775 warn!(
776 "Could not initialize file token storage: {}. Using memory storage.",
777 e
778 );
779 Arc::new(MemoryTokenStorage::new())
780 }
781 };
782
783 // Try to load existing token from storage
784 let token = match storage.load() {
785 Ok(Some(t)) => t,
786 Ok(None) => String::new(),
787 Err(e) => {
788 warn!(
789 "Failed to load token from storage: {}. Starting with empty token.",
790 e
791 );
792 String::new()
793 }
794 };
795
796 // Extract server from token if available
797 let url = if !token.is_empty() {
798 match Self::extract_server_from_token(&token) {
799 Ok(server) => format!("https://{}.edgefirst.studio", server),
800 Err(e) => {
801 warn!(
802 "Failed to extract server from token: {}. Using default server.",
803 e
804 );
805 "https://edgefirst.studio".to_string()
806 }
807 }
808 } else {
809 "https://edgefirst.studio".to_string()
810 };
811
812 Ok(Client {
813 http,
814 bulk_http,
815 url,
816 token: Arc::new(tokio::sync::RwLock::new(token)),
817 storage: Some(storage),
818 token_path: None,
819 })
820 }
821
822 /// Returns a new client connected to the specified server instance.
823 ///
824 /// The server parameter is an instance name that maps to a URL:
825 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
826 /// server)
827 /// - `"test"` → `https://test.edgefirst.studio`
828 /// - `"stage"` → `https://stage.edgefirst.studio`
829 /// - `"dev"` → `https://dev.edgefirst.studio`
830 /// - `"{name}"` → `https://{name}.edgefirst.studio`
831 ///
832 /// # Server Selection Priority
833 ///
834 /// When using the CLI or Python API, server selection follows this
835 /// priority:
836 ///
837 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
838 /// they were issued for. If you have a valid token, its server is used.
839 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
840 /// token is available. If a token exists with a different server, a
841 /// warning is emitted and the token's server takes priority.
842 /// 3. **Default `"saas"`** - If no token and no server specified, the
843 /// production server (`https://edgefirst.studio`) is used.
844 ///
845 /// # Important Notes
846 ///
847 /// - If a token is already set in the client, calling this method will
848 /// **drop the token** as tokens are specific to the server instance.
849 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
850 /// token's server before calling this method.
851 /// - For login operations, call `with_server()` first, then authenticate.
852 ///
853 /// # Examples
854 ///
855 /// ```rust,no_run
856 /// use edgefirst_client::Client;
857 ///
858 /// # fn main() -> Result<(), edgefirst_client::Error> {
859 /// let client = Client::new()?.with_server("test")?;
860 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
861 /// # Ok(())
862 /// # }
863 /// ```
864 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
865 // Resolve the target URL. Full URLs (self-hosted Studio,
866 // wiremock) are validated through `with_url` so the HTTPS rules
867 // there apply uniformly. Short names map to the SaaS pattern.
868 // We extract only the URL string and rebuild the Client below,
869 // because `with_url` preserves the in-memory token (the contract
870 // for self-hosted deployments) whereas `with_server` deliberately
871 // clears it (a different server means a stale token).
872 let url = if server.starts_with("http://") || server.starts_with("https://") {
873 self.with_url(server)?.url().to_string()
874 } else {
875 match server {
876 "" | "saas" => "https://edgefirst.studio".to_string(),
877 name => format!("https://{}.edgefirst.studio", name),
878 }
879 };
880
881 // Clear token from storage when changing servers to prevent
882 // authentication issues with stale tokens from different
883 // instances. This runs whether the caller passed a short name
884 // or a full URL — both reach a new server.
885 if let Some(ref storage) = self.storage
886 && let Err(e) = storage.clear()
887 {
888 warn!(
889 "Failed to clear token from storage when changing servers: {}",
890 e
891 );
892 }
893
894 Ok(Client {
895 url,
896 token: Arc::new(tokio::sync::RwLock::new(String::new())),
897 ..self.clone()
898 })
899 }
900
901 /// Returns a new client pointed at an explicit URL.
902 ///
903 /// Used for self-hosted Studio deployments (e.g.
904 /// `https://studio.example.com`) and for offline integration tests
905 /// against a mock HTTP server (e.g. `http://127.0.0.1:8080`). The
906 /// token is preserved so callers can chain
907 /// `Client::new()?.with_url(...)?.with_token(...)`.
908 ///
909 /// # Errors
910 ///
911 /// Returns [`Error::UrlParseError`] for syntactically invalid URLs and
912 /// [`Error::InsecureUrl`] for plain `http://` URLs that resolve to a
913 /// non-loopback host: the Studio bearer token rides in the
914 /// `Authorization` header, and plain HTTP would leak it in the clear.
915 /// Loopback URLs (`127.0.0.1`, `::1`, `localhost`, `*.localhost`) are
916 /// permitted because traffic never leaves the machine — wiremock and
917 /// local dev servers go through that path.
918 pub fn with_url(&self, url: &str) -> Result<Self, Error> {
919 // Reject malformed inputs early so test failures point at the test
920 // rather than a downstream reqwest send.
921 let parsed = url::Url::parse(url)?;
922 let scheme = parsed.scheme();
923 if scheme == "http" {
924 if !is_loopback_host(parsed.host().as_ref()) {
925 return Err(Error::InsecureUrl(url.to_string()));
926 }
927 } else if scheme != "https" {
928 return Err(Error::InsecureUrl(url.to_string()));
929 }
930 Ok(Client {
931 url: url.trim_end_matches('/').to_string(),
932 ..self.clone()
933 })
934 }
935
936 /// Returns a new client with the specified token storage backend.
937 ///
938 /// Use this to configure custom token storage, such as platform-specific
939 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
940 ///
941 /// # Examples
942 ///
943 /// ```rust,no_run
944 /// use edgefirst_client::{Client, FileTokenStorage};
945 /// use std::{path::PathBuf, sync::Arc};
946 ///
947 /// # fn main() -> Result<(), edgefirst_client::Error> {
948 /// // Use a custom file path for token storage
949 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
950 /// let client = Client::new()?.with_storage(Arc::new(storage));
951 /// # Ok(())
952 /// # }
953 /// ```
954 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
955 // Try to load existing token from the new storage
956 let token = match storage.load() {
957 Ok(Some(t)) => t,
958 Ok(None) => String::new(),
959 Err(e) => {
960 warn!(
961 "Failed to load token from storage: {}. Starting with empty token.",
962 e
963 );
964 String::new()
965 }
966 };
967
968 Client {
969 token: Arc::new(tokio::sync::RwLock::new(token)),
970 storage: Some(storage),
971 token_path: None,
972 ..self
973 }
974 }
975
976 /// Returns a new client with in-memory token storage (no persistence).
977 ///
978 /// Tokens are stored in memory only and lost when the application exits.
979 /// This is useful for testing or when you want to manage token persistence
980 /// externally.
981 ///
982 /// # Examples
983 ///
984 /// ```rust,no_run
985 /// use edgefirst_client::Client;
986 ///
987 /// # fn main() -> Result<(), edgefirst_client::Error> {
988 /// let client = Client::new()?.with_memory_storage();
989 /// # Ok(())
990 /// # }
991 /// ```
992 pub fn with_memory_storage(self) -> Self {
993 Client {
994 token: Arc::new(tokio::sync::RwLock::new(String::new())),
995 storage: Some(Arc::new(MemoryTokenStorage::new())),
996 token_path: None,
997 ..self
998 }
999 }
1000
1001 /// Returns a new client with no token storage.
1002 ///
1003 /// Tokens are not persisted. Use this when you want to manage tokens
1004 /// entirely manually.
1005 ///
1006 /// # Examples
1007 ///
1008 /// ```rust,no_run
1009 /// use edgefirst_client::Client;
1010 ///
1011 /// # fn main() -> Result<(), edgefirst_client::Error> {
1012 /// let client = Client::new()?.with_no_storage();
1013 /// # Ok(())
1014 /// # }
1015 /// ```
1016 pub fn with_no_storage(self) -> Self {
1017 Client {
1018 storage: None,
1019 token_path: None,
1020 ..self
1021 }
1022 }
1023
1024 /// Returns a new client authenticated with the provided username and
1025 /// password.
1026 ///
1027 /// The token is automatically persisted to storage (if configured).
1028 ///
1029 /// # Examples
1030 ///
1031 /// ```rust,no_run
1032 /// use edgefirst_client::Client;
1033 ///
1034 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1035 /// let client = Client::new()?
1036 /// .with_server("test")?
1037 /// .with_login("user@example.com", "password")
1038 /// .await?;
1039 /// # Ok(())
1040 /// # }
1041 /// ```
1042 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
1043 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
1044 let params = HashMap::from([("username", username), ("password", password)]);
1045 let login: LoginResult = self
1046 .rpc_without_auth("auth.login".to_owned(), Some(params))
1047 .await?;
1048
1049 // Validate that the server returned a non-empty token
1050 if login.token.is_empty() {
1051 return Err(Error::EmptyToken);
1052 }
1053
1054 // Persist token to storage if configured
1055 if let Some(ref storage) = self.storage
1056 && let Err(e) = storage.store(&login.token)
1057 {
1058 warn!("Failed to persist token to storage: {}", e);
1059 }
1060
1061 Ok(Client {
1062 token: Arc::new(tokio::sync::RwLock::new(login.token)),
1063 ..self.clone()
1064 })
1065 }
1066
1067 /// Returns a new client which will load and save the token to the specified
1068 /// path.
1069 ///
1070 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
1071 /// [`FileTokenStorage`] instead for more flexible token management.
1072 ///
1073 /// This method is maintained for backwards compatibility with existing
1074 /// code. It disables the default storage and uses file-based storage at
1075 /// the specified path.
1076 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
1077 let token_path = match token_path {
1078 Some(path) => path.to_path_buf(),
1079 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1080 .ok_or_else(|| {
1081 Error::IoError(std::io::Error::new(
1082 std::io::ErrorKind::NotFound,
1083 "Could not determine user config directory",
1084 ))
1085 })?
1086 .config_dir()
1087 .join("token"),
1088 };
1089
1090 debug!("Using token path (legacy): {:?}", token_path);
1091
1092 let token = match token_path.exists() {
1093 true => std::fs::read_to_string(&token_path)?,
1094 false => "".to_string(),
1095 };
1096
1097 if !token.is_empty() {
1098 match self.with_token(&token) {
1099 Ok(client) => Ok(Client {
1100 token_path: Some(token_path),
1101 storage: None, // Disable new storage when using legacy token_path
1102 ..client
1103 }),
1104 Err(e) => {
1105 // Token is corrupted or invalid - remove it and continue with no token
1106 warn!(
1107 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
1108 token_path, e
1109 );
1110 if let Err(remove_err) = std::fs::remove_file(&token_path) {
1111 warn!("Failed to remove corrupted token file: {:?}", remove_err);
1112 }
1113 // Clear any token from default storage to ensure we don't use it
1114 Ok(Client {
1115 token_path: Some(token_path),
1116 storage: None,
1117 token: Arc::new(RwLock::new("".to_string())),
1118 ..self.clone()
1119 })
1120 }
1121 }
1122 } else {
1123 // No token in the legacy file - clear any token from default storage
1124 Ok(Client {
1125 token_path: Some(token_path),
1126 storage: None,
1127 token: Arc::new(RwLock::new("".to_string())),
1128 ..self.clone()
1129 })
1130 }
1131 }
1132
1133 /// Returns a new client authenticated with the provided token.
1134 ///
1135 /// The token is automatically persisted to storage (if configured).
1136 /// The server URL is extracted from the token payload.
1137 ///
1138 /// # Examples
1139 ///
1140 /// ```rust,no_run
1141 /// use edgefirst_client::Client;
1142 ///
1143 /// # fn main() -> Result<(), edgefirst_client::Error> {
1144 /// let client = Client::new()?.with_token("your-jwt-token")?;
1145 /// # Ok(())
1146 /// # }
1147 /// ```
1148 /// Extract server name from JWT token payload.
1149 ///
1150 /// Helper method to parse the JWT token and extract the "server" field
1151 /// from the payload. Returns the server name (e.g., "test", "stage", "")
1152 /// or an error if the token is invalid.
1153 fn extract_server_from_token(token: &str) -> Result<String, Error> {
1154 let token_parts: Vec<&str> = token.split('.').collect();
1155 if token_parts.len() != 3 {
1156 return Err(Error::InvalidToken);
1157 }
1158
1159 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1160 .decode(token_parts[1])
1161 .map_err(|_| Error::InvalidToken)?;
1162 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1163 let server = match payload.get("server") {
1164 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
1165 None => return Err(Error::InvalidToken),
1166 };
1167
1168 Ok(server)
1169 }
1170
1171 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
1172 if token.is_empty() {
1173 return Ok(self.clone());
1174 }
1175
1176 let server = Self::extract_server_from_token(token)?;
1177
1178 // Persist token to storage if configured
1179 if let Some(ref storage) = self.storage
1180 && let Err(e) = storage.store(token)
1181 {
1182 warn!("Failed to persist token to storage: {}", e);
1183 }
1184
1185 Ok(Client {
1186 url: format!("https://{}.edgefirst.studio", server),
1187 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
1188 ..self.clone()
1189 })
1190 }
1191
1192 /// Persist the current token to storage.
1193 ///
1194 /// This is automatically called when using [`with_login`][Self::with_login]
1195 /// or [`with_token`][Self::with_token], so you typically don't need to call
1196 /// this directly.
1197 ///
1198 /// If using the legacy `token_path` configuration, saves to the file path.
1199 /// If using the new storage abstraction, saves to the configured storage.
1200 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1201 pub async fn save_token(&self) -> Result<(), Error> {
1202 let token = self.token.read().await;
1203
1204 // Try new storage first
1205 if let Some(ref storage) = self.storage {
1206 storage.store(&token)?;
1207 debug!("Token saved to storage");
1208 return Ok(());
1209 }
1210
1211 // Fall back to legacy token_path behavior
1212 let path = self.token_path.clone().unwrap_or_else(|| {
1213 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1214 .map(|dirs| dirs.config_dir().join("token"))
1215 .unwrap_or_else(|| PathBuf::from(".token"))
1216 });
1217
1218 create_dir_all(path.parent().ok_or_else(|| {
1219 Error::IoError(std::io::Error::new(
1220 std::io::ErrorKind::InvalidInput,
1221 "Token path has no parent directory",
1222 ))
1223 })?)?;
1224 let mut file = std::fs::File::create(&path)?;
1225 file.write_all(token.as_bytes())?;
1226
1227 debug!("Saved token to {:?}", path);
1228
1229 Ok(())
1230 }
1231
1232 /// Return the version of the EdgeFirst Studio server for the current
1233 /// client connection.
1234 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1235 pub async fn version(&self) -> Result<String, Error> {
1236 let version: HashMap<String, String> = self
1237 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
1238 .await?;
1239 let version = version.get("version").ok_or(Error::InvalidResponse)?;
1240 Ok(version.to_owned())
1241 }
1242
1243 /// Clear the token used to authenticate the client with the server.
1244 ///
1245 /// Clears the token from memory and from storage (if configured).
1246 /// If using the legacy `token_path` configuration, removes the token file.
1247 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1248 pub async fn logout(&self) -> Result<(), Error> {
1249 {
1250 let mut token = self.token.write().await;
1251 *token = "".to_string();
1252 }
1253
1254 // Clear from new storage if configured
1255 if let Some(ref storage) = self.storage
1256 && let Err(e) = storage.clear()
1257 {
1258 warn!("Failed to clear token from storage: {}", e);
1259 }
1260
1261 // Also clear legacy token_path if configured
1262 if let Some(path) = &self.token_path
1263 && path.exists()
1264 {
1265 fs::remove_file(path).await?;
1266 }
1267
1268 Ok(())
1269 }
1270
1271 /// Return the token used to authenticate the client with the server. When
1272 /// logging into the server using a username and password, the token is
1273 /// returned by the server and stored in the client for future interactions.
1274 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1275 pub async fn token(&self) -> String {
1276 self.token.read().await.clone()
1277 }
1278
1279 /// Verify the token used to authenticate the client with the server. This
1280 /// method is used to ensure that the token is still valid and has not
1281 /// expired. If the token is invalid, the server will return an error and
1282 /// the client will need to login again.
1283 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1284 pub async fn verify_token(&self) -> Result<(), Error> {
1285 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
1286 .await?;
1287 Ok::<(), Error>(())
1288 }
1289
1290 /// Renew the token used to authenticate the client with the server.
1291 ///
1292 /// Refreshes the token before it expires. If the token has already expired,
1293 /// the server will return an error and you will need to login again.
1294 ///
1295 /// The new token is automatically persisted to storage (if configured).
1296 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1297 pub async fn renew_token(&self) -> Result<(), Error> {
1298 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1299 let result: LoginResult = self
1300 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1301 .await?;
1302
1303 {
1304 let mut token = self.token.write().await;
1305 *token = result.token.clone();
1306 }
1307
1308 // Persist to new storage if configured
1309 if let Some(ref storage) = self.storage
1310 && let Err(e) = storage.store(&result.token)
1311 {
1312 warn!("Failed to persist renewed token to storage: {}", e);
1313 }
1314
1315 // Also persist to legacy token_path if configured
1316 if self.token_path.is_some() {
1317 self.save_token().await?;
1318 }
1319
1320 Ok(())
1321 }
1322
1323 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1324 let token = self.token.read().await;
1325 if token.is_empty() {
1326 return Err(Error::EmptyToken);
1327 }
1328
1329 let token_parts: Vec<&str> = token.split('.').collect();
1330 if token_parts.len() != 3 {
1331 return Err(Error::InvalidToken);
1332 }
1333
1334 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1335 .decode(token_parts[1])
1336 .map_err(|_| Error::InvalidToken)?;
1337 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1338 match payload.get(field) {
1339 Some(value) => Ok(value.to_owned()),
1340 None => Err(Error::InvalidToken),
1341 }
1342 }
1343
1344 /// Returns the URL of the EdgeFirst Studio server for the current client.
1345 pub fn url(&self) -> &str {
1346 &self.url
1347 }
1348
1349 /// Returns the server name for the current client.
1350 ///
1351 /// This extracts the server name from the client's URL:
1352 /// - `https://edgefirst.studio` → `"saas"`
1353 /// - `https://test.edgefirst.studio` → `"test"`
1354 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1355 ///
1356 /// # Examples
1357 ///
1358 /// ```rust,no_run
1359 /// use edgefirst_client::Client;
1360 ///
1361 /// # fn main() -> Result<(), edgefirst_client::Error> {
1362 /// let client = Client::new()?.with_server("test")?;
1363 /// assert_eq!(client.server(), "test");
1364 ///
1365 /// let client = Client::new()?; // default
1366 /// assert_eq!(client.server(), "saas");
1367 /// # Ok(())
1368 /// # }
1369 /// ```
1370 pub fn server(&self) -> &str {
1371 if self.url == "https://edgefirst.studio" {
1372 "saas"
1373 } else if let Some(name) = self.url.strip_prefix("https://") {
1374 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1375 } else {
1376 "saas"
1377 }
1378 }
1379
1380 /// Returns the username associated with the current token.
1381 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1382 pub async fn username(&self) -> Result<String, Error> {
1383 match self.token_field("username").await? {
1384 serde_json::Value::String(username) => Ok(username),
1385 _ => Err(Error::InvalidToken),
1386 }
1387 }
1388
1389 /// Returns the expiration time for the current token.
1390 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1391 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1392 let ts = match self.token_field("exp").await? {
1393 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1394 _ => return Err(Error::InvalidToken),
1395 };
1396
1397 match DateTime::<Utc>::from_timestamp(ts, 0) {
1398 Some(dt) => Ok(dt),
1399 None => Err(Error::InvalidToken),
1400 }
1401 }
1402
1403 /// Returns the organization information for the current user.
1404 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1405 pub async fn organization(&self) -> Result<Organization, Error> {
1406 self.rpc::<(), Organization>("org.get".to_owned(), None)
1407 .await
1408 }
1409
1410 /// Returns a list of projects available to the user. The projects are
1411 /// returned as a vector of Project objects. If a name filter is
1412 /// provided, only projects matching the filter are returned.
1413 ///
1414 /// Results are sorted by match quality: exact matches first, then
1415 /// case-insensitive exact matches, then shorter names (more specific),
1416 /// then alphabetically.
1417 ///
1418 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1419 /// Projects contain datasets, trainers, and trainer sessions. Projects
1420 /// are used to group related datasets and trainers together.
1421 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1422 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1423 let projects = self
1424 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1425 .await?;
1426 if let Some(name) = name {
1427 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1428 } else {
1429 Ok(projects)
1430 }
1431 }
1432
1433 /// Return the project with the specified project ID. If the project does
1434 /// not exist, an error is returned.
1435 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1436 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1437 let params = HashMap::from([("project_id", project_id)]);
1438 self.rpc("project.get".to_owned(), Some(params)).await
1439 }
1440
1441 /// Returns a list of datasets available to the user. The datasets are
1442 /// returned as a vector of Dataset objects. If a name filter is
1443 /// provided, only datasets matching the filter are returned.
1444 ///
1445 /// Results are sorted by match quality: exact matches first, then
1446 /// case-insensitive exact matches, then shorter names (more specific),
1447 /// then alphabetically. This ensures "Deer" returns before "Deer
1448 /// Roundtrip".
1449 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1450 pub async fn datasets(
1451 &self,
1452 project_id: ProjectID,
1453 name: Option<&str>,
1454 ) -> Result<Vec<Dataset>, Error> {
1455 let params = HashMap::from([("project_id", project_id)]);
1456 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1457 if let Some(name) = name {
1458 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1459 } else {
1460 Ok(datasets)
1461 }
1462 }
1463
1464 /// Return the dataset with the specified dataset ID. If the dataset does
1465 /// not exist, an error is returned.
1466 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1467 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1468 let params = HashMap::from([("dataset_id", dataset_id)]);
1469 self.rpc("dataset.get".to_owned(), Some(params)).await
1470 }
1471
1472 /// Lists the labels for the specified dataset.
1473 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1474 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1475 let params = HashMap::from([("dataset_id", dataset_id)]);
1476 self.rpc("label.list".to_owned(), Some(params)).await
1477 }
1478
1479 /// Add a new label to the dataset with the specified name.
1480 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1481 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1482 let new_label = NewLabel {
1483 dataset_id,
1484 labels: vec![NewLabelObject {
1485 name: name.to_owned(),
1486 }],
1487 };
1488 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1489 Ok(())
1490 }
1491
1492 /// Add multiple labels to the dataset in a single request.
1493 ///
1494 /// Equivalent to calling [`add_label`](Self::add_label) for each name but in
1495 /// one round-trip. Useful before a bulk/concurrent upload: pre-creating the
1496 /// full label set serially avoids many concurrent `populate2` calls racing to
1497 /// create the same label server-side. Names already present are not
1498 /// duplicated by the server. A no-op when `names` is empty.
1499 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, names), fields(dataset_id = %dataset_id, count = names.len())))]
1500 pub async fn add_labels(&self, dataset_id: DatasetID, names: &[String]) -> Result<(), Error> {
1501 if names.is_empty() {
1502 return Ok(());
1503 }
1504 let new_label = NewLabel {
1505 dataset_id,
1506 labels: names
1507 .iter()
1508 .map(|name| NewLabelObject { name: name.clone() })
1509 .collect(),
1510 };
1511 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1512 Ok(())
1513 }
1514
1515 /// Removes the label with the specified ID from the dataset. Label IDs are
1516 /// globally unique so the dataset_id is not required.
1517 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1518 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1519 let params = HashMap::from([("label_id", label_id)]);
1520 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1521 Ok(())
1522 }
1523
1524 /// Creates a new dataset in the specified project.
1525 ///
1526 /// # Arguments
1527 ///
1528 /// * `project_id` - The ID of the project to create the dataset in
1529 /// * `name` - The name of the new dataset
1530 /// * `description` - Optional description for the dataset
1531 ///
1532 /// # Returns
1533 ///
1534 /// Returns the dataset ID of the newly created dataset.
1535 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1536 pub async fn create_dataset(
1537 &self,
1538 project_id: &str,
1539 name: &str,
1540 description: Option<&str>,
1541 ) -> Result<DatasetID, Error> {
1542 let mut params = HashMap::new();
1543 params.insert("project_id", project_id);
1544 params.insert("name", name);
1545 if let Some(desc) = description {
1546 params.insert("description", desc);
1547 }
1548
1549 #[derive(Deserialize)]
1550 struct CreateDatasetResult {
1551 id: DatasetID,
1552 }
1553
1554 let result: CreateDatasetResult =
1555 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1556 Ok(result.id)
1557 }
1558
1559 /// Deletes a dataset by marking it as deleted.
1560 ///
1561 /// # Arguments
1562 ///
1563 /// * `dataset_id` - The ID of the dataset to delete
1564 ///
1565 /// # Returns
1566 ///
1567 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1568 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1569 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1570 let params = HashMap::from([("id", dataset_id)]);
1571 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1572 Ok(())
1573 }
1574
1575 /// Updates the label with the specified ID to have the new name or index.
1576 /// Label IDs cannot be changed. Label IDs are globally unique so the
1577 /// dataset_id is not required.
1578 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1579 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1580 #[derive(Serialize)]
1581 struct Params {
1582 dataset_id: DatasetID,
1583 label_id: u64,
1584 label_name: String,
1585 label_index: u64,
1586 }
1587
1588 let _: String = self
1589 .rpc(
1590 "label.update".to_owned(),
1591 Some(Params {
1592 dataset_id: label.dataset_id(),
1593 label_id: label.id(),
1594 label_name: label.name().to_owned(),
1595 label_index: label.index(),
1596 }),
1597 )
1598 .await?;
1599 Ok(())
1600 }
1601
1602 /// Lists the groups for the specified dataset.
1603 ///
1604 /// Groups are used to organize samples into logical subsets such as
1605 /// "train", "val", "test", etc. Each sample can belong to at most one
1606 /// group at a time.
1607 ///
1608 /// # Arguments
1609 ///
1610 /// * `dataset_id` - The ID of the dataset to list groups for
1611 ///
1612 /// # Returns
1613 ///
1614 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1615 /// empty vector if no groups have been created yet.
1616 ///
1617 /// # Errors
1618 ///
1619 /// Returns an error if the dataset does not exist or cannot be accessed.
1620 ///
1621 /// # Example
1622 ///
1623 /// ```rust,no_run
1624 /// # use edgefirst_client::{Client, DatasetID};
1625 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1626 /// let client = Client::new()?.with_token_path(None)?;
1627 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1628 ///
1629 /// let groups = client.groups(dataset_id).await?;
1630 /// for group in groups {
1631 /// println!("{}: {}", group.id, group.name);
1632 /// }
1633 /// # Ok(())
1634 /// # }
1635 /// ```
1636 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1637 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1638 let params = HashMap::from([("dataset_id", dataset_id)]);
1639 self.rpc("groups.list".to_owned(), Some(params)).await
1640 }
1641
1642 /// Gets an existing group by name or creates a new one.
1643 ///
1644 /// This is a convenience method that first checks if a group with the
1645 /// specified name exists, and creates it if not. This is useful when
1646 /// you need to ensure a group exists before assigning samples to it.
1647 ///
1648 /// # Arguments
1649 ///
1650 /// * `dataset_id` - The ID of the dataset
1651 /// * `name` - The name of the group (e.g., "train", "val", "test")
1652 ///
1653 /// # Returns
1654 ///
1655 /// Returns the group ID (either existing or newly created).
1656 ///
1657 /// # Errors
1658 ///
1659 /// Returns an error if:
1660 /// - The dataset does not exist or cannot be accessed
1661 /// - The group creation fails
1662 ///
1663 /// # Concurrency
1664 ///
1665 /// This method handles concurrent creation attempts gracefully. If another
1666 /// process creates the group between the existence check and creation,
1667 /// this method will return the existing group's ID.
1668 ///
1669 /// # Example
1670 ///
1671 /// ```rust,no_run
1672 /// # use edgefirst_client::{Client, DatasetID};
1673 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1674 /// let client = Client::new()?.with_token_path(None)?;
1675 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1676 ///
1677 /// // Get or create a "train" group
1678 /// let train_group_id = client
1679 /// .get_or_create_group(dataset_id.clone(), "train")
1680 /// .await?;
1681 /// println!("Train group ID: {}", train_group_id);
1682 ///
1683 /// // Calling again returns the same ID
1684 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1685 /// assert_eq!(train_group_id, same_id);
1686 /// # Ok(())
1687 /// # }
1688 /// ```
1689 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1690 pub async fn get_or_create_group(
1691 &self,
1692 dataset_id: DatasetID,
1693 name: &str,
1694 ) -> Result<u64, Error> {
1695 // First check if the group already exists
1696 let groups = self.groups(dataset_id).await?;
1697 if let Some(group) = groups.iter().find(|g| g.name == name) {
1698 return Ok(group.id);
1699 }
1700
1701 // Create the group
1702 #[derive(Serialize)]
1703 struct CreateGroupParams {
1704 dataset_id: DatasetID,
1705 group_names: Vec<String>,
1706 group_splits: Vec<i64>,
1707 }
1708
1709 let params = CreateGroupParams {
1710 dataset_id,
1711 group_names: vec![name.to_string()],
1712 group_splits: vec![0], // No automatic splitting
1713 };
1714
1715 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1716 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1717 Ok(group.id)
1718 } else {
1719 // Group might have been created by concurrent call, try fetching again
1720 let groups = self.groups(dataset_id).await?;
1721 groups
1722 .iter()
1723 .find(|g| g.name == name)
1724 .map(|g| g.id)
1725 .ok_or_else(|| {
1726 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1727 })
1728 }
1729 }
1730
1731 /// Sets the group for a sample.
1732 ///
1733 /// Assigns a sample to a specific group. Each sample can belong to at most
1734 /// one group at a time. Setting a new group replaces any existing group
1735 /// assignment.
1736 ///
1737 /// # Arguments
1738 ///
1739 /// * `sample_id` - The ID of the sample (image) to update
1740 /// * `group_id` - The ID of the group to assign. Use
1741 /// [`get_or_create_group`] to obtain a group ID from a name.
1742 ///
1743 /// # Returns
1744 ///
1745 /// Returns `Ok(())` on success.
1746 ///
1747 /// # Errors
1748 ///
1749 /// Returns an error if:
1750 /// - The sample does not exist
1751 /// - The group does not exist
1752 /// - Insufficient permissions to modify the sample
1753 ///
1754 /// # Example
1755 ///
1756 /// ```rust,no_run
1757 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1758 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1759 /// let client = Client::new()?.with_token_path(None)?;
1760 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1761 /// let sample_id: SampleID = 12345.into();
1762 ///
1763 /// // Get or create the "val" group
1764 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1765 ///
1766 /// // Assign the sample to the "val" group
1767 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1768 /// # Ok(())
1769 /// # }
1770 /// ```
1771 ///
1772 /// [`get_or_create_group`]: Self::get_or_create_group
1773 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1774 pub async fn set_sample_group_id(
1775 &self,
1776 sample_id: SampleID,
1777 group_id: u64,
1778 ) -> Result<(), Error> {
1779 #[derive(Serialize)]
1780 struct SetGroupParams {
1781 image_id: SampleID,
1782 group_id: u64,
1783 }
1784
1785 let params = SetGroupParams {
1786 image_id: sample_id,
1787 group_id,
1788 };
1789 let _: String = self
1790 .rpc("image.set_group_id".to_owned(), Some(params))
1791 .await?;
1792 Ok(())
1793 }
1794
1795 /// Downloads dataset samples to the local filesystem.
1796 ///
1797 /// # Arguments
1798 ///
1799 /// * `dataset_id` - The unique identifier of the dataset
1800 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1801 /// * `file_types` - File types to download. Supported types:
1802 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
1803 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
1804 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
1805 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
1806 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
1807 /// - `FileType::RadarCube` - Radar cube data (.png format)
1808 /// - `FileType::All` - All sensor types (expands to all of the above)
1809 /// * `output` - Local directory to save downloaded files
1810 /// * `flatten` - If true, download all files to output root without
1811 /// sequence subdirectories. When flattening, filenames are prefixed with
1812 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1813 /// unavailable) unless the filename already starts with
1814 /// `{sequence_name}_`, to avoid conflicts between sequences.
1815 /// * `progress` - Optional channel for progress updates
1816 ///
1817 /// # Progress
1818 ///
1819 /// This operation has two phases with distinct progress reporting:
1820 ///
1821 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
1822 /// from the server. Progress counts samples fetched.
1823 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
1824 /// files to disk. Progress counts samples completed (each sample may
1825 /// have multiple files for different sensor types).
1826 ///
1827 /// Applications should detect the status change from `None` to
1828 /// `"Downloading"` to reset their progress bar for the second phase.
1829 ///
1830 /// # Returns
1831 ///
1832 /// Returns `Ok(())` on success or an error if download fails.
1833 ///
1834 /// # Example
1835 ///
1836 /// ```rust,no_run
1837 /// # use edgefirst_client::{Client, DatasetID, FileType};
1838 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1839 /// let client = Client::new()?.with_token_path(None)?;
1840 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1841 ///
1842 /// // Download with sequence subdirectories (default)
1843 /// client
1844 /// .download_dataset(
1845 /// dataset_id,
1846 /// &[],
1847 /// &[FileType::Image],
1848 /// "./data".into(),
1849 /// false,
1850 /// None,
1851 /// )
1852 /// .await?;
1853 ///
1854 /// // Download flattened (all files in one directory)
1855 /// client
1856 /// .download_dataset(
1857 /// dataset_id,
1858 /// &[],
1859 /// &[FileType::Image],
1860 /// "./data".into(),
1861 /// true,
1862 /// None,
1863 /// )
1864 /// .await?;
1865 ///
1866 /// // Download all sensor types
1867 /// client
1868 /// .download_dataset(
1869 /// dataset_id,
1870 /// &[],
1871 /// &FileType::expand_types(&[FileType::All]),
1872 /// "./data".into(),
1873 /// false,
1874 /// None,
1875 /// )
1876 /// .await?;
1877 /// # Ok(())
1878 /// # }
1879 /// ```
1880 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
1881 pub async fn download_dataset(
1882 &self,
1883 dataset_id: DatasetID,
1884 groups: &[String],
1885 file_types: &[FileType],
1886 output: PathBuf,
1887 flatten: bool,
1888 progress: Option<Sender<Progress>>,
1889 ) -> Result<(), Error> {
1890 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
1891 let samples = self
1892 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1893 .await?;
1894 fs::create_dir_all(&output).await?;
1895
1896 // Phase 2: Download actual files using direct semaphore pattern
1897 let total = samples.len();
1898 let current = Arc::new(AtomicUsize::new(0));
1899 let sem = Arc::new(Semaphore::new(max_tasks()));
1900
1901 // Send initial progress for download phase
1902 if let Some(ref progress) = progress {
1903 let _ = progress
1904 .send(Progress {
1905 current: 0,
1906 total,
1907 status: Some("Downloading".to_string()),
1908 })
1909 .await;
1910 }
1911
1912 let tasks = samples
1913 .into_iter()
1914 .map(|sample| {
1915 let client = self.clone();
1916 let file_types = file_types.to_vec();
1917 let output = output.clone();
1918 let progress = progress.clone();
1919 let current = current.clone();
1920 let sem = sem.clone();
1921
1922 tokio::spawn(async move {
1923 let _permit = sem.acquire().await.map_err(|_| {
1924 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1925 })?;
1926
1927 for file_type in &file_types {
1928 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1929 let (file_ext, is_image) = match file_type {
1930 FileType::Image => (
1931 infer::get(&data)
1932 .expect("Failed to identify image file format for sample")
1933 .extension()
1934 .to_string(),
1935 true,
1936 ),
1937 other => (other.file_extension().to_string(), false),
1938 };
1939
1940 // Determine target directory based on sequence membership and
1941 // flatten option
1942 // - flatten=false + sequence_name: dataset/sequence_name/
1943 // - flatten=false + no sequence: dataset/ (root level)
1944 // - flatten=true: dataset/ (all files in output root)
1945 // NOTE: group (train/val/test) is NOT used for directory structure
1946 let sequence_dir = sample
1947 .sequence_name()
1948 .map(|name| sanitize_path_component(name));
1949
1950 let target_dir = if flatten {
1951 output.clone()
1952 } else {
1953 sequence_dir
1954 .as_ref()
1955 .map(|seq| output.join(seq))
1956 .unwrap_or_else(|| output.clone())
1957 };
1958 fs::create_dir_all(&target_dir).await?;
1959
1960 let sanitized_sample_name = sample
1961 .name()
1962 .map(|name| sanitize_path_component(&name))
1963 .unwrap_or_else(|| "unknown".to_string());
1964
1965 let image_name = sample.image_name().map(sanitize_path_component);
1966
1967 // Construct filename with smart prefixing for flatten mode
1968 // When flatten=true and sample belongs to a sequence:
1969 // - Check if filename already starts with "{sequence_name}_"
1970 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1971 // - If yes, use filename as-is (already uniquely named)
1972 let file_name = if is_image {
1973 if let Some(img_name) = image_name {
1974 Client::build_filename(
1975 &img_name,
1976 flatten,
1977 sequence_dir.as_ref(),
1978 sample.frame_number(),
1979 )
1980 } else {
1981 format!("{}.{}", sanitized_sample_name, file_ext)
1982 }
1983 } else {
1984 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1985 Client::build_filename(
1986 &base_name,
1987 flatten,
1988 sequence_dir.as_ref(),
1989 sample.frame_number(),
1990 )
1991 };
1992
1993 let file_path = target_dir.join(&file_name);
1994
1995 let mut file = File::create(&file_path).await?;
1996 file.write_all(&data).await?;
1997 }
1998 }
1999
2000 // Update progress after sample completes
2001 if let Some(progress) = &progress {
2002 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
2003 let _ = progress
2004 .send(Progress {
2005 current: completed,
2006 total,
2007 status: Some("Downloading".to_string()),
2008 })
2009 .await;
2010 }
2011
2012 Ok::<(), Error>(())
2013 })
2014 })
2015 .collect::<Vec<_>>();
2016
2017 join_all(tasks)
2018 .await
2019 .into_iter()
2020 .collect::<Result<Vec<_>, _>>()?
2021 .into_iter()
2022 .collect::<Result<Vec<_>, _>>()?;
2023
2024 Ok(())
2025 }
2026
2027 /// Builds a filename with smart prefixing for flatten mode.
2028 ///
2029 /// When flattening sequences into a single directory, this function ensures
2030 /// unique filenames by checking if the sequence prefix already exists and
2031 /// adding it if necessary.
2032 ///
2033 /// # Logic
2034 ///
2035 /// - If `flatten=false`: returns `base_name` unchanged
2036 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
2037 /// - If `flatten=true` and in sequence:
2038 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
2039 /// unchanged
2040 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
2041 /// `{sequence_name}_{base_name}`
2042 fn build_filename(
2043 base_name: &str,
2044 flatten: bool,
2045 sequence_name: Option<&String>,
2046 frame_number: Option<u32>,
2047 ) -> String {
2048 if !flatten || sequence_name.is_none() {
2049 return base_name.to_string();
2050 }
2051
2052 let seq_name = sequence_name.unwrap();
2053 let prefix = format!("{}_", seq_name);
2054
2055 // Check if already prefixed with sequence name
2056 if base_name.starts_with(&prefix) {
2057 base_name.to_string()
2058 } else {
2059 // Add sequence (and optionally frame) prefix
2060 match frame_number {
2061 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
2062 None => format!("{}{}", prefix, base_name),
2063 }
2064 }
2065 }
2066
2067 /// List available annotation sets for the specified dataset.
2068 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2069 pub async fn annotation_sets(
2070 &self,
2071 dataset_id: DatasetID,
2072 ) -> Result<Vec<AnnotationSet>, Error> {
2073 let params = HashMap::from([("dataset_id", dataset_id)]);
2074 self.rpc("annset.list".to_owned(), Some(params)).await
2075 }
2076
2077 /// Create a new annotation set for the specified dataset.
2078 ///
2079 /// # Arguments
2080 ///
2081 /// * `dataset_id` - The ID of the dataset to create the annotation set in
2082 /// * `name` - The name of the new annotation set
2083 /// * `description` - Optional description for the annotation set
2084 ///
2085 /// # Returns
2086 ///
2087 /// Returns the annotation set ID of the newly created annotation set.
2088 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2089 pub async fn create_annotation_set(
2090 &self,
2091 dataset_id: DatasetID,
2092 name: &str,
2093 description: Option<&str>,
2094 ) -> Result<AnnotationSetID, Error> {
2095 #[derive(Serialize)]
2096 struct Params<'a> {
2097 dataset_id: DatasetID,
2098 name: &'a str,
2099 operator: &'a str,
2100 #[serde(skip_serializing_if = "Option::is_none")]
2101 description: Option<&'a str>,
2102 }
2103
2104 #[derive(Deserialize)]
2105 struct CreateAnnotationSetResult {
2106 id: AnnotationSetID,
2107 }
2108
2109 let username = self.username().await?;
2110 let result: CreateAnnotationSetResult = self
2111 .rpc(
2112 "annset.add".to_owned(),
2113 Some(Params {
2114 dataset_id,
2115 name,
2116 operator: &username,
2117 description,
2118 }),
2119 )
2120 .await?;
2121 Ok(result.id)
2122 }
2123
2124 /// Deletes an annotation set by marking it as deleted.
2125 ///
2126 /// # Arguments
2127 ///
2128 /// * `annotation_set_id` - The ID of the annotation set to delete
2129 ///
2130 /// # Returns
2131 ///
2132 /// Returns `Ok(())` if the annotation set was successfully marked as
2133 /// deleted.
2134 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2135 pub async fn delete_annotation_set(
2136 &self,
2137 annotation_set_id: AnnotationSetID,
2138 ) -> Result<(), Error> {
2139 let params = HashMap::from([("id", annotation_set_id)]);
2140 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
2141 Ok(())
2142 }
2143
2144 /// Retrieve the annotation set with the specified ID.
2145 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2146 pub async fn annotation_set(
2147 &self,
2148 annotation_set_id: AnnotationSetID,
2149 ) -> Result<AnnotationSet, Error> {
2150 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
2151 self.rpc("annset.get".to_owned(), Some(params)).await
2152 }
2153
2154 /// Get the annotations for the specified annotation set with the
2155 /// requested annotation types. The annotation types are used to filter
2156 /// the annotations returned. The groups parameter is used to filter for
2157 /// dataset groups (train, val, test). Images which do not have any
2158 /// annotations are also included in the result as long as they are in the
2159 /// requested groups (when specified).
2160 ///
2161 /// The result is a vector of Annotations objects which contain the
2162 /// full dataset along with the annotations for the specified types.
2163 ///
2164 /// # Progress
2165 ///
2166 /// Reports progress with `status: None` as samples are fetched and
2167 /// processed for their annotations. Progress unit is samples processed
2168 /// (not individual annotations).
2169 ///
2170 /// To get the annotations as a DataFrame, use the `samples_dataframe`
2171 /// method instead.
2172 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2173 pub async fn annotations(
2174 &self,
2175 annotation_set_id: AnnotationSetID,
2176 groups: &[String],
2177 annotation_types: &[AnnotationType],
2178 progress: Option<Sender<Progress>>,
2179 ) -> Result<Vec<Annotation>, Error> {
2180 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
2181 let labels = self
2182 .labels(dataset_id)
2183 .await?
2184 .into_iter()
2185 .map(|label| (label.name().to_string(), label.index()))
2186 .collect::<HashMap<_, _>>();
2187 let total = self
2188 .samples_count(
2189 dataset_id,
2190 Some(annotation_set_id),
2191 annotation_types,
2192 groups,
2193 &[],
2194 )
2195 .await?
2196 .total as usize;
2197
2198 if total == 0 {
2199 return Ok(vec![]);
2200 }
2201
2202 let context = FetchContext {
2203 dataset_id,
2204 annotation_set_id: Some(annotation_set_id),
2205 groups,
2206 types: annotation_types.iter().map(|t| t.to_string()).collect(),
2207 labels: &labels,
2208 };
2209
2210 self.fetch_annotations_paginated(context, total, progress)
2211 .await
2212 }
2213
2214 async fn fetch_annotations_paginated(
2215 &self,
2216 context: FetchContext<'_>,
2217 total: usize,
2218 progress: Option<Sender<Progress>>,
2219 ) -> Result<Vec<Annotation>, Error> {
2220 let mut annotations = vec![];
2221 let mut continue_token: Option<String> = None;
2222 let mut current = 0;
2223
2224 loop {
2225 let params = SamplesListParams {
2226 dataset_id: context.dataset_id,
2227 annotation_set_id: context.annotation_set_id,
2228 types: context.types.clone(),
2229 group_names: context.groups.to_vec(),
2230 continue_token,
2231 };
2232
2233 let result: SamplesListResult =
2234 self.rpc("samples.list".to_owned(), Some(params)).await?;
2235 current += result.samples.len();
2236 continue_token = result.continue_token;
2237
2238 if result.samples.is_empty() {
2239 break;
2240 }
2241
2242 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
2243
2244 if let Some(progress) = &progress {
2245 let _ = progress
2246 .send(Progress {
2247 current,
2248 total,
2249 status: None,
2250 })
2251 .await;
2252 }
2253
2254 match &continue_token {
2255 Some(token) if !token.is_empty() => continue,
2256 _ => break,
2257 }
2258 }
2259
2260 drop(progress);
2261 Ok(annotations)
2262 }
2263
2264 fn process_sample_annotations(
2265 &self,
2266 samples: &[Sample],
2267 labels: &HashMap<String, u64>,
2268 annotations: &mut Vec<Annotation>,
2269 ) {
2270 for sample in samples {
2271 if sample.annotations().is_empty() {
2272 let mut annotation = Annotation::new();
2273 annotation.set_sample_id(sample.id());
2274 annotation.set_name(sample.name());
2275 annotation.set_sequence_name(sample.sequence_name().cloned());
2276 annotation.set_frame_number(sample.frame_number());
2277 annotation.set_group(sample.group().cloned());
2278 annotations.push(annotation);
2279 continue;
2280 }
2281
2282 for annotation in sample.annotations() {
2283 let mut annotation = annotation.clone();
2284 annotation.set_sample_id(sample.id());
2285 annotation.set_name(sample.name());
2286 annotation.set_sequence_name(sample.sequence_name().cloned());
2287 annotation.set_frame_number(sample.frame_number());
2288 annotation.set_group(sample.group().cloned());
2289 Self::set_label_index_from_map(&mut annotation, labels);
2290 annotations.push(annotation);
2291 }
2292 }
2293 }
2294
2295 /// Delete annotations in bulk from specified samples.
2296 ///
2297 /// This method calls the `annotation.bulk.del` API to efficiently remove
2298 /// annotations from multiple samples at once. Useful for clearing
2299 /// annotations before re-importing updated data.
2300 ///
2301 /// # Arguments
2302 /// * `annotation_set_id` - The annotation set containing the annotations
2303 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
2304 /// for masks
2305 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
2306 ///
2307 /// # Example
2308 /// ```no_run
2309 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
2310 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2311 /// # let client = Client::new()?.with_login("user", "pass").await?;
2312 /// let annotation_set_id = AnnotationSetID::from(123);
2313 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2314 ///
2315 /// client
2316 /// .delete_annotations_bulk(
2317 /// annotation_set_id,
2318 /// &["box".to_string(), "seg".to_string()],
2319 /// &sample_ids,
2320 /// )
2321 /// .await?;
2322 /// # Ok(())
2323 /// # }
2324 /// ```
2325 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2326 pub async fn delete_annotations_bulk(
2327 &self,
2328 annotation_set_id: AnnotationSetID,
2329 annotation_types: &[String],
2330 sample_ids: &[SampleID],
2331 ) -> Result<(), Error> {
2332 use crate::api::AnnotationBulkDeleteParams;
2333
2334 let params = AnnotationBulkDeleteParams {
2335 annotation_set_id: annotation_set_id.into(),
2336 annotation_types: annotation_types.to_vec(),
2337 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2338 delete_all: None,
2339 };
2340
2341 let _: String = self
2342 .rpc("annotation.bulk.del".to_owned(), Some(params))
2343 .await?;
2344 Ok(())
2345 }
2346
2347 /// Add annotations in bulk.
2348 ///
2349 /// This method calls the `annotation.add_bulk` API to efficiently add
2350 /// multiple annotations at once. The annotations must be in server format
2351 /// with image_id references.
2352 ///
2353 /// # Arguments
2354 /// * `annotation_set_id` - The annotation set to add annotations to
2355 /// * `annotations` - Vector of server-format annotations to add
2356 ///
2357 /// # Returns
2358 /// Vector of created annotation records from the server.
2359 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2360 pub async fn add_annotations_bulk(
2361 &self,
2362 annotation_set_id: AnnotationSetID,
2363 annotations: Vec<crate::api::ServerAnnotation>,
2364 ) -> Result<Vec<serde_json::Value>, Error> {
2365 use crate::api::AnnotationAddBulkParams;
2366
2367 let params = AnnotationAddBulkParams {
2368 annotation_set_id: annotation_set_id.into(),
2369 annotations,
2370 };
2371
2372 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2373 .await
2374 }
2375
2376 /// Helper to parse frame number from image_name when sequence_name is
2377 /// present. This ensures frame_number is always derived from the image
2378 /// filename, not from the server's frame_number field (which may be
2379 /// inconsistent).
2380 ///
2381 /// Returns Some(frame_number) if sequence_name is present and frame can be
2382 /// parsed, otherwise None.
2383 fn parse_frame_from_image_name(
2384 image_name: Option<&String>,
2385 sequence_name: Option<&String>,
2386 ) -> Option<u32> {
2387 use std::path::Path;
2388
2389 let sequence = sequence_name?;
2390 let name = image_name?;
2391
2392 // Extract stem (remove extension)
2393 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2394
2395 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2396 stem.strip_prefix(sequence)
2397 .and_then(|suffix| suffix.strip_prefix('_'))
2398 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2399 }
2400
2401 /// Helper to set label index from a label map
2402 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2403 if let Some(label) = annotation.label() {
2404 annotation.set_label_index(Some(labels[label.as_str()]));
2405 }
2406 }
2407
2408 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2409 pub async fn samples_count(
2410 &self,
2411 dataset_id: DatasetID,
2412 annotation_set_id: Option<AnnotationSetID>,
2413 annotation_types: &[AnnotationType],
2414 groups: &[String],
2415 types: &[FileType],
2416 ) -> Result<SamplesCountResult, Error> {
2417 // Use server type names for API calls (e.g., "box" instead of "box2d")
2418 let types = annotation_types
2419 .iter()
2420 .map(|t| t.as_server_type().to_string())
2421 .chain(types.iter().map(|t| t.to_string()))
2422 .collect::<Vec<_>>();
2423
2424 let params = SamplesListParams {
2425 dataset_id,
2426 annotation_set_id,
2427 group_names: groups.to_vec(),
2428 types,
2429 continue_token: None,
2430 };
2431
2432 self.rpc("samples.count".to_owned(), Some(params)).await
2433 }
2434
2435 /// Fetches samples from a dataset with optional annotation and file type
2436 /// filters.
2437 ///
2438 /// # Arguments
2439 ///
2440 /// * `dataset_id` - The dataset to fetch samples from
2441 /// * `annotation_set_id` - Optional annotation set to include annotations
2442 /// from
2443 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2444 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2445 /// * `types` - File types to include metadata for
2446 /// * `progress` - Optional channel for progress updates
2447 ///
2448 /// # Progress
2449 ///
2450 /// Reports progress with `status: None` as samples are fetched from the
2451 /// server in paginated batches. Progress unit is samples fetched.
2452 ///
2453 /// # Returns
2454 ///
2455 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2456 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2457 pub async fn samples(
2458 &self,
2459 dataset_id: DatasetID,
2460 annotation_set_id: Option<AnnotationSetID>,
2461 annotation_types: &[AnnotationType],
2462 groups: &[String],
2463 types: &[FileType],
2464 progress: Option<Sender<Progress>>,
2465 ) -> Result<Vec<Sample>, Error> {
2466 // Use server type names for API calls (e.g., "box" instead of "box2d")
2467 let types_vec = annotation_types
2468 .iter()
2469 .map(|t| t.as_server_type().to_string())
2470 .chain(types.iter().map(|t| t.to_string()))
2471 .collect::<Vec<_>>();
2472 let labels = self
2473 .labels(dataset_id)
2474 .await?
2475 .into_iter()
2476 .map(|label| (label.name().to_string(), label.index()))
2477 .collect::<HashMap<_, _>>();
2478 let total = self
2479 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2480 .await?
2481 .total as usize;
2482
2483 if total == 0 {
2484 return Ok(vec![]);
2485 }
2486
2487 let context = FetchContext {
2488 dataset_id,
2489 annotation_set_id,
2490 groups,
2491 types: types_vec,
2492 labels: &labels,
2493 };
2494
2495 self.fetch_samples_paginated(context, total, progress).await
2496 }
2497
2498 /// Get all sample names in a dataset.
2499 ///
2500 /// This is an efficient method for checking which samples already exist,
2501 /// useful for resuming interrupted imports. It only retrieves sample names
2502 /// without loading full annotation data.
2503 ///
2504 /// # Arguments
2505 ///
2506 /// * `dataset_id` - The dataset to query
2507 /// * `groups` - Optional group filter (empty = all groups)
2508 /// * `progress` - Optional progress channel
2509 ///
2510 /// # Progress
2511 ///
2512 /// Reports progress with `status: None` as sample names are fetched from
2513 /// the server in paginated batches. Progress unit is samples fetched.
2514 ///
2515 /// # Returns
2516 ///
2517 /// A HashSet of sample names (image_name field) that exist in the dataset.
2518 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2519 pub async fn sample_names(
2520 &self,
2521 dataset_id: DatasetID,
2522 groups: &[String],
2523 progress: Option<Sender<Progress>>,
2524 ) -> Result<std::collections::HashSet<String>, Error> {
2525 use std::collections::HashSet;
2526
2527 let total = self
2528 .samples_count(dataset_id, None, &[], groups, &[])
2529 .await?
2530 .total as usize;
2531
2532 if total == 0 {
2533 return Ok(HashSet::new());
2534 }
2535
2536 let mut names = HashSet::with_capacity(total);
2537 let mut continue_token: Option<String> = None;
2538 let mut current = 0;
2539
2540 loop {
2541 let params = SamplesListParams {
2542 dataset_id,
2543 annotation_set_id: None,
2544 types: vec![], // No type filter - we just want names
2545 group_names: groups.to_vec(),
2546 continue_token: continue_token.clone(),
2547 };
2548
2549 let result: SamplesListResult =
2550 self.rpc("samples.list".to_owned(), Some(params)).await?;
2551 current += result.samples.len();
2552 continue_token = result.continue_token;
2553
2554 if result.samples.is_empty() {
2555 break;
2556 }
2557
2558 // Extract sample names (normalized without extension)
2559 for sample in result.samples {
2560 if let Some(name) = sample.name() {
2561 names.insert(name);
2562 }
2563 }
2564
2565 if let Some(ref p) = progress {
2566 let _ = p
2567 .send(Progress {
2568 current,
2569 total,
2570 status: None,
2571 })
2572 .await;
2573 }
2574
2575 match &continue_token {
2576 Some(token) if !token.is_empty() => continue,
2577 _ => break,
2578 }
2579 }
2580
2581 Ok(names)
2582 }
2583
2584 async fn fetch_samples_paginated(
2585 &self,
2586 context: FetchContext<'_>,
2587 total: usize,
2588 progress: Option<Sender<Progress>>,
2589 ) -> Result<Vec<Sample>, Error> {
2590 let mut samples = vec![];
2591 let mut continue_token: Option<String> = None;
2592 let mut current = 0;
2593
2594 loop {
2595 let params = SamplesListParams {
2596 dataset_id: context.dataset_id,
2597 annotation_set_id: context.annotation_set_id,
2598 types: context.types.clone(),
2599 group_names: context.groups.to_vec(),
2600 continue_token: continue_token.clone(),
2601 };
2602
2603 let result: SamplesListResult =
2604 self.rpc("samples.list".to_owned(), Some(params)).await?;
2605 current += result.samples.len();
2606 continue_token = result.continue_token;
2607
2608 if result.samples.is_empty() {
2609 break;
2610 }
2611
2612 samples.append(
2613 &mut result
2614 .samples
2615 .into_iter()
2616 .map(|s| {
2617 // Use server's frame_number if valid (>= 0 after deserialization)
2618 // Otherwise parse from image_name as fallback
2619 // This ensures we respect explicit frame_number from uploads
2620 // while still handling legacy data that only has filename encoding
2621 let frame_number = s.frame_number.or_else(|| {
2622 Self::parse_frame_from_image_name(
2623 s.image_name.as_ref(),
2624 s.sequence_name.as_ref(),
2625 )
2626 });
2627
2628 let mut anns = s.annotations().to_vec();
2629 for ann in &mut anns {
2630 // Set annotation fields from parent sample
2631 ann.set_name(s.name());
2632 ann.set_group(s.group().cloned());
2633 ann.set_sequence_name(s.sequence_name().cloned());
2634 ann.set_frame_number(frame_number);
2635 Self::set_label_index_from_map(ann, context.labels);
2636 }
2637 s.with_annotations(anns).with_frame_number(frame_number)
2638 })
2639 .collect::<Vec<_>>(),
2640 );
2641
2642 if let Some(progress) = &progress {
2643 let _ = progress
2644 .send(Progress {
2645 current,
2646 total,
2647 status: None,
2648 })
2649 .await;
2650 }
2651
2652 match &continue_token {
2653 Some(token) if !token.is_empty() => continue,
2654 _ => break,
2655 }
2656 }
2657
2658 drop(progress);
2659 Ok(samples)
2660 }
2661
2662 /// Populates (imports) samples into a dataset using the `samples.populate2`
2663 /// API.
2664 ///
2665 /// This method creates new samples in the specified dataset, optionally
2666 /// with annotations and sensor data files. For each sample, the `files`
2667 /// field is checked for local file paths. If a filename is a valid path
2668 /// to an existing file, the file will be automatically uploaded to S3
2669 /// using presigned URLs returned by the server. The filename in the
2670 /// request is replaced with the basename (path removed) before sending
2671 /// to the server.
2672 ///
2673 /// # Important Notes
2674 ///
2675 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2676 /// annotations. Without it, the server will accept the request but will
2677 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2678 /// available annotation sets for a dataset, or create a new one via the
2679 /// Studio UI.
2680 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2681 /// boxes. Divide pixel coordinates by image width/height before creating
2682 /// [`Box2d`](crate::Box2d) annotations.
2683 /// - **Files are uploaded automatically** when the filename is a valid
2684 /// local path. The method will replace the full path with just the
2685 /// basename before sending to the server.
2686 /// - **Image dimensions are extracted automatically** for image files using
2687 /// the `imagesize` crate. The width/height are sent to the server and
2688 /// stored in the `image_files` table. These dimensions are returned by
2689 /// `samples.list` and used in [`samples_dataframe`](crate::samples_dataframe)
2690 /// to populate the `size` column.
2691 /// - **UUIDs are generated automatically** if not provided. If you need
2692 /// deterministic UUIDs, set `sample.uuid` explicitly before calling.
2693 ///
2694 /// # Arguments
2695 ///
2696 /// * `dataset_id` - The ID of the dataset to populate
2697 /// * `annotation_set_id` - **Required** if samples contain annotations,
2698 /// otherwise they will be ignored. Query with
2699 /// [`Client::annotation_sets`].
2700 /// * `samples` - Vector of samples to import with metadata and file
2701 /// references. For files, use the full local path - it will be uploaded
2702 /// automatically. UUIDs and image dimensions will be
2703 /// auto-generated/extracted if not provided.
2704 /// * `progress` - Optional channel for progress updates
2705 ///
2706 /// # Progress
2707 ///
2708 /// Reports progress with `status: None` as each sample's files are
2709 /// uploaded. Progress unit is samples (not individual files). Each
2710 /// sample may contain multiple files (image, lidar, radar, etc.) which
2711 /// are all uploaded before the sample is counted as complete.
2712 ///
2713 /// # Returns
2714 ///
2715 /// Returns the API result with sample UUIDs and upload status.
2716 ///
2717 /// # Example
2718 ///
2719 /// ```no_run
2720 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2721 ///
2722 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2723 /// # let client = Client::new()?.with_login("user", "pass").await?;
2724 /// # let dataset_id = DatasetID::from(1);
2725 /// // Query available annotation sets for the dataset
2726 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2727 /// let annotation_set_id = annotation_sets
2728 /// .first()
2729 /// .ok_or_else(|| {
2730 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2731 /// })?
2732 /// .id();
2733 ///
2734 /// // Create sample with annotation (UUID will be auto-generated)
2735 /// let mut sample = Sample::new();
2736 /// sample.width = Some(1920);
2737 /// sample.height = Some(1080);
2738 /// sample.group = Some("train".to_string());
2739 ///
2740 /// // Add file - use full path to local file, it will be uploaded automatically
2741 /// sample.files = vec![SampleFile::with_filename(
2742 /// "image".to_string(),
2743 /// "/path/to/image.jpg".to_string(),
2744 /// )];
2745 ///
2746 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2747 /// let mut annotation = Annotation::new();
2748 /// annotation.set_label(Some("person".to_string()));
2749 /// // Normalize pixel coordinates by dividing by image dimensions
2750 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2751 /// annotation.set_box2d(Some(bbox));
2752 /// sample.annotations = vec![annotation];
2753 ///
2754 /// // Populate with annotation_set_id (REQUIRED for annotations)
2755 /// let result = client
2756 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2757 /// .await?;
2758 /// # Ok(())
2759 /// # }
2760 /// ```
2761 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2762 pub async fn populate_samples(
2763 &self,
2764 dataset_id: DatasetID,
2765 annotation_set_id: Option<AnnotationSetID>,
2766 samples: Vec<Sample>,
2767 progress: Option<Sender<Progress>>,
2768 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2769 self.populate_samples_with_concurrency(
2770 dataset_id,
2771 annotation_set_id,
2772 samples,
2773 progress,
2774 None,
2775 )
2776 .await
2777 }
2778
2779 /// Populate samples with custom upload concurrency.
2780 ///
2781 /// Same as [`populate_samples`](Self::populate_samples) but allows
2782 /// specifying the maximum number of concurrent file uploads. Use this
2783 /// for bulk imports where higher concurrency can significantly reduce
2784 /// upload time.
2785 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2786 pub async fn populate_samples_with_concurrency(
2787 &self,
2788 dataset_id: DatasetID,
2789 annotation_set_id: Option<AnnotationSetID>,
2790 samples: Vec<Sample>,
2791 progress: Option<Sender<Progress>>,
2792 concurrency: Option<usize>,
2793 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2794 use crate::api::SamplesPopulateParams;
2795 #[cfg(feature = "profiling")]
2796 use tracing::Instrument as _;
2797
2798 // Track which files need to be uploaded
2799 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
2800
2801 // Process samples to detect local files and generate UUIDs. This is
2802 // synchronous CPU/metadata work; the span uses `.entered()` since it
2803 // runs on the current task with no await inside.
2804 let samples = {
2805 #[cfg(feature = "profiling")]
2806 let _prepare_span = tracing::info_span!("prepare_samples", n = samples.len()).entered();
2807 self.prepare_samples_for_upload(samples, &mut files_to_upload)?
2808 };
2809
2810 let has_files_to_upload = !files_to_upload.is_empty();
2811
2812 // Call populate API with presigned_urls=true if we have files to upload
2813 let params = SamplesPopulateParams {
2814 dataset_id,
2815 annotation_set_id,
2816 presigned_urls: Some(has_files_to_upload),
2817 samples,
2818 };
2819
2820 #[cfg(feature = "profiling")]
2821 let rpc_start = std::time::Instant::now();
2822 let results: Vec<crate::SamplesPopulateResult> = self
2823 .rpc("samples.populate2".to_owned(), Some(params))
2824 .await?;
2825 #[cfg(feature = "profiling")]
2826 upload_stats::add_rpc_nanos(rpc_start.elapsed().as_nanos() as u64);
2827
2828 // Upload files if we have any. The S3 fan-out is async, so the span is
2829 // attached to the future with `.instrument()` (not `.entered()`) to stay
2830 // correct when this batch overlaps others.
2831 if has_files_to_upload {
2832 #[cfg(feature = "profiling")]
2833 let n_files = files_to_upload.len();
2834 #[cfg(feature = "profiling")]
2835 let upload_start = std::time::Instant::now();
2836 let upload_fut =
2837 self.upload_sample_files(&results, files_to_upload, progress, concurrency);
2838 #[cfg(feature = "profiling")]
2839 let upload_fut =
2840 upload_fut.instrument(tracing::info_span!("upload_files", files = n_files));
2841 upload_fut.await?;
2842 #[cfg(feature = "profiling")]
2843 upload_stats::add_upload_nanos(upload_start.elapsed().as_nanos() as u64);
2844 }
2845
2846 Ok(results)
2847 }
2848
2849 fn prepare_samples_for_upload(
2850 &self,
2851 samples: Vec<Sample>,
2852 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2853 ) -> Result<Vec<Sample>, Error> {
2854 Ok(samples
2855 .into_iter()
2856 .map(|mut sample| {
2857 // Generate UUID if not provided
2858 if sample.uuid.is_none() {
2859 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2860 }
2861
2862 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2863
2864 // Process files: detect local paths and queue for upload
2865 let files_copy = sample.files.clone();
2866 let updated_files: Vec<crate::SampleFile> = files_copy
2867 .iter()
2868 .map(|file| {
2869 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2870 })
2871 .collect();
2872
2873 sample.files = updated_files;
2874 sample
2875 })
2876 .collect())
2877 }
2878
2879 fn process_sample_file(
2880 &self,
2881 file: &crate::SampleFile,
2882 sample_uuid: &str,
2883 sample: &mut Sample,
2884 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2885 ) -> crate::SampleFile {
2886 use std::path::Path;
2887
2888 // Handle files with raw bytes (e.g., from ZIP archives)
2889 if let Some(bytes) = file.bytes()
2890 && let Some(filename) = file.filename()
2891 {
2892 // For image files with bytes, try to extract dimensions if not already set
2893 if file.file_type() == "image"
2894 && (sample.width.is_none() || sample.height.is_none())
2895 && let Ok(size) = imagesize::blob_size(bytes)
2896 {
2897 sample.width = Some(size.width as u32);
2898 sample.height = Some(size.height as u32);
2899 }
2900
2901 // Store the bytes for later upload
2902 files_to_upload.push((
2903 sample_uuid.to_string(),
2904 file.file_type().to_string(),
2905 FileSource::Bytes(bytes.to_vec()),
2906 filename.to_string(),
2907 ));
2908
2909 // Return SampleFile with just the filename
2910 return crate::SampleFile::with_filename(
2911 file.file_type().to_string(),
2912 filename.to_string(),
2913 );
2914 }
2915
2916 // Handle files with local paths
2917 if let Some(filename) = file.filename() {
2918 let path = Path::new(filename);
2919
2920 // Check if this is a valid local file path
2921 if path.exists()
2922 && path.is_file()
2923 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2924 {
2925 // For image files, try to extract dimensions if not already set
2926 if file.file_type() == "image"
2927 && (sample.width.is_none() || sample.height.is_none())
2928 && let Ok(size) = imagesize::size(path)
2929 {
2930 sample.width = Some(size.width as u32);
2931 sample.height = Some(size.height as u32);
2932 }
2933
2934 // Store the full path for later upload
2935 files_to_upload.push((
2936 sample_uuid.to_string(),
2937 file.file_type().to_string(),
2938 FileSource::Path(path.to_path_buf()),
2939 basename.to_string(),
2940 ));
2941
2942 // Return SampleFile with just the basename
2943 return crate::SampleFile::with_filename(
2944 file.file_type().to_string(),
2945 basename.to_string(),
2946 );
2947 }
2948 }
2949 // Return the file unchanged if not a local path
2950 file.clone()
2951 }
2952
2953 async fn upload_sample_files(
2954 &self,
2955 results: &[crate::SamplesPopulateResult],
2956 files_to_upload: Vec<(String, String, FileSource, String)>,
2957 progress: Option<Sender<Progress>>,
2958 concurrency: Option<usize>,
2959 ) -> Result<(), Error> {
2960 // Build a map from (sample_uuid, basename) -> file source
2961 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
2962 for (uuid, _file_type, source, basename) in files_to_upload {
2963 upload_map.insert((uuid, basename), source);
2964 }
2965
2966 let http = self.bulk_http.clone();
2967
2968 // Extract the data we need for parallel upload
2969 let upload_tasks: Vec<_> = results
2970 .iter()
2971 .map(|result| (result.uuid.clone(), result.urls.clone()))
2972 .collect();
2973
2974 parallel_foreach_items(
2975 upload_tasks,
2976 progress.clone(),
2977 concurrency,
2978 move |(uuid, urls)| {
2979 let http = http.clone();
2980 let upload_map = upload_map.clone();
2981
2982 async move {
2983 // Upload all files for this sample
2984 for url_info in &urls {
2985 if let Some(source) =
2986 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2987 {
2988 match source {
2989 FileSource::Path(path) => {
2990 upload_file_to_presigned_url(
2991 http.clone(),
2992 &url_info.url,
2993 path.clone(),
2994 )
2995 .await?;
2996 }
2997 FileSource::Bytes(bytes) => {
2998 upload_bytes_to_presigned_url(
2999 http.clone(),
3000 &url_info.url,
3001 bytes.clone(),
3002 &url_info.filename,
3003 )
3004 .await?;
3005 }
3006 }
3007 }
3008 }
3009
3010 Ok(())
3011 }
3012 },
3013 )
3014 .await
3015 }
3016
3017 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3018 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
3019 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
3020 if !url.starts_with("http://") && !url.starts_with("https://") {
3021 return Err(Error::InvalidParameters(format!(
3022 "Invalid URL (must be absolute): {}",
3023 url
3024 )));
3025 }
3026
3027 let resp = self.bulk_http.get(url).send().await?;
3028
3029 if !resp.status().is_success() {
3030 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
3031 }
3032
3033 let bytes = resp.bytes().await?;
3034 Ok(bytes.to_vec())
3035 }
3036
3037 /// Get samples as a DataFrame with complete 2025.10 schema.
3038 ///
3039 /// This is the recommended method for obtaining dataset annotations in
3040 /// DataFrame format. It includes all sample metadata (size, location,
3041 /// pose, degradation) as optional columns.
3042 ///
3043 /// # Arguments
3044 ///
3045 /// * `dataset_id` - Dataset identifier
3046 /// * `annotation_set_id` - Optional annotation set filter
3047 /// * `groups` - Dataset groups to include (train, val, test)
3048 /// * `types` - Annotation types to filter (bbox, box3d, mask)
3049 /// * `progress` - Optional progress callback
3050 ///
3051 /// # Progress
3052 ///
3053 /// Reports progress with `status: None` as samples are fetched from the
3054 /// server in paginated batches. Progress unit is samples fetched. This
3055 /// method delegates to [`samples()`](Self::samples) and shares its
3056 /// progress behavior.
3057 ///
3058 /// # Example
3059 ///
3060 /// ```rust,no_run
3061 /// use edgefirst_client::Client;
3062 ///
3063 /// # async fn example() -> Result<(), edgefirst_client::Error> {
3064 /// # let client = Client::new()?;
3065 /// # let dataset_id = 1.into();
3066 /// # let annotation_set_id = 1.into();
3067 /// let df = client
3068 /// .samples_dataframe(
3069 /// dataset_id,
3070 /// Some(annotation_set_id),
3071 /// &["train".to_string()],
3072 /// &[],
3073 /// None,
3074 /// )
3075 /// .await?;
3076 /// println!("DataFrame shape: {:?}", df.shape());
3077 /// # Ok(())
3078 /// # }
3079 /// ```
3080 #[cfg(feature = "polars")]
3081 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3082 pub async fn samples_dataframe(
3083 &self,
3084 dataset_id: DatasetID,
3085 annotation_set_id: Option<AnnotationSetID>,
3086 groups: &[String],
3087 types: &[AnnotationType],
3088 progress: Option<Sender<Progress>>,
3089 ) -> Result<DataFrame, Error> {
3090 use crate::dataset::samples_dataframe;
3091
3092 let samples = self
3093 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
3094 .await?;
3095 samples_dataframe(&samples)
3096 }
3097
3098 /// Update image dimensions for existing samples in a dataset.
3099 ///
3100 /// This is useful for backfilling width/height data on samples that were
3101 /// uploaded before dimension extraction was added, or where dimensions
3102 /// could not be determined at upload time.
3103 ///
3104 /// # Arguments
3105 ///
3106 /// * `dataset_id` - The dataset containing the samples
3107 /// * `updates` - List of dimension updates (sample ID, width, height)
3108 ///
3109 /// # Returns
3110 ///
3111 /// The number of samples that were successfully updated.
3112 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, updates), fields(dataset_id = %dataset_id, count = updates.len())))]
3113 pub async fn update_sample_dimensions(
3114 &self,
3115 dataset_id: DatasetID,
3116 updates: Vec<crate::SampleDimensionUpdate>,
3117 ) -> Result<u64, Error> {
3118 use crate::api::SamplesUpdateDimensionsParams;
3119
3120 if updates.is_empty() {
3121 return Ok(0);
3122 }
3123
3124 // Batch in groups of 500 to stay within server limits
3125 let mut total_updated = 0u64;
3126 for chunk in updates.chunks(500) {
3127 let params = SamplesUpdateDimensionsParams {
3128 dataset_id,
3129 samples: chunk.to_vec(),
3130 };
3131 let result: crate::SamplesUpdateDimensionsResult = self
3132 .rpc("samples.update_dimensions".to_owned(), Some(params))
3133 .await?;
3134 total_updated += result.updated;
3135 }
3136 Ok(total_updated)
3137 }
3138
3139 /// Backfill missing image dimensions for a dataset.
3140 ///
3141 /// Downloads image data for samples that are missing width/height,
3142 /// extracts the dimensions using the `imagesize` crate, and updates
3143 /// the server with the computed values.
3144 ///
3145 /// This is a one-time repair operation for datasets that were uploaded
3146 /// before the client added automatic dimension extraction.
3147 ///
3148 /// # Arguments
3149 ///
3150 /// * `dataset_id` - The dataset to backfill
3151 /// * `progress` - Optional progress channel
3152 ///
3153 /// # Returns
3154 ///
3155 /// The number of samples whose dimensions were updated.
3156 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(dataset_id = %dataset_id)))]
3157 pub async fn backfill_sample_dimensions(
3158 &self,
3159 dataset_id: DatasetID,
3160 progress: Option<Sender<Progress>>,
3161 ) -> Result<u64, Error> {
3162 // Fetch all samples; listing progress is not forwarded to the caller
3163 // since it would interleave with the dimension-computing phase.
3164 let samples = self.samples(dataset_id, None, &[], &[], &[], None).await?;
3165
3166 // Filter to samples missing dimensions
3167 let missing: Vec<&Sample> = samples
3168 .iter()
3169 .filter(|s| s.width.is_none() || s.height.is_none())
3170 .collect();
3171
3172 if missing.is_empty() {
3173 return Ok(0);
3174 }
3175
3176 let total = missing.len();
3177 let mut updates: Vec<crate::SampleDimensionUpdate> = Vec::with_capacity(total);
3178
3179 for (i, sample) in missing.into_iter().enumerate() {
3180 let current = i + 1;
3181
3182 let Some(id) = sample.id() else {
3183 Self::send_progress(&progress, current, total).await;
3184 continue;
3185 };
3186
3187 let Some(url) = sample.image_url() else {
3188 #[cfg(feature = "profiling")]
3189 tracing::warn!(sample_id = %id, "skipping sample: no image URL");
3190 Self::send_progress(&progress, current, total).await;
3191 continue;
3192 };
3193
3194 // Download image data to determine dimensions
3195 let resp = self.bulk_http.get(url).send().await;
3196 let Ok(resp) = resp else {
3197 #[cfg(feature = "profiling")]
3198 tracing::warn!(sample_id = %id, "skipping sample: download failed");
3199 Self::send_progress(&progress, current, total).await;
3200 continue;
3201 };
3202
3203 // Skip non-success responses (e.g. 404, 500) rather than parsing error pages
3204 if !resp.status().is_success() {
3205 #[cfg(feature = "profiling")]
3206 tracing::warn!(sample_id = %id, status = %resp.status(), "skipping sample: non-success HTTP status");
3207 Self::send_progress(&progress, current, total).await;
3208 continue;
3209 }
3210
3211 let Ok(bytes) = resp.bytes().await else {
3212 #[cfg(feature = "profiling")]
3213 tracing::warn!(sample_id = %id, "skipping sample: failed to read response body");
3214 Self::send_progress(&progress, current, total).await;
3215 continue;
3216 };
3217
3218 // Extract dimensions from the downloaded image
3219 let Ok(size) = imagesize::blob_size(&bytes) else {
3220 #[cfg(feature = "profiling")]
3221 tracing::warn!(sample_id = %id, "skipping sample: could not determine dimensions");
3222 Self::send_progress(&progress, current, total).await;
3223 continue;
3224 };
3225
3226 let (Ok(width), Ok(height)) = (u32::try_from(size.width), u32::try_from(size.height))
3227 else {
3228 #[cfg(feature = "profiling")]
3229 tracing::warn!(sample_id = %id, width = size.width, height = size.height, "skipping sample: dimensions overflow u32");
3230 Self::send_progress(&progress, current, total).await;
3231 continue;
3232 };
3233
3234 updates.push(crate::SampleDimensionUpdate { id, width, height });
3235 Self::send_progress(&progress, current, total).await;
3236 }
3237
3238 // Send updates to server
3239 self.update_sample_dimensions(dataset_id, updates).await
3240 }
3241
3242 /// Emit a progress event if a progress channel is provided.
3243 async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
3244 if let Some(tx) = progress {
3245 let _ = tx
3246 .send(Progress {
3247 current,
3248 total,
3249 status: Some("Computing dimensions".to_string()),
3250 })
3251 .await;
3252 }
3253 }
3254
3255 /// List available snapshots. If a name is provided, only snapshots
3256 /// containing that name are returned.
3257 ///
3258 /// Results are sorted by match quality: exact matches first, then
3259 /// case-insensitive exact matches, then shorter descriptions (more
3260 /// specific), then alphabetically.
3261 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3262 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
3263 let snapshots: Vec<Snapshot> = self
3264 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
3265 .await?;
3266 if let Some(name) = name {
3267 Ok(filter_and_sort_by_name(snapshots, name, |s| {
3268 s.description()
3269 }))
3270 } else {
3271 Ok(snapshots)
3272 }
3273 }
3274
3275 /// Get the snapshot with the specified id.
3276 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3277 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
3278 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3279 self.rpc("snapshots.get".to_owned(), Some(params)).await
3280 }
3281
3282 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
3283 ///
3284 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
3285 /// pairs) that serve two primary purposes:
3286 ///
3287 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
3288 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
3289 /// restored with AGTG (Automatic Ground Truth Generation) and optional
3290 /// auto-depth processing.
3291 ///
3292 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
3293 /// migration between EdgeFirst Studio instances using the create →
3294 /// download → upload → restore workflow.
3295 ///
3296 /// Large files are automatically chunked into 100MB parts and uploaded
3297 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
3298 /// is streamed without loading into memory, maintaining constant memory
3299 /// usage.
3300 ///
3301 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3302 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
3303 /// better for large files to avoid timeout issues. Higher values (16-32)
3304 /// are better for many small files.
3305 ///
3306 /// # Arguments
3307 ///
3308 /// * `path` - Local file path to MCAP file or directory containing
3309 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
3310 /// * `progress` - Optional channel to receive upload progress updates
3311 ///
3312 /// # Progress
3313 ///
3314 /// Reports progress with `status: None` as file data is uploaded. Progress
3315 /// unit is bytes uploaded. For single files, total is the file size. For
3316 /// directories, total is the combined size of all files.
3317 ///
3318 /// # Returns
3319 ///
3320 /// Returns a `Snapshot` object with ID, description, status, path, and
3321 /// creation timestamp on success.
3322 ///
3323 /// # Errors
3324 ///
3325 /// Returns an error if:
3326 /// * Path doesn't exist or contains invalid UTF-8
3327 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
3328 /// * Upload fails or network error occurs
3329 /// * Server rejects the snapshot
3330 ///
3331 /// # Example
3332 ///
3333 /// ```no_run
3334 /// # use edgefirst_client::{Client, Progress};
3335 /// # use tokio::sync::mpsc;
3336 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3337 /// let client = Client::new()?.with_token_path(None)?;
3338 ///
3339 /// // Upload MCAP file with progress tracking
3340 /// let (tx, mut rx) = mpsc::channel(1);
3341 /// tokio::spawn(async move {
3342 /// while let Some(Progress {
3343 /// current,
3344 /// total,
3345 /// status,
3346 /// }) = rx.recv().await
3347 /// {
3348 /// println!(
3349 /// "{}: {}/{} bytes ({:.1}%)",
3350 /// status.as_deref().unwrap_or("Upload"),
3351 /// current,
3352 /// total,
3353 /// (current as f64 / total as f64) * 100.0
3354 /// );
3355 /// }
3356 /// });
3357 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
3358 /// println!("Created snapshot: {:?}", snapshot.id());
3359 ///
3360 /// // Upload dataset directory (no progress)
3361 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
3362 /// # Ok(())
3363 /// # }
3364 /// ```
3365 ///
3366 /// # See Also
3367 ///
3368 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3369 /// dataset
3370 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3371 /// data
3372 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3373 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3374 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
3375 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3376 pub async fn create_snapshot(
3377 &self,
3378 path: &str,
3379 progress: Option<Sender<Progress>>,
3380 ) -> Result<Snapshot, Error> {
3381 let path = Path::new(path);
3382
3383 if path.is_dir() {
3384 let path_str = path.to_str().ok_or_else(|| {
3385 Error::IoError(std::io::Error::new(
3386 std::io::ErrorKind::InvalidInput,
3387 "Path contains invalid UTF-8",
3388 ))
3389 })?;
3390 return self.create_snapshot_folder(path_str, progress).await;
3391 }
3392
3393 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3394 Error::IoError(std::io::Error::new(
3395 std::io::ErrorKind::InvalidInput,
3396 "Invalid filename",
3397 ))
3398 })?;
3399 let total = path.metadata()?.len() as usize;
3400 let current = Arc::new(AtomicUsize::new(0));
3401
3402 if let Some(progress) = &progress {
3403 let _ = progress
3404 .send(Progress {
3405 current: 0,
3406 total,
3407 status: None,
3408 })
3409 .await;
3410 }
3411
3412 let params = SnapshotCreateMultipartParams {
3413 snapshot_name: name.to_owned(),
3414 keys: vec![name.to_owned()],
3415 file_sizes: vec![total],
3416 snapshot_type: None,
3417 };
3418 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3419 .rpc(
3420 "snapshots.create_upload_url_multipart".to_owned(),
3421 Some(params),
3422 )
3423 .await?;
3424
3425 let snapshot_id = match multipart.get("snapshot_id") {
3426 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3427 _ => return Err(Error::InvalidResponse),
3428 };
3429
3430 let snapshot = self.snapshot(snapshot_id).await?;
3431 let part_prefix = snapshot
3432 .path()
3433 .split("::/")
3434 .last()
3435 .ok_or(Error::InvalidResponse)?
3436 .to_owned();
3437 let part_key = format!("{}/{}", part_prefix, name);
3438 let mut part = match multipart.get(&part_key) {
3439 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3440 _ => return Err(Error::InvalidResponse),
3441 }
3442 .clone();
3443 part.key = Some(part_key);
3444
3445 let params = upload_multipart(
3446 self.bulk_http.clone(),
3447 part.clone(),
3448 path.to_path_buf(),
3449 total,
3450 current,
3451 progress.clone(),
3452 )
3453 .await?;
3454
3455 let complete: String = self
3456 .rpc(
3457 "snapshots.complete_multipart_upload".to_owned(),
3458 Some(params),
3459 )
3460 .await?;
3461 debug!("Snapshot Multipart Complete: {:?}", complete);
3462
3463 let params: SnapshotStatusParams = SnapshotStatusParams {
3464 snapshot_id,
3465 status: "available".to_owned(),
3466 };
3467 let _: SnapshotStatusResult = self
3468 .rpc("snapshots.update".to_owned(), Some(params))
3469 .await?;
3470
3471 if let Some(progress) = progress {
3472 drop(progress);
3473 }
3474
3475 self.snapshot(snapshot_id).await
3476 }
3477
3478 async fn create_snapshot_folder(
3479 &self,
3480 path: &str,
3481 progress: Option<Sender<Progress>>,
3482 ) -> Result<Snapshot, Error> {
3483 let path = Path::new(path);
3484 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3485 Error::IoError(std::io::Error::new(
3486 std::io::ErrorKind::InvalidInput,
3487 "Invalid directory name",
3488 ))
3489 })?;
3490
3491 let files = WalkDir::new(path)
3492 .into_iter()
3493 .filter_map(|entry| entry.ok())
3494 .filter(|entry| entry.file_type().is_file())
3495 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3496 .collect::<Vec<_>>();
3497
3498 let total: usize = files
3499 .iter()
3500 .filter_map(|file| path.join(file).metadata().ok())
3501 .map(|metadata| metadata.len() as usize)
3502 .sum();
3503 let current = Arc::new(AtomicUsize::new(0));
3504
3505 if let Some(progress) = &progress {
3506 let _ = progress
3507 .send(Progress {
3508 current: 0,
3509 total,
3510 status: None,
3511 })
3512 .await;
3513 }
3514
3515 let keys = files
3516 .iter()
3517 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3518 .collect::<Vec<_>>();
3519 let file_sizes = files
3520 .iter()
3521 .filter_map(|key| path.join(key).metadata().ok())
3522 .map(|metadata| metadata.len() as usize)
3523 .collect::<Vec<_>>();
3524
3525 let params = SnapshotCreateMultipartParams {
3526 snapshot_name: name.to_owned(),
3527 keys,
3528 file_sizes,
3529 snapshot_type: None,
3530 };
3531
3532 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3533 .rpc(
3534 "snapshots.create_upload_url_multipart".to_owned(),
3535 Some(params),
3536 )
3537 .await?;
3538
3539 let snapshot_id = match multipart.get("snapshot_id") {
3540 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3541 _ => return Err(Error::InvalidResponse),
3542 };
3543
3544 let snapshot = self.snapshot(snapshot_id).await?;
3545 let part_prefix = snapshot
3546 .path()
3547 .split("::/")
3548 .last()
3549 .ok_or(Error::InvalidResponse)?
3550 .to_owned();
3551
3552 for file in files {
3553 let file_str = file.to_str().ok_or_else(|| {
3554 Error::IoError(std::io::Error::new(
3555 std::io::ErrorKind::InvalidInput,
3556 "File path contains invalid UTF-8",
3557 ))
3558 })?;
3559 let part_key = format!("{}/{}", part_prefix, file_str);
3560 let mut part = match multipart.get(&part_key) {
3561 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3562 _ => return Err(Error::InvalidResponse),
3563 }
3564 .clone();
3565 part.key = Some(part_key);
3566
3567 let params = upload_multipart(
3568 self.bulk_http.clone(),
3569 part.clone(),
3570 path.join(file),
3571 total,
3572 current.clone(),
3573 progress.clone(),
3574 )
3575 .await?;
3576
3577 let complete: String = self
3578 .rpc(
3579 "snapshots.complete_multipart_upload".to_owned(),
3580 Some(params),
3581 )
3582 .await?;
3583 debug!("Snapshot Part Complete: {:?}", complete);
3584 }
3585
3586 let params = SnapshotStatusParams {
3587 snapshot_id,
3588 status: "available".to_owned(),
3589 };
3590 let _: SnapshotStatusResult = self
3591 .rpc("snapshots.update".to_owned(), Some(params))
3592 .await?;
3593
3594 if let Some(progress) = progress {
3595 drop(progress);
3596 }
3597
3598 self.snapshot(snapshot_id).await
3599 }
3600
3601 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3602 ///
3603 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3604 /// This format is the native EdgeFirst Dataset Format used for efficient
3605 /// dataset storage and transfer.
3606 ///
3607 /// # Arguments
3608 ///
3609 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3610 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3611 /// * `description` - Optional description for the snapshot
3612 /// * `progress` - Optional progress channel for upload tracking
3613 ///
3614 /// # File Requirements
3615 ///
3616 /// - Arrow file must have `.arrow` extension
3617 /// - ZIP file must have `.zip` extension
3618 /// - Both files must exist and be readable
3619 ///
3620 /// # Example
3621 ///
3622 /// ```no_run
3623 /// # use edgefirst_client::Client;
3624 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3625 /// let client = Client::new()?.with_token_path(None)?;
3626 ///
3627 /// let snapshot = client
3628 /// .create_snapshot_edgefirst_format(
3629 /// "dataset.arrow",
3630 /// "dataset.zip",
3631 /// Some("My Dataset Snapshot"),
3632 /// None,
3633 /// )
3634 /// .await?;
3635 /// println!("Created snapshot: {}", snapshot.id());
3636 /// # Ok(())
3637 /// # }
3638 /// ```
3639 ///
3640 /// # See Also
3641 ///
3642 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3643 /// folder
3644 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3645 /// dataset
3646 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3647 pub async fn create_snapshot_edgefirst_format(
3648 &self,
3649 arrow_path: &str,
3650 zip_path: &str,
3651 description: Option<&str>,
3652 progress: Option<Sender<Progress>>,
3653 ) -> Result<Snapshot, Error> {
3654 let arrow_path = Path::new(arrow_path);
3655 let zip_path = Path::new(zip_path);
3656
3657 // Validate files exist
3658 if !arrow_path.exists() {
3659 return Err(Error::IoError(std::io::Error::new(
3660 std::io::ErrorKind::NotFound,
3661 format!("Arrow file not found: {}", arrow_path.display()),
3662 )));
3663 }
3664 if !zip_path.exists() {
3665 return Err(Error::IoError(std::io::Error::new(
3666 std::io::ErrorKind::NotFound,
3667 format!("ZIP file not found: {}", zip_path.display()),
3668 )));
3669 }
3670
3671 // Get file names
3672 let arrow_name = arrow_path
3673 .file_name()
3674 .and_then(|n| n.to_str())
3675 .ok_or_else(|| {
3676 Error::IoError(std::io::Error::new(
3677 std::io::ErrorKind::InvalidInput,
3678 "Invalid Arrow filename",
3679 ))
3680 })?;
3681 let zip_name = zip_path
3682 .file_name()
3683 .and_then(|n| n.to_str())
3684 .ok_or_else(|| {
3685 Error::IoError(std::io::Error::new(
3686 std::io::ErrorKind::InvalidInput,
3687 "Invalid ZIP filename",
3688 ))
3689 })?;
3690
3691 // Generate snapshot name from arrow file (without extension)
3692 let snapshot_name = description
3693 .map(|s| s.to_string())
3694 .or_else(|| {
3695 arrow_path
3696 .file_stem()
3697 .and_then(|s| s.to_str())
3698 .map(|s| s.to_string())
3699 })
3700 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3701
3702 // Calculate file sizes
3703 let arrow_size = arrow_path.metadata()?.len() as usize;
3704 let zip_size = zip_path.metadata()?.len() as usize;
3705 let total = arrow_size + zip_size;
3706 let current = Arc::new(AtomicUsize::new(0));
3707
3708 if let Some(progress) = &progress {
3709 let _ = progress
3710 .send(Progress {
3711 current: 0,
3712 total,
3713 status: None,
3714 })
3715 .await;
3716 }
3717
3718 // Create multipart upload request with "ziparrow" type
3719 let params = SnapshotCreateMultipartParams {
3720 snapshot_name,
3721 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
3722 file_sizes: vec![arrow_size, zip_size],
3723 snapshot_type: Some("ziparrow".to_string()),
3724 };
3725
3726 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3727 .rpc(
3728 "snapshots.create_upload_url_multipart".to_owned(),
3729 Some(params),
3730 )
3731 .await?;
3732
3733 let snapshot_id = match multipart.get("snapshot_id") {
3734 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3735 _ => return Err(Error::InvalidResponse),
3736 };
3737
3738 let snapshot = self.snapshot(snapshot_id).await?;
3739 let part_prefix = snapshot
3740 .path()
3741 .split("::/")
3742 .last()
3743 .ok_or(Error::InvalidResponse)?
3744 .to_owned();
3745
3746 // Upload Arrow file
3747 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
3748 let mut arrow_part = match multipart.get(&arrow_key) {
3749 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3750 _ => return Err(Error::InvalidResponse),
3751 };
3752 arrow_part.key = Some(arrow_key);
3753
3754 let params = upload_multipart(
3755 self.bulk_http.clone(),
3756 arrow_part,
3757 arrow_path.to_path_buf(),
3758 total,
3759 current.clone(),
3760 progress.clone(),
3761 )
3762 .await?;
3763
3764 let _: String = self
3765 .rpc(
3766 "snapshots.complete_multipart_upload".to_owned(),
3767 Some(params),
3768 )
3769 .await?;
3770 debug!("Arrow file upload complete");
3771
3772 // Upload ZIP file
3773 let zip_key = format!("{}/{}", part_prefix, zip_name);
3774 let mut zip_part = match multipart.get(&zip_key) {
3775 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3776 _ => return Err(Error::InvalidResponse),
3777 };
3778 zip_part.key = Some(zip_key);
3779
3780 let params = upload_multipart(
3781 self.bulk_http.clone(),
3782 zip_part,
3783 zip_path.to_path_buf(),
3784 total,
3785 current.clone(),
3786 progress.clone(),
3787 )
3788 .await?;
3789
3790 let _: String = self
3791 .rpc(
3792 "snapshots.complete_multipart_upload".to_owned(),
3793 Some(params),
3794 )
3795 .await?;
3796 debug!("ZIP file upload complete");
3797
3798 // Mark snapshot as available
3799 let params = SnapshotStatusParams {
3800 snapshot_id,
3801 status: "available".to_owned(),
3802 };
3803 let _: SnapshotStatusResult = self
3804 .rpc("snapshots.update".to_owned(), Some(params))
3805 .await?;
3806
3807 if let Some(progress) = progress {
3808 drop(progress);
3809 }
3810
3811 self.snapshot(snapshot_id).await
3812 }
3813
3814 /// Delete a snapshot from EdgeFirst Studio.
3815 ///
3816 /// Permanently removes a snapshot and its associated data. This operation
3817 /// cannot be undone.
3818 ///
3819 /// # Arguments
3820 ///
3821 /// * `snapshot_id` - The snapshot ID to delete
3822 ///
3823 /// # Errors
3824 ///
3825 /// Returns an error if:
3826 /// * Snapshot doesn't exist
3827 /// * User lacks permission to delete the snapshot
3828 /// * Server error occurs
3829 ///
3830 /// # Example
3831 ///
3832 /// ```no_run
3833 /// # use edgefirst_client::{Client, SnapshotID};
3834 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3835 /// let client = Client::new()?.with_token_path(None)?;
3836 /// let snapshot_id = SnapshotID::from(123);
3837 /// client.delete_snapshot(snapshot_id).await?;
3838 /// # Ok(())
3839 /// # }
3840 /// ```
3841 ///
3842 /// # See Also
3843 ///
3844 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3845 /// * [`snapshots`](Self::snapshots) - List all snapshots
3846 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3847 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
3848 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3849 let _: serde_json::Value = self
3850 .rpc("snapshots.delete".to_owned(), Some(params))
3851 .await?;
3852 Ok(())
3853 }
3854
3855 /// Create a snapshot from an existing dataset on the server.
3856 ///
3857 /// Triggers server-side snapshot generation which exports the dataset's
3858 /// images and annotations into a downloadable EdgeFirst Dataset Format
3859 /// snapshot.
3860 ///
3861 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
3862 /// while restore creates a dataset from a snapshot, this method creates a
3863 /// snapshot from a dataset.
3864 ///
3865 /// # Arguments
3866 ///
3867 /// * `dataset_id` - The dataset ID to create snapshot from
3868 /// * `description` - Description for the created snapshot
3869 ///
3870 /// # Returns
3871 ///
3872 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
3873 /// for monitoring progress.
3874 ///
3875 /// # Errors
3876 ///
3877 /// Returns an error if:
3878 /// * Dataset doesn't exist
3879 /// * User lacks permission to access the dataset
3880 /// * Server rejects the request
3881 ///
3882 /// # Example
3883 ///
3884 /// ```no_run
3885 /// # use edgefirst_client::{Client, DatasetID};
3886 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3887 /// let client = Client::new()?.with_token_path(None)?;
3888 /// let dataset_id = DatasetID::from(123);
3889 ///
3890 /// // Create snapshot from dataset (all annotation sets)
3891 /// let result = client
3892 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
3893 /// .await?;
3894 /// println!("Created snapshot: {:?}", result.id);
3895 ///
3896 /// // Monitor progress via task ID
3897 /// if let Some(task_id) = result.task_id {
3898 /// println!("Task: {}", task_id);
3899 /// }
3900 /// # Ok(())
3901 /// # }
3902 /// ```
3903 ///
3904 /// # See Also
3905 ///
3906 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
3907 /// snapshot
3908 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3909 /// dataset
3910 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3911 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3912 pub async fn create_snapshot_from_dataset(
3913 &self,
3914 dataset_id: DatasetID,
3915 description: &str,
3916 annotation_set_id: Option<AnnotationSetID>,
3917 ) -> Result<SnapshotFromDatasetResult, Error> {
3918 // Resolve annotation_set_id: use provided value or fetch default
3919 let annotation_set_id = match annotation_set_id {
3920 Some(id) => id,
3921 None => {
3922 // Fetch annotation sets and find default ("annotations") or use first
3923 let sets = self.annotation_sets(dataset_id).await?;
3924 if sets.is_empty() {
3925 return Err(Error::InvalidParameters(
3926 "No annotation sets available for dataset".to_owned(),
3927 ));
3928 }
3929 // Look for "annotations" set (default), otherwise use first
3930 sets.iter()
3931 .find(|s| s.name() == "annotations")
3932 .unwrap_or(&sets[0])
3933 .id()
3934 }
3935 };
3936 let params = SnapshotCreateFromDataset {
3937 description: description.to_owned(),
3938 dataset_id,
3939 annotation_set_id,
3940 };
3941 self.rpc("snapshots.create".to_owned(), Some(params)).await
3942 }
3943
3944 /// Download a snapshot from EdgeFirst Studio to local storage.
3945 ///
3946 /// Downloads all files in a snapshot (single MCAP file or directory of
3947 /// EdgeFirst Dataset Format files) to the specified output path. Files are
3948 /// downloaded concurrently with progress tracking.
3949 ///
3950 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3951 /// downloads (default: half of CPU cores, min 2, max 8).
3952 ///
3953 /// # Arguments
3954 ///
3955 /// * `snapshot_id` - The snapshot ID to download
3956 /// * `output` - Local directory path to save downloaded files
3957 /// * `progress` - Optional channel to receive download progress updates
3958 ///
3959 /// # Progress
3960 ///
3961 /// Reports progress with `status: None` as file data is received. Progress
3962 /// unit is bytes downloaded across all files combined. The total
3963 /// accumulates as file sizes become known (from HTTP Content-Length
3964 /// headers), so both `current` and `total` may increase during
3965 /// download.
3966 ///
3967 /// # Errors
3968 ///
3969 /// Returns an error if:
3970 /// * Snapshot doesn't exist
3971 /// * Output directory cannot be created
3972 /// * Download fails or network error occurs
3973 ///
3974 /// # Example
3975 ///
3976 /// ```no_run
3977 /// # use edgefirst_client::{Client, SnapshotID, Progress};
3978 /// # use tokio::sync::mpsc;
3979 /// # use std::path::PathBuf;
3980 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3981 /// let client = Client::new()?.with_token_path(None)?;
3982 /// let snapshot_id = SnapshotID::from(123);
3983 ///
3984 /// // Download with progress tracking
3985 /// let (tx, mut rx) = mpsc::channel(1);
3986 /// tokio::spawn(async move {
3987 /// while let Some(Progress {
3988 /// current,
3989 /// total,
3990 /// status,
3991 /// }) = rx.recv().await
3992 /// {
3993 /// println!(
3994 /// "{}: {}/{} bytes",
3995 /// status.as_deref().unwrap_or("Download"),
3996 /// current,
3997 /// total
3998 /// );
3999 /// }
4000 /// });
4001 /// client
4002 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
4003 /// .await?;
4004 /// # Ok(())
4005 /// # }
4006 /// ```
4007 ///
4008 /// # See Also
4009 ///
4010 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4011 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
4012 /// dataset
4013 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
4014 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
4015 pub async fn download_snapshot(
4016 &self,
4017 snapshot_id: SnapshotID,
4018 output: PathBuf,
4019 progress: Option<Sender<Progress>>,
4020 ) -> Result<(), Error> {
4021 fs::create_dir_all(&output).await?;
4022
4023 let params = HashMap::from([("snapshot_id", snapshot_id)]);
4024 let items: HashMap<String, String> = self
4025 .rpc("snapshots.create_download_url".to_owned(), Some(params))
4026 .await?;
4027
4028 // Single-phase: each task holds its semaphore permit for the full
4029 // lifetime of the request (GET → headers → stream → disk). This bounds
4030 // the number of simultaneously-open connections to max_tasks() and
4031 // avoids accumulating all responses in memory before streaming.
4032 //
4033 // total is updated atomically as each response's Content-Length header
4034 // arrives, so progress tracking is accurate without a separate phase.
4035 let http = self.bulk_http.clone();
4036 let current = Arc::new(AtomicUsize::new(0));
4037 let total = Arc::new(AtomicUsize::new(0));
4038 let sem = Arc::new(Semaphore::new(max_tasks()));
4039
4040 let tasks = items
4041 .into_iter()
4042 .map(|(key, url)| {
4043 let http = http.clone();
4044 let output = output.clone();
4045 let progress = progress.clone();
4046 let current = current.clone();
4047 let total = total.clone();
4048 let sem = sem.clone();
4049
4050 tokio::spawn(async move {
4051 let _permit = sem.acquire().await.map_err(|_| {
4052 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4053 })?;
4054
4055 let res = http.get(url).send().await?;
4056 let res = res.error_for_status()?;
4057
4058 // Contribute this file's size to the running total so the
4059 // caller's progress bar knows the overall scope.
4060 if let Some(len) = res.content_length() {
4061 total.fetch_add(len as usize, Ordering::SeqCst);
4062 }
4063
4064 let mut file = File::create(output.join(key)).await?;
4065 let mut stream = res.bytes_stream();
4066
4067 while let Some(chunk) = stream.next().await {
4068 let chunk = chunk?;
4069 file.write_all(&chunk).await?;
4070 let len = chunk.len();
4071
4072 if let Some(progress) = &progress {
4073 let cur = current.fetch_add(len, Ordering::SeqCst) + len;
4074 let tot = total.load(Ordering::SeqCst);
4075 let _ = progress
4076 .send(Progress {
4077 current: cur,
4078 total: tot,
4079 status: None,
4080 })
4081 .await;
4082 }
4083 }
4084
4085 Ok::<(), Error>(())
4086 })
4087 })
4088 .collect::<Vec<_>>();
4089
4090 join_all(tasks)
4091 .await
4092 .into_iter()
4093 .collect::<Result<Vec<_>, _>>()?
4094 .into_iter()
4095 .collect::<Result<Vec<_>, _>>()?;
4096
4097 Ok(())
4098 }
4099
4100 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
4101 ///
4102 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
4103 /// the specified project. For MCAP files, supports:
4104 ///
4105 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
4106 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
4107 /// present)
4108 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
4109 /// * **Topic filtering**: Select specific MCAP topics to restore
4110 ///
4111 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
4112 /// dataset structure.
4113 ///
4114 /// # Arguments
4115 ///
4116 /// * `project_id` - Target project ID
4117 /// * `snapshot_id` - Snapshot ID to restore
4118 /// * `topics` - MCAP topics to include (empty = all topics)
4119 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
4120 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
4121 /// * `dataset_name` - Optional custom dataset name
4122 /// * `dataset_description` - Optional dataset description
4123 ///
4124 /// # Returns
4125 ///
4126 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
4127 ///
4128 /// # Errors
4129 ///
4130 /// Returns an error if:
4131 /// * Snapshot or project doesn't exist
4132 /// * Snapshot format is invalid
4133 /// * Server rejects restoration parameters
4134 ///
4135 /// # Example
4136 ///
4137 /// ```no_run
4138 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
4139 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4140 /// let client = Client::new()?.with_token_path(None)?;
4141 /// let project_id = ProjectID::from(1);
4142 /// let snapshot_id = SnapshotID::from(123);
4143 ///
4144 /// // Restore MCAP with AGTG for "person" and "car" detection
4145 /// let result = client
4146 /// .restore_snapshot(
4147 /// project_id,
4148 /// snapshot_id,
4149 /// &[], // All topics
4150 /// &["person".to_string(), "car".to_string()], // AGTG labels
4151 /// true, // Auto-depth
4152 /// Some("Highway Dataset"),
4153 /// Some("Collected on I-95"),
4154 /// )
4155 /// .await?;
4156 /// println!("Restored to dataset: {:?}", result.dataset_id);
4157 /// # Ok(())
4158 /// # }
4159 /// ```
4160 ///
4161 /// # See Also
4162 ///
4163 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4164 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
4165 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
4166 #[allow(clippy::too_many_arguments)]
4167 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4168 pub async fn restore_snapshot(
4169 &self,
4170 project_id: ProjectID,
4171 snapshot_id: SnapshotID,
4172 topics: &[String],
4173 autolabel: &[String],
4174 autodepth: bool,
4175 dataset_name: Option<&str>,
4176 dataset_description: Option<&str>,
4177 ) -> Result<SnapshotRestoreResult, Error> {
4178 let params = SnapshotRestore {
4179 project_id,
4180 snapshot_id,
4181 fps: 1,
4182 autodepth,
4183 agtg_pipeline: !autolabel.is_empty(),
4184 autolabel: autolabel.to_vec(),
4185 topics: topics.to_vec(),
4186 dataset_name: dataset_name.map(|s| s.to_owned()),
4187 dataset_description: dataset_description.map(|s| s.to_owned()),
4188 };
4189 self.rpc("snapshots.restore".to_owned(), Some(params)).await
4190 }
4191
4192 /// Returns a list of experiments available to the user. The experiments
4193 /// are returned as a vector of Experiment objects. If name is provided
4194 /// then only experiments containing this string are returned.
4195 ///
4196 /// Results are sorted by match quality: exact matches first, then
4197 /// case-insensitive exact matches, then shorter names (more specific),
4198 /// then alphabetically.
4199 ///
4200 /// Experiments provide a method of organizing training and validation
4201 /// sessions together and are akin to an Experiment in MLFlow terminology.
4202 /// Each experiment can have multiple trainer sessions associated with it,
4203 /// these would be akin to runs in MLFlow terminology.
4204 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4205 pub async fn experiments(
4206 &self,
4207 project_id: ProjectID,
4208 name: Option<&str>,
4209 ) -> Result<Vec<Experiment>, Error> {
4210 let params = HashMap::from([("project_id", project_id)]);
4211 let experiments: Vec<Experiment> =
4212 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
4213 if let Some(name) = name {
4214 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
4215 } else {
4216 Ok(experiments)
4217 }
4218 }
4219
4220 /// Return the experiment with the specified experiment ID. If the
4221 /// experiment does not exist, an error is returned.
4222 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4223 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
4224 let params = HashMap::from([("trainer_id", experiment_id)]);
4225 self.rpc("trainer.get".to_owned(), Some(params)).await
4226 }
4227
4228 /// Returns a list of trainer sessions available to the user. The trainer
4229 /// sessions are returned as a vector of TrainingSession objects. If name
4230 /// is provided then only trainer sessions containing this string are
4231 /// returned.
4232 ///
4233 /// Results are sorted by match quality: exact matches first, then
4234 /// case-insensitive exact matches, then shorter names (more specific),
4235 /// then alphabetically.
4236 ///
4237 /// Trainer sessions are akin to runs in MLFlow terminology. These
4238 /// represent an actual training session which will produce metrics and
4239 /// model artifacts.
4240 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4241 pub async fn training_sessions(
4242 &self,
4243 experiment_id: ExperimentID,
4244 name: Option<&str>,
4245 ) -> Result<Vec<TrainingSession>, Error> {
4246 let params = HashMap::from([("trainer_id", experiment_id)]);
4247 let sessions: Vec<TrainingSession> = self
4248 .rpc("trainer.session.list".to_owned(), Some(params))
4249 .await?;
4250 if let Some(name) = name {
4251 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
4252 } else {
4253 Ok(sessions)
4254 }
4255 }
4256
4257 /// Return the trainer session with the specified trainer session ID. If
4258 /// the trainer session does not exist, an error is returned.
4259 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4260 pub async fn training_session(
4261 &self,
4262 session_id: TrainingSessionID,
4263 ) -> Result<TrainingSession, Error> {
4264 let params = HashMap::from([("trainer_session_id", session_id)]);
4265 self.rpc("trainer.session.get".to_owned(), Some(params))
4266 .await
4267 }
4268
4269 /// List validation sessions for the given project.
4270 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4271 pub async fn validation_sessions(
4272 &self,
4273 project_id: ProjectID,
4274 ) -> Result<Vec<ValidationSession>, Error> {
4275 let params = HashMap::from([("project_id", project_id)]);
4276 self.rpc("validate.session.list".to_owned(), Some(params))
4277 .await
4278 }
4279
4280 /// Retrieve a specific validation session.
4281 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4282 pub async fn validation_session(
4283 &self,
4284 session_id: ValidationSessionID,
4285 ) -> Result<ValidationSession, Error> {
4286 let params = HashMap::from([("validate_session_id", session_id)]);
4287 self.rpc("validate.session.get".to_owned(), Some(params))
4288 .await
4289 }
4290
4291 /// Create a new validation session via Studio's `cloud.server.start`.
4292 ///
4293 /// Pass `is_local: true` in the [`StartValidationRequest`] to create
4294 /// a **user-managed** session: the database row is created and the
4295 /// session is fully usable for data uploads / downloads / metrics,
4296 /// but no EC2 instance is provisioned and no automated validator
4297 /// pipeline is started. That is the mode our integration tests use
4298 /// — they create a session, exercise the wrapper APIs against it,
4299 /// then call [`Client::delete_validation_sessions`] in teardown so
4300 /// no stray sessions accumulate on the test account.
4301 ///
4302 /// Returns a [`NewValidationSession`] carrying the backing task id
4303 /// and the freshly-minted validation session id.
4304 ///
4305 /// # Errors
4306 ///
4307 /// Surfaces any RPC error from `cloud.server.start`. Common cases:
4308 /// `RpcError(101, …)` if a required entity is missing (project,
4309 /// training session, dataset, …); `PermissionDenied` if the caller
4310 /// can't write to the target project.
4311 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, req)))]
4312 pub async fn start_validation_session(
4313 &self,
4314 req: StartValidationRequest,
4315 ) -> Result<NewValidationSession, Error> {
4316 // Build the params shape the server expects. `cloud.server.start`
4317 // is intentionally generic — different server types pull
4318 // different fields out of `params` — so we serialize manually to
4319 // match the JS frontend's call site verbatim (see
4320 // `dve-frontend/src/components/ValidationPage/StartValidatorModal.vue`).
4321 let mut body = serde_json::Map::new();
4322 body.insert(
4323 "type".into(),
4324 serde_json::Value::String("validation".into()),
4325 );
4326 body.insert("name".into(), serde_json::Value::String(req.name));
4327 body.insert("project_id".into(), serde_json::to_value(req.project_id)?);
4328 body.insert(
4329 "training_session_id".into(),
4330 serde_json::to_value(req.training_session_id)?,
4331 );
4332 body.insert(
4333 "model_file".into(),
4334 serde_json::Value::String(req.model_file),
4335 );
4336 body.insert("val_type".into(), serde_json::Value::String(req.val_type));
4337 body.insert("is_local".into(), serde_json::Value::Bool(req.is_local));
4338 body.insert(
4339 "is_kubernetes".into(),
4340 serde_json::Value::Bool(req.is_kubernetes),
4341 );
4342
4343 // `validate.session` reads its config from `params.params` (one
4344 // extra envelope level). The outer `params` wrapper is required
4345 // even when the inner map is empty.
4346 let inner = serde_json::to_value(req.params)?;
4347 let mut outer = serde_json::Map::new();
4348 outer.insert("params".into(), inner);
4349 body.insert("params".into(), serde_json::Value::Object(outer));
4350
4351 if let Some(d) = req.description {
4352 body.insert("description".into(), serde_json::Value::String(d));
4353 }
4354 if let Some(id) = req.dataset_id {
4355 body.insert("dataset_id".into(), serde_json::to_value(id)?);
4356 }
4357 if let Some(id) = req.annotation_set_id {
4358 body.insert("annotation_set_id".into(), serde_json::to_value(id)?);
4359 }
4360 if let Some(id) = req.snapshot_id {
4361 body.insert("snapshot_id".into(), serde_json::to_value(id)?);
4362 }
4363
4364 self.rpc("cloud.server.start".to_owned(), Some(body)).await
4365 }
4366
4367 /// Delete one or more validation sessions via
4368 /// `validate.session.delete`.
4369 ///
4370 /// Used by integration tests to tear down sessions they created
4371 /// with [`Client::start_validation_session`]; idempotent against
4372 /// already-deleted ids on the server side (the RPC accepts the
4373 /// list, deletes what it can, and surfaces an error only if none
4374 /// of the ids were resolvable).
4375 ///
4376 /// # Errors
4377 ///
4378 /// Surfaces any RPC error from `validate.session.delete`. A
4379 /// `PermissionDenied` indicates the caller lacks
4380 /// `TrainerWrite` on at least one of the listed sessions.
4381 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4382 pub async fn delete_validation_sessions(
4383 &self,
4384 session_ids: &[ValidationSessionID],
4385 ) -> Result<(), Error> {
4386 let mut body = serde_json::Map::new();
4387 body.insert("session_ids".into(), serde_json::to_value(session_ids)?);
4388 let _: serde_json::Value = self
4389 .rpc("validate.session.delete".to_owned(), Some(body))
4390 .await?;
4391 Ok(())
4392 }
4393
4394 /// List the artifacts for the specified trainer session. The artifacts
4395 /// are returned as a vector of strings.
4396 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4397 pub async fn artifacts(
4398 &self,
4399 training_session_id: TrainingSessionID,
4400 ) -> Result<Vec<Artifact>, Error> {
4401 let params = HashMap::from([("training_session_id", training_session_id)]);
4402 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
4403 .await
4404 }
4405
4406 /// Download the model artifact for the specified trainer session to the
4407 /// specified file path, if path is not provided it will be downloaded to
4408 /// the current directory with the same filename.
4409 ///
4410 /// # Progress
4411 ///
4412 /// Reports progress with `status: None` as file data is received. Progress
4413 /// unit is bytes downloaded. Total is determined from the HTTP
4414 /// Content-Length header (may be 0 if server doesn't provide it).
4415 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4416 pub async fn download_artifact(
4417 &self,
4418 training_session_id: TrainingSessionID,
4419 modelname: &str,
4420 filename: Option<PathBuf>,
4421 progress: Option<Sender<Progress>>,
4422 ) -> Result<(), Error> {
4423 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
4424 let resp = self
4425 .bulk_http
4426 .get(format!(
4427 "{}/download_model?training_session_id={}&file={}",
4428 self.url,
4429 training_session_id.value(),
4430 modelname
4431 ))
4432 .header("Authorization", format!("Bearer {}", self.token().await))
4433 .send()
4434 .await?;
4435 if !resp.status().is_success() {
4436 let err = resp.error_for_status_ref().unwrap_err();
4437 return Err(Error::HttpError(err));
4438 }
4439
4440 if let Some(parent) = filename.parent() {
4441 fs::create_dir_all(parent).await?;
4442 }
4443
4444 stream_response_to_file(resp, &filename, progress).await
4445 }
4446
4447 /// Download the model checkpoint associated with the specified trainer
4448 /// session to the specified file path, if path is not provided it will be
4449 /// downloaded to the current directory with the same filename.
4450 ///
4451 /// There is no API for listing checkpoints it is expected that trainers are
4452 /// aware of possible checkpoints and their names within the checkpoint
4453 /// folder on the server.
4454 ///
4455 /// # Progress
4456 ///
4457 /// Reports progress with `status: None` as file data is received. Progress
4458 /// unit is bytes downloaded. Total is determined from the HTTP
4459 /// Content-Length header (may be 0 if server doesn't provide it).
4460 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4461 pub async fn download_checkpoint(
4462 &self,
4463 training_session_id: TrainingSessionID,
4464 checkpoint: &str,
4465 filename: Option<PathBuf>,
4466 progress: Option<Sender<Progress>>,
4467 ) -> Result<(), Error> {
4468 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
4469 let resp = self
4470 .bulk_http
4471 .get(format!(
4472 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
4473 self.url,
4474 training_session_id.value(),
4475 checkpoint
4476 ))
4477 .header("Authorization", format!("Bearer {}", self.token().await))
4478 .send()
4479 .await?;
4480 if !resp.status().is_success() {
4481 let err = resp.error_for_status_ref().unwrap_err();
4482 return Err(Error::HttpError(err));
4483 }
4484
4485 if let Some(parent) = filename.parent() {
4486 fs::create_dir_all(parent).await?;
4487 }
4488
4489 stream_response_to_file(resp, &filename, progress).await
4490 }
4491
4492 /// Return a list of tasks for the current user.
4493 ///
4494 /// # Arguments
4495 ///
4496 /// * `name` - Optional filter for task name (client-side substring match)
4497 /// * `workflow` - Optional filter for workflow/task type. If provided,
4498 /// filters server-side by exact match. Valid values include: "trainer",
4499 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
4500 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
4501 /// "convertor", "twostage"
4502 /// * `status` - Optional filter for task status (e.g., "running",
4503 /// "complete", "error")
4504 /// * `manager` - Optional filter for task manager type (e.g., "aws",
4505 /// "user", "kubernetes")
4506 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4507 pub async fn tasks(
4508 &self,
4509 name: Option<&str>,
4510 workflow: Option<&str>,
4511 status: Option<&str>,
4512 manager: Option<&str>,
4513 ) -> Result<Vec<Task>, Error> {
4514 let mut params = TasksListParams {
4515 continue_token: None,
4516 types: workflow.map(|w| vec![w.to_owned()]),
4517 status: status.map(|s| vec![s.to_owned()]),
4518 manager: manager.map(|m| vec![m.to_owned()]),
4519 };
4520 let mut tasks = Vec::new();
4521
4522 loop {
4523 let result = self
4524 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
4525 .await?;
4526 tasks.extend(result.tasks);
4527
4528 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
4529 params.continue_token = None;
4530 } else {
4531 params.continue_token = result.continue_token;
4532 }
4533
4534 if params.continue_token.is_none() {
4535 break;
4536 }
4537 }
4538
4539 if let Some(name) = name {
4540 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4541 }
4542
4543 Ok(tasks)
4544 }
4545
4546 /// Submits a job (app run) to the server and returns the resulting `Job`
4547 /// record (which carries the linked task id alongside the cloud-batch
4548 /// metadata).
4549 ///
4550 /// # Arguments
4551 /// * `app_name` - The name of the registered app to run (e.g., `"edgefirst-validator"`).
4552 /// * `job_name` - A user-defined label for this run.
4553 /// * `env` - Environment variables passed to the job (string-string map).
4554 /// * `data` - Job input payload (e.g., session ids, parameters).
4555 ///
4556 /// # Returns
4557 /// The full `Job` record returned by the server (wraps the BK_BATCH object),
4558 /// including AWS Batch job ID, state, and the linked `task_id`. Callers that
4559 /// only need the task ID can call `.task_id()` on the returned `Job`.
4560 pub async fn job_run(
4561 &self,
4562 app_name: &str,
4563 job_name: &str,
4564 env: std::collections::HashMap<String, String>,
4565 data: std::collections::HashMap<String, crate::api::Parameter>,
4566 ) -> Result<crate::api::Job, Error> {
4567 let req = JobRunRequest {
4568 name: app_name.to_owned(),
4569 job_name: job_name.to_owned(),
4570 env,
4571 data,
4572 };
4573 let resp: crate::api::Job = match self.rpc("job.run".to_owned(), Some(&req)).await {
4574 Ok(r) => r,
4575 Err(Error::RpcError(code, msg)) => {
4576 return Err(map_rpc_error("job.run", code, msg, None));
4577 }
4578 Err(e) => return Err(e),
4579 };
4580 Ok(resp)
4581 }
4582
4583 /// Requests a running job task be stopped.
4584 ///
4585 /// Returns `Ok(())` if the stop request was accepted by the server. The
4586 /// task may still take time to fully terminate; poll `task_info` if you
4587 /// need to wait for shutdown.
4588 pub async fn job_stop(&self, task_id: crate::api::TaskID) -> Result<(), Error> {
4589 let req = JobStopRequest {
4590 task_id: task_id.value(),
4591 };
4592 // We don't care about the response body; deserialize as serde_json::Value.
4593 let _resp: serde_json::Value = match self.rpc("job.stop".to_owned(), Some(&req)).await {
4594 Ok(r) => r,
4595 Err(Error::RpcError(code, msg)) => {
4596 return Err(map_rpc_error("job.stop", code, msg, Some(task_id)));
4597 }
4598 Err(e) => return Err(e),
4599 };
4600 Ok(())
4601 }
4602
4603 /// Lists job (app-run) entries visible to the authenticated user.
4604 ///
4605 /// The server returns AWS Batch-wrapper entries (not bare `Task` objects),
4606 /// surfacing cloud-batch state (`RUNNING`/`SUCCEEDED`/...) and the linked
4607 /// `task_id`. Use `Job::task_id()` + `Client::task_info` to fetch the
4608 /// underlying task details.
4609 ///
4610 /// The server does not support server-side filters, so the optional
4611 /// `name` argument is applied client-side as a substring match against
4612 /// each job's `job_name`.
4613 pub async fn jobs(&self, name: Option<&str>) -> Result<Vec<crate::api::Job>, Error> {
4614 let req = JobsListRequest {};
4615 let mut jobs: Vec<crate::api::Job> = match self.rpc("job.list".to_owned(), Some(&req)).await
4616 {
4617 Ok(r) => r,
4618 Err(Error::RpcError(code, msg)) => {
4619 return Err(map_rpc_error("job.list", code, msg, None));
4620 }
4621 Err(e) => return Err(e),
4622 };
4623 if let Some(name) = name {
4624 let needle = name.to_lowercase();
4625 jobs.retain(|j| j.job_name.to_lowercase().contains(&needle));
4626 jobs.sort_by(|a, b| a.job_name.cmp(&b.job_name));
4627 }
4628 Ok(jobs)
4629 }
4630
4631 /// Retrieve the task information and status.
4632 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4633 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4634 self.rpc(
4635 "task.get".to_owned(),
4636 Some(HashMap::from([("id", task_id)])),
4637 )
4638 .await
4639 }
4640
4641 /// Updates the tasks status.
4642 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4643 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4644 let status = TaskStatus {
4645 task_id,
4646 status: status.to_owned(),
4647 };
4648 self.rpc("docker.update.status".to_owned(), Some(status))
4649 .await
4650 }
4651
4652 /// Defines the stages for the task. The stages are defined as a mapping
4653 /// from stage names to their descriptions. Once stages are defined their
4654 /// status can be updated using the update_stage method.
4655 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4656 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4657 let stages: Vec<HashMap<String, String>> = stages
4658 .iter()
4659 .map(|(key, value)| {
4660 let mut stage_map = HashMap::new();
4661 stage_map.insert(key.to_string(), value.to_string());
4662 stage_map
4663 })
4664 .collect();
4665 let params = TaskStages { task_id, stages };
4666 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4667 Ok(())
4668 }
4669
4670 /// Updates the progress of the task for the provided stage and status
4671 /// information.
4672 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4673 pub async fn update_stage(
4674 &self,
4675 task_id: TaskID,
4676 stage: &str,
4677 status: &str,
4678 message: &str,
4679 percentage: u8,
4680 ) -> Result<(), Error> {
4681 let stage = Stage::new(
4682 Some(task_id),
4683 stage.to_owned(),
4684 Some(status.to_owned()),
4685 Some(message.to_owned()),
4686 percentage,
4687 );
4688 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4689 Ok(())
4690 }
4691
4692 /// Authenticated fetch from the Studio server using the bulk HTTP client
4693 /// (no total-request timeout; idle read timeout per chunk).
4694 ///
4695 /// **Buffers the entire response body into memory.** Suitable for small to
4696 /// medium payloads. For very large binary downloads (multi-GB artifacts or
4697 /// checkpoints), prefer a streaming approach that writes directly to disk.
4698 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4699 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4700 let req = self
4701 .bulk_http
4702 .get(format!("{}/{}", self.url, query))
4703 .header("User-Agent", "EdgeFirst Client")
4704 .header("Authorization", format!("Bearer {}", self.token().await));
4705 let resp = req.send().await?;
4706
4707 if resp.status().is_success() {
4708 let body = resp.bytes().await?;
4709
4710 if log_enabled!(Level::Trace) {
4711 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
4712 }
4713
4714 Ok(body.to_vec())
4715 } else {
4716 let err = resp.error_for_status_ref().unwrap_err();
4717 Err(Error::HttpError(err))
4718 }
4719 }
4720
4721 /// Sends a multipart post request to the server. This is used by the
4722 /// upload and download APIs which do not use JSON-RPC but instead transfer
4723 /// files using multipart/form-data.
4724 ///
4725 /// The result field is deserialized as `serde_json::Value` rather than
4726 /// `String` because different server endpoints return different shapes —
4727 /// `val.data.upload` returns a plain string while `task.data.upload`
4728 /// returns an object `{"message":…,"path":…,"size":…}`. All current
4729 /// callers discard the return value so this is backwards-compatible.
4730 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
4731 pub async fn post_multipart(
4732 &self,
4733 method: &str,
4734 form: Form,
4735 ) -> Result<serde_json::Value, Error> {
4736 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4737 .ok()
4738 .and_then(|s| s.parse().ok())
4739 .unwrap_or(600u64);
4740
4741 let req = self
4742 .http
4743 .post(format!("{}/api?method={}", self.url, method))
4744 .header("Accept", "application/json")
4745 .header("User-Agent", "EdgeFirst Client")
4746 .header("Authorization", format!("Bearer {}", self.token().await))
4747 .timeout(Duration::from_secs(upload_timeout_secs))
4748 .multipart(form);
4749 let resp = req.send().await?;
4750
4751 if resp.status().is_success() {
4752 let body = resp.bytes().await?;
4753
4754 if log_enabled!(Level::Trace) {
4755 trace!(
4756 "POST Multipart Response: {}",
4757 String::from_utf8_lossy(&body)
4758 );
4759 }
4760
4761 let response: RpcResponse<serde_json::Value> = match serde_json::from_slice(&body) {
4762 Ok(response) => response,
4763 Err(err) => {
4764 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4765 return Err(err.into());
4766 }
4767 };
4768
4769 if let Some(error) = response.error {
4770 Err(Error::RpcError(error.code, error.message))
4771 } else if let Some(result) = response.result {
4772 Ok(result)
4773 } else {
4774 Err(Error::InvalidResponse)
4775 }
4776 } else {
4777 // HTTP-level failure on the multipart upload. Map 413 to the
4778 // typed `PayloadTooLarge` variant so callers see the same error
4779 // type from both single-file rpc_download paths and multipart
4780 // upload paths; everything else falls through to HttpError.
4781 let status = resp.status();
4782 if status.as_u16() == 413 {
4783 return Err(Error::PayloadTooLarge {
4784 method: method.to_string(),
4785 size_hint: None,
4786 });
4787 }
4788 let err = resp.error_for_status_ref().unwrap_err();
4789 Err(Error::HttpError(err))
4790 }
4791 }
4792
4793 /// Internal helper: POST a JSON-RPC request and stream the binary response
4794 /// to `output_path`. The response is assumed to be raw binary (not a JSON
4795 /// envelope). Use for endpoints that return file contents directly.
4796 ///
4797 /// On HTTP non-success, the response body is read as text and surfaced
4798 /// via `Error::RpcError(status_code, body)`.
4799 pub(crate) async fn rpc_download<P: Serialize>(
4800 &self,
4801 method: &str,
4802 params: &P,
4803 output_path: &std::path::Path,
4804 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
4805 ) -> Result<(), Error> {
4806 let envelope = serde_json::json!({
4807 "jsonrpc": "2.0",
4808 "id": 0,
4809 "method": method,
4810 "params": params,
4811 });
4812
4813 let url = format!("{}/api", self.url);
4814 let resp = self
4815 .bulk_http
4816 .post(&url)
4817 .header("Authorization", format!("Bearer {}", self.token().await))
4818 .json(&envelope)
4819 .send()
4820 .await?;
4821
4822 let status = resp.status();
4823 if !status.is_success() {
4824 if status.as_u16() == 413 {
4825 return Err(Error::PayloadTooLarge {
4826 method: method.to_string(),
4827 size_hint: None,
4828 });
4829 }
4830 let body = resp.text().await.unwrap_or_default();
4831 return Err(Error::RpcError(status.as_u16() as i32, body));
4832 }
4833
4834 // HTTP 200 with Content-Type: application/json can mean two things:
4835 // (a) a JSON-RPC error envelope when the server failed mid-way
4836 // (e.g. {"jsonrpc":"2.0","error":{"code":N,"message":"..."}}),
4837 // (b) a legitimate JSON file payload — validation traces, chart
4838 // bodies, metrics, etc., are typically served with this MIME.
4839 //
4840 // Disambiguate structurally: a JSON-RPC 2.0 envelope is required to
4841 // carry a `jsonrpc` member, and an *error* envelope further requires
4842 // an `error.code` integer (per RFC 8259 + JSON-RPC 2.0 §5). Only
4843 // decode the body as an error if both markers are present. This is
4844 // strict enough to leave legitimate JSON artifacts that happen to
4845 // contain a free-form `error` field (metrics, diagnostics, log
4846 // dumps) untouched, while still catching every real server
4847 // failure.
4848 let content_type = resp
4849 .headers()
4850 .get(reqwest::header::CONTENT_TYPE)
4851 .and_then(|v| v.to_str().ok())
4852 .unwrap_or("")
4853 .to_owned();
4854 if content_type.contains("application/json") {
4855 let body = resp.bytes().await?;
4856 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&body)
4857 && is_jsonrpc_error_envelope(&val)
4858 && let Some(err_obj) = val.get("error")
4859 {
4860 let code = err_obj.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
4861 let message = err_obj
4862 .get("message")
4863 .and_then(|m| m.as_str())
4864 .unwrap_or("unknown error")
4865 .to_string();
4866 return Err(Error::RpcError(code, message));
4867 }
4868 // Not an error envelope — body is a JSON file. Write it to disk
4869 // and emit a single completion progress event so callers (e.g.,
4870 // Python download_data progress callbacks) see the download
4871 // finish.
4872 //
4873 // `Path::parent` returns `Some("")` for a bare filename like
4874 // "metrics.json"; `create_dir_all("")` errors out with
4875 // `NotFound`, so only create the parent when it actually names
4876 // a directory.
4877 if let Some(parent) = output_path.parent()
4878 && !parent.as_os_str().is_empty()
4879 {
4880 tokio::fs::create_dir_all(parent).await?;
4881 }
4882 let mut file = tokio::fs::File::create(output_path).await?;
4883 file.write_all(&body).await?;
4884 file.flush().await?;
4885 if let Some(tx) = progress {
4886 let total = body.len();
4887 // Use the awaited send for the final event so completion
4888 // handlers are never silently dropped.
4889 let _ = tx
4890 .send(Progress {
4891 current: total,
4892 total,
4893 status: None,
4894 })
4895 .await;
4896 }
4897 return Ok(());
4898 }
4899
4900 // Same empty-parent guard for the streaming download path: passing
4901 // a bare filename like "metrics.json" must write to the current
4902 // directory rather than failing on `create_dir_all("")`.
4903 if let Some(parent) = output_path.parent()
4904 && !parent.as_os_str().is_empty()
4905 {
4906 tokio::fs::create_dir_all(parent).await?;
4907 }
4908
4909 stream_response_to_file(resp, output_path, progress).await
4910 }
4911
4912 /// Send a JSON-RPC request to the server. The method is the name of the
4913 /// method to call on the server. The params are the parameters to pass to
4914 /// the method. The method and params are serialized into a JSON-RPC
4915 /// request and sent to the server. The response is deserialized into
4916 /// the specified type and returned to the caller.
4917 ///
4918 /// NOTE: This API would generally not be called directly and instead users
4919 /// should use the higher-level methods provided by the client.
4920 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
4921 pub async fn rpc<Params, RpcResult>(
4922 &self,
4923 method: String,
4924 params: Option<Params>,
4925 ) -> Result<RpcResult, Error>
4926 where
4927 Params: Serialize,
4928 RpcResult: DeserializeOwned,
4929 {
4930 let auth_expires = self.token_expiration().await?;
4931 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
4932 self.renew_token().await?;
4933 }
4934
4935 self.rpc_without_auth(method, params).await
4936 }
4937
4938 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
4939 async fn rpc_without_auth<Params, RpcResult>(
4940 &self,
4941 method: String,
4942 params: Option<Params>,
4943 ) -> Result<RpcResult, Error>
4944 where
4945 Params: Serialize,
4946 RpcResult: DeserializeOwned,
4947 {
4948 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4949 .ok()
4950 .and_then(|s| s.parse().ok())
4951 .unwrap_or(5usize);
4952
4953 let url = format!("{}/api", self.url);
4954
4955 // Serialize request body once before retry loop to avoid Clone bound on Params
4956 let request = RpcRequest {
4957 method: method.clone(),
4958 params,
4959 ..Default::default()
4960 };
4961
4962 // Log request for debugging (log crate) and profiling (tracing crate)
4963 let request_json = if method == "auth.login" {
4964 // Redact auth.login params (contains password)
4965 serde_json::json!({
4966 "jsonrpc": "2.0",
4967 "method": &method,
4968 "params": "[REDACTED - contains credentials]",
4969 "id": request.id
4970 })
4971 .to_string()
4972 } else {
4973 serde_json::to_string(&request)?
4974 };
4975
4976 if log_enabled!(Level::Trace) {
4977 trace!("RPC Request: {}", request_json);
4978 }
4979
4980 // Record request on current span for Perfetto when profiling is enabled
4981 #[cfg(feature = "profiling")]
4982 tracing::Span::current().record("request", &request_json);
4983
4984 let request_body = serde_json::to_vec(&request)?;
4985 let mut last_error: Option<Error> = None;
4986
4987 for attempt in 0..=max_retries {
4988 if attempt > 0 {
4989 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
4990 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
4991 // while ensuring we never retry faster than the base delay
4992 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
4993 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
4994 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
4995 let delay = Duration::from_millis(delay_ms);
4996 warn!(
4997 "Retry {}/{} for RPC '{}' after {:?}",
4998 attempt, max_retries, method, delay
4999 );
5000 tokio::time::sleep(delay).await;
5001 }
5002
5003 let result = self
5004 .http
5005 .post(&url)
5006 .header("Accept", "application/json")
5007 .header("Content-Type", "application/json")
5008 .header("User-Agent", "EdgeFirst Client")
5009 .header("Authorization", format!("Bearer {}", self.token().await))
5010 .body(request_body.clone())
5011 .send()
5012 .await;
5013
5014 match result {
5015 Ok(res) => {
5016 let status = res.status();
5017 let status_code = status.as_u16();
5018
5019 // Check for retryable HTTP status codes before processing response
5020 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
5021 && attempt < max_retries
5022 {
5023 warn!(
5024 "RPC '{}' failed with HTTP {} (retrying)",
5025 method, status_code
5026 );
5027 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
5028 continue;
5029 }
5030
5031 // Process the response
5032 match self.process_rpc_response(res).await {
5033 Ok(result) => {
5034 if attempt > 0 {
5035 debug!("RPC '{}' succeeded on retry {}", method, attempt);
5036 }
5037 return Ok(result);
5038 }
5039 Err(e) => {
5040 // Don't retry client errors (4xx except 408, 429)
5041 if attempt > 0 {
5042 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
5043 }
5044 return Err(e);
5045 }
5046 }
5047 }
5048 Err(e) => {
5049 // Transport error (timeout, connection failure, etc.)
5050 let is_timeout = e.is_timeout();
5051 let is_connect = e.is_connect();
5052
5053 if (is_timeout || is_connect) && attempt < max_retries {
5054 warn!(
5055 "RPC '{}' transport error (retrying): {}",
5056 method,
5057 if is_timeout {
5058 "timeout"
5059 } else {
5060 "connection failed"
5061 }
5062 );
5063 last_error = Some(Error::HttpError(e));
5064 continue;
5065 }
5066
5067 if attempt > 0 {
5068 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
5069 }
5070 return Err(Error::HttpError(e));
5071 }
5072 }
5073 }
5074
5075 // Should not reach here
5076 Err(last_error.unwrap_or_else(|| {
5077 Error::InvalidParameters(format!(
5078 "RPC '{}' failed after {} retries",
5079 method, max_retries
5080 ))
5081 }))
5082 }
5083
5084 async fn process_rpc_response<RpcResult>(
5085 &self,
5086 res: reqwest::Response,
5087 ) -> Result<RpcResult, Error>
5088 where
5089 RpcResult: DeserializeOwned,
5090 {
5091 let body = res.bytes().await?;
5092 let response_str = String::from_utf8_lossy(&body);
5093
5094 if log_enabled!(Level::Trace) {
5095 trace!("RPC Response: {}", response_str);
5096 }
5097
5098 // Record response on current span for Perfetto when profiling is enabled
5099 // Truncate large responses to avoid bloating trace files
5100 #[cfg(feature = "profiling")]
5101 {
5102 const MAX_RESPONSE_LEN: usize = 4096;
5103 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
5104 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
5105 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
5106 format!(
5107 "{}...[truncated {} bytes]",
5108 &response_str[..safe_end],
5109 response_str.len() - safe_end
5110 )
5111 } else {
5112 response_str.to_string()
5113 };
5114 tracing::Span::current().record("response", &truncated);
5115 }
5116
5117 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
5118 Ok(response) => response,
5119 Err(err) => {
5120 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
5121 return Err(err.into());
5122 }
5123 };
5124
5125 // FIXME: Studio Server always returns 999 as the id.
5126 // if request.id.to_string() != response.id {
5127 // return Err(Error::InvalidRpcId(response.id));
5128 // }
5129
5130 if let Some(error) = response.error {
5131 Err(Error::RpcError(error.code, error.message))
5132 } else if let Some(result) = response.result {
5133 Ok(result)
5134 } else {
5135 Err(Error::InvalidResponse)
5136 }
5137 }
5138}
5139
5140/// Process items in parallel with semaphore concurrency control and progress
5141/// tracking.
5142///
5143/// This helper eliminates boilerplate for parallel item processing with:
5144/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
5145/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
5146/// - Atomic progress counter with automatic item-level updates
5147/// - Progress updates sent after each item completes (not byte-level streaming)
5148/// - Proper error propagation from spawned tasks
5149///
5150/// Note: This is optimized for discrete items with post-completion progress
5151/// updates. For byte-level streaming progress or custom retry logic, use
5152/// specialized implementations.
5153///
5154/// # Arguments
5155///
5156/// * `items` - Collection of items to process in parallel
5157/// * `progress` - Optional progress channel for tracking completion
5158/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
5159/// * `work_fn` - Async function to execute for each item
5160///
5161/// # Examples
5162///
5163/// ```rust,ignore
5164/// // Use default concurrency
5165/// parallel_foreach_items(samples, progress, None, |sample| async move {
5166/// sample.download(&client, file_type).await?;
5167/// Ok(())
5168/// }).await?;
5169/// ```
5170async fn parallel_foreach_items<T, F, Fut>(
5171 items: Vec<T>,
5172 progress: Option<Sender<Progress>>,
5173 concurrency: Option<usize>,
5174 work_fn: F,
5175) -> Result<(), Error>
5176where
5177 T: Send + 'static,
5178 F: Fn(T) -> Fut + Send + Sync + 'static,
5179 Fut: Future<Output = Result<(), Error>> + Send + 'static,
5180{
5181 let total = items.len();
5182 let current = Arc::new(AtomicUsize::new(0));
5183 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
5184 let work_fn = Arc::new(work_fn);
5185
5186 let tasks = items
5187 .into_iter()
5188 .map(|item| {
5189 let sem = sem.clone();
5190 let current = current.clone();
5191 let progress = progress.clone();
5192 let work_fn = work_fn.clone();
5193
5194 tokio::spawn(async move {
5195 let _permit = sem.acquire().await.map_err(|_| {
5196 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5197 })?;
5198
5199 // Execute the actual work
5200 work_fn(item).await?;
5201
5202 // Update progress
5203 if let Some(progress) = &progress {
5204 let current = current.fetch_add(1, Ordering::SeqCst);
5205 let _ = progress
5206 .send(Progress {
5207 current: current + 1,
5208 total,
5209 status: None,
5210 })
5211 .await;
5212 }
5213
5214 Ok::<(), Error>(())
5215 })
5216 })
5217 .collect::<Vec<_>>();
5218
5219 join_all(tasks)
5220 .await
5221 .into_iter()
5222 .collect::<Result<Vec<_>, _>>()?
5223 .into_iter()
5224 .collect::<Result<Vec<_>, _>>()?;
5225
5226 if let Some(progress) = progress {
5227 drop(progress);
5228 }
5229
5230 Ok(())
5231}
5232
5233/// Upload a file to S3 using multipart upload with presigned URLs.
5234///
5235/// Splits a file into chunks (100MB each) and uploads them in parallel using
5236/// S3 multipart upload protocol. Returns completion parameters with ETags for
5237/// finalizing the upload.
5238///
5239/// This function handles:
5240/// - Splitting files into parts based on PART_SIZE (100MB)
5241/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
5242/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
5243/// - Retry logic (handled by reqwest client)
5244/// - Progress tracking across all parts
5245///
5246/// # Arguments
5247///
5248/// * `http` - HTTP client for making requests
5249/// * `part` - Snapshot part info with presigned URLs for each chunk
5250/// * `path` - Local file path to upload
5251/// * `total` - Total bytes across all files for progress calculation
5252/// * `current` - Atomic counter tracking bytes uploaded across all operations
5253/// * `progress` - Optional channel for sending progress updates
5254///
5255/// # Returns
5256///
5257/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
5258async fn upload_multipart(
5259 http: reqwest::Client,
5260 part: SnapshotPart,
5261 path: PathBuf,
5262 total: usize,
5263 confirmed_bytes: Arc<AtomicUsize>,
5264 progress: Option<Sender<Progress>>,
5265) -> Result<SnapshotCompleteMultipartParams, Error> {
5266 let filesize = path.metadata()?.len() as usize;
5267 let n_parts = filesize.div_ceil(PART_SIZE);
5268 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
5269
5270 let key = part.key.ok_or(Error::InvalidResponse)?;
5271 let upload_id = part.upload_id;
5272
5273 let urls = part.urls.clone();
5274
5275 // Pre-allocate ETag slots for all parts
5276 let etags = Arc::new(tokio::sync::Mutex::new(vec![
5277 EtagPart {
5278 etag: "".to_owned(),
5279 part_number: 0,
5280 };
5281 n_parts
5282 ]));
5283
5284 // Per-part byte counters for streaming progress (reset on retry)
5285 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
5286 (0..n_parts)
5287 .map(|_| AtomicUsize::new(0))
5288 .collect::<Vec<_>>(),
5289 );
5290
5291 // Upload all parts in parallel with concurrency limiting
5292 let tasks = (0..n_parts)
5293 .map(|part_idx| {
5294 let http = http.clone();
5295 let url = urls[part_idx].clone();
5296 let etags = etags.clone();
5297 let path = path.to_owned();
5298 let sem = sem.clone();
5299 let progress = progress.clone();
5300 let confirmed_bytes = confirmed_bytes.clone();
5301 let part_bytes = part_bytes.clone();
5302
5303 // Calculate this part's size
5304 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5305 filesize % PART_SIZE
5306 } else {
5307 PART_SIZE
5308 };
5309
5310 tokio::spawn(async move {
5311 // Acquire semaphore permit to limit concurrent uploads
5312 let _permit = sem.acquire().await.map_err(|_| {
5313 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5314 })?;
5315
5316 // Upload part with streaming progress and retry logic
5317 let etag = upload_part_with_progress(
5318 http,
5319 url,
5320 path,
5321 part_idx,
5322 n_parts,
5323 part_size,
5324 total,
5325 confirmed_bytes.clone(),
5326 part_bytes.clone(),
5327 progress.clone(),
5328 )
5329 .await?;
5330
5331 // Store ETag for this part (needed to complete multipart upload)
5332 let mut etags_guard = etags.lock().await;
5333 etags_guard[part_idx] = EtagPart {
5334 etag,
5335 part_number: part_idx + 1,
5336 };
5337
5338 // Part completed successfully - add to confirmed bytes
5339 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
5340 // Reset part counter since it's now confirmed
5341 part_bytes[part_idx].store(0, Ordering::SeqCst);
5342
5343 // Send final progress update for this part
5344 if let Some(progress) = &progress {
5345 let current = confirmed_bytes.load(Ordering::SeqCst)
5346 + part_bytes
5347 .iter()
5348 .map(|p| p.load(Ordering::SeqCst))
5349 .sum::<usize>();
5350 let _ = progress
5351 .send(Progress {
5352 current,
5353 total,
5354 status: None,
5355 })
5356 .await;
5357 }
5358
5359 Ok::<(), Error>(())
5360 })
5361 })
5362 .collect::<Vec<_>>();
5363
5364 // Wait for all parts to complete (double collect to handle both JoinError and
5365 // inner Error)
5366 join_all(tasks)
5367 .await
5368 .into_iter()
5369 .collect::<Result<Vec<_>, _>>()?
5370 .into_iter()
5371 .collect::<Result<Vec<_>, _>>()?;
5372
5373 Ok(SnapshotCompleteMultipartParams {
5374 key,
5375 upload_id,
5376 etag_list: etags.lock().await.clone(),
5377 })
5378}
5379
5380/// Upload a single part with streaming progress tracking and retry logic.
5381///
5382/// Progress is reported continuously as bytes are sent. On retry, the part's
5383/// progress counter is reset to avoid over-reporting.
5384#[allow(clippy::too_many_arguments)]
5385async fn upload_part_with_progress(
5386 http: reqwest::Client,
5387 url: String,
5388 path: PathBuf,
5389 part_idx: usize,
5390 n_parts: usize,
5391 part_size: usize,
5392 total: usize,
5393 confirmed_bytes: Arc<AtomicUsize>,
5394 part_bytes: Arc<Vec<AtomicUsize>>,
5395 progress: Option<Sender<Progress>>,
5396) -> Result<String, Error> {
5397 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5398 .ok()
5399 .and_then(|s| s.parse().ok())
5400 .unwrap_or(5usize);
5401
5402 // Per-part total upload timeout. Covers the send phase (request body) where
5403 // read_timeout does not apply. Each part is at most PART_SIZE (100MB), so
5404 // this bounds how long a stalled upload can block before retrying.
5405 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5406 .ok()
5407 .and_then(|s| s.parse().ok())
5408 .unwrap_or(600u64); // 600s = 100MB at ~170 KB/s minimum
5409
5410 let mut last_error: Option<Error> = None;
5411
5412 for attempt in 0..=max_retries {
5413 if attempt > 0 {
5414 // Reset this part's progress counter before retry
5415 part_bytes[part_idx].store(0, Ordering::SeqCst);
5416
5417 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5418 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5419 warn!(
5420 "Retry {}/{} for part {} after {:?}",
5421 attempt, max_retries, part_idx, delay
5422 );
5423 tokio::time::sleep(delay).await;
5424 }
5425
5426 match upload_part_streaming(
5427 http.clone(),
5428 url.clone(),
5429 path.clone(),
5430 part_idx,
5431 n_parts,
5432 part_size,
5433 total,
5434 upload_timeout_secs,
5435 confirmed_bytes.clone(),
5436 part_bytes.clone(),
5437 progress.clone(),
5438 )
5439 .await
5440 {
5441 Ok(etag) => return Ok(etag),
5442 Err(e) => {
5443 // Check if error is retryable
5444 let is_retryable = matches!(
5445 &e,
5446 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
5447 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
5448 );
5449
5450 if is_retryable && attempt < max_retries {
5451 last_error = Some(e);
5452 continue;
5453 }
5454
5455 return Err(e);
5456 }
5457 }
5458 }
5459
5460 Err(last_error
5461 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
5462}
5463
5464/// Perform the actual upload with streaming progress.
5465#[allow(clippy::too_many_arguments)]
5466async fn upload_part_streaming(
5467 http: reqwest::Client,
5468 url: String,
5469 path: PathBuf,
5470 part_idx: usize,
5471 n_parts: usize,
5472 _part_size: usize,
5473 total: usize,
5474 upload_timeout_secs: u64,
5475 confirmed_bytes: Arc<AtomicUsize>,
5476 part_bytes: Arc<Vec<AtomicUsize>>,
5477 progress: Option<Sender<Progress>>,
5478) -> Result<String, Error> {
5479 let filesize = path.metadata()?.len() as usize;
5480 let mut file = File::open(&path).await?;
5481 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
5482 .await?;
5483 let file = file.take(PART_SIZE as u64);
5484
5485 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5486 filesize % PART_SIZE
5487 } else {
5488 PART_SIZE
5489 };
5490
5491 // Create stream with progress tracking
5492 let stream = FramedRead::new(file, BytesCodec::new());
5493
5494 // Wrap stream to track bytes sent and report progress
5495 let progress_stream = stream.map(move |result| {
5496 if let Ok(ref bytes) = result {
5497 let bytes_len = bytes.len();
5498 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
5499
5500 // Send progress update (fire-and-forget via try_send to avoid blocking)
5501 if let Some(ref progress) = progress {
5502 let current = confirmed_bytes.load(Ordering::SeqCst)
5503 + part_bytes
5504 .iter()
5505 .map(|p| p.load(Ordering::SeqCst))
5506 .sum::<usize>();
5507 // Best-effort progress reporting: use try_send to avoid blocking.
5508 // If the channel is full or closed, we intentionally skip this update
5509 // to avoid stalling the upload; subsequent updates will still be delivered.
5510 let _ = progress.try_send(Progress {
5511 current,
5512 total,
5513 status: None,
5514 });
5515 }
5516 }
5517 result.map(|b| b.freeze())
5518 });
5519
5520 let body = Body::wrap_stream(progress_stream);
5521
5522 let resp = http
5523 .put(url)
5524 .header(CONTENT_LENGTH, body_length)
5525 .timeout(Duration::from_secs(upload_timeout_secs))
5526 .body(body)
5527 .send()
5528 .await?
5529 .error_for_status()?;
5530
5531 let etag = resp
5532 .headers()
5533 .get("etag")
5534 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
5535 .to_str()
5536 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
5537 .to_owned();
5538
5539 // Studio Server requires etag without the quotes.
5540 let etag = etag
5541 .strip_prefix("\"")
5542 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
5543 let etag = etag
5544 .strip_suffix("\"")
5545 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
5546
5547 Ok(etag.to_owned())
5548}
5549
5550/// Upload a complete file to a presigned S3 URL using HTTP PUT.
5551///
5552/// This is used for populate_samples to upload files to S3 after
5553/// receiving presigned URLs from the server.
5554///
5555/// Includes explicit retry logic with exponential backoff for transient
5556/// failures.
5557/// Classify a reqwest transport error (one where no HTTP response was received)
5558/// as a transient failure worth retrying.
5559///
5560/// Presigned-URL uploads buffer the body in memory and a PUT to the same object
5561/// key is idempotent, so replaying any transport-level failure is safe. Besides
5562/// timeouts and connect failures this covers request/body send errors such as
5563/// hyper's `IncompleteMessage` (a peer closing a keep-alive connection mid-send)
5564/// — transients that pipelined, high-concurrency uploads provoke far more often
5565/// than serial ones, and which the previous `is_timeout() || is_connect()` gate
5566/// missed (aborting the whole upload on a single blip).
5567fn is_retryable_upload_error(e: &reqwest::Error) -> bool {
5568 e.is_timeout() || e.is_connect() || e.is_request() || e.is_body()
5569}
5570
5571/// Reliable, `Instant`-based upload timing accumulators (profiling builds only).
5572///
5573/// Async `tracing` spans cannot measure per-await latency or task concurrency
5574/// under a multi-threaded runtime — a future's span fragments across worker
5575/// threads — so these atomics accumulate real measured durations and byte counts
5576/// for a trustworthy phase breakdown. Durations are summed across concurrent
5577/// batches, so totals can exceed wall-clock; `(rpc + upload) / wall` gives the
5578/// effective parallelism, and `bytes / wall` the effective upload bandwidth.
5579#[cfg(feature = "profiling")]
5580pub mod upload_stats {
5581 use std::sync::atomic::{AtomicU64, Ordering};
5582
5583 static RPC_NANOS: AtomicU64 = AtomicU64::new(0);
5584 static UPLOAD_NANOS: AtomicU64 = AtomicU64::new(0);
5585 static UPLOAD_BYTES: AtomicU64 = AtomicU64::new(0);
5586
5587 pub(crate) fn add_rpc_nanos(n: u64) {
5588 RPC_NANOS.fetch_add(n, Ordering::Relaxed);
5589 }
5590 pub(crate) fn add_upload_nanos(n: u64) {
5591 UPLOAD_NANOS.fetch_add(n, Ordering::Relaxed);
5592 }
5593 pub(crate) fn add_upload_bytes(n: u64) {
5594 UPLOAD_BYTES.fetch_add(n, Ordering::Relaxed);
5595 }
5596
5597 /// Zero all accumulators. Call once before starting an upload.
5598 pub fn reset() {
5599 RPC_NANOS.store(0, Ordering::Relaxed);
5600 UPLOAD_NANOS.store(0, Ordering::Relaxed);
5601 UPLOAD_BYTES.store(0, Ordering::Relaxed);
5602 }
5603
5604 /// Snapshot of `(rpc_nanos, upload_nanos, upload_bytes)` accumulated so far.
5605 pub fn snapshot() -> (u64, u64, u64) {
5606 (
5607 RPC_NANOS.load(Ordering::Relaxed),
5608 UPLOAD_NANOS.load(Ordering::Relaxed),
5609 UPLOAD_BYTES.load(Ordering::Relaxed),
5610 )
5611 }
5612}
5613
5614async fn upload_file_to_presigned_url(
5615 http: reqwest::Client,
5616 url: &str,
5617 path: PathBuf,
5618) -> Result<(), Error> {
5619 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5620 .ok()
5621 .and_then(|s| s.parse().ok())
5622 .unwrap_or(5usize);
5623
5624 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5625 .ok()
5626 .and_then(|s| s.parse().ok())
5627 .unwrap_or(600u64);
5628
5629 // Read the entire file into memory once
5630 let file_data = fs::read(&path).await?;
5631 let file_size = file_data.len();
5632 let filename = path.file_name().unwrap_or_default().to_string_lossy();
5633
5634 let mut last_error: Option<Error> = None;
5635
5636 for attempt in 0..=max_retries {
5637 if attempt > 0 {
5638 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5639 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5640 warn!(
5641 "Retry {}/{} for upload '{}' after {:?}",
5642 attempt, max_retries, filename, delay
5643 );
5644 tokio::time::sleep(delay).await;
5645 }
5646
5647 // Attempt upload
5648 let result = http
5649 .put(url)
5650 .header(CONTENT_LENGTH, file_size)
5651 .timeout(Duration::from_secs(upload_timeout_secs))
5652 .body(file_data.clone())
5653 .send()
5654 .await;
5655
5656 match result {
5657 Ok(resp) => {
5658 if resp.status().is_success() {
5659 if attempt > 0 {
5660 debug!(
5661 "Upload '{}' succeeded on retry {} ({} bytes)",
5662 filename, attempt, file_size
5663 );
5664 } else {
5665 debug!(
5666 "Successfully uploaded file: {} ({} bytes)",
5667 filename, file_size
5668 );
5669 }
5670 #[cfg(feature = "profiling")]
5671 upload_stats::add_upload_bytes(file_size as u64);
5672 return Ok(());
5673 }
5674
5675 let status = resp.status();
5676 let status_code = status.as_u16();
5677
5678 // Check if error is retryable
5679 let is_retryable =
5680 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5681
5682 if is_retryable && attempt < max_retries {
5683 let error_text = resp.text().await.unwrap_or_default();
5684 warn!(
5685 "Upload '{}' failed with HTTP {} (retryable): {}",
5686 filename, status_code, error_text
5687 );
5688 last_error = Some(Error::InvalidParameters(format!(
5689 "Upload failed: HTTP {} - {}",
5690 status, error_text
5691 )));
5692 continue;
5693 }
5694
5695 // Non-retryable error or max retries exceeded
5696 let error_text = resp.text().await.unwrap_or_default();
5697 if attempt > 0 {
5698 error!(
5699 "Upload '{}' failed after {} retries: HTTP {} - {}",
5700 filename, attempt, status, error_text
5701 );
5702 }
5703 return Err(Error::InvalidParameters(format!(
5704 "Upload failed: HTTP {} - {}",
5705 status, error_text
5706 )));
5707 }
5708 Err(e) => {
5709 // Transport error: no HTTP response was received. The body is
5710 // buffered in memory and the PUT is idempotent, so any transient
5711 // transport failure is safe to replay (see
5712 // `is_retryable_upload_error`).
5713 if is_retryable_upload_error(&e) && attempt < max_retries {
5714 warn!("Upload '{}' transport error (retrying): {}", filename, e);
5715 last_error = Some(Error::HttpError(e));
5716 continue;
5717 }
5718
5719 // Non-retryable or max retries exceeded
5720 if attempt > 0 {
5721 error!(
5722 "Upload '{}' failed after {} retries: {}",
5723 filename, attempt, e
5724 );
5725 }
5726 return Err(Error::HttpError(e));
5727 }
5728 }
5729 }
5730
5731 // Should not reach here, but return last error if we do
5732 Err(last_error.unwrap_or_else(|| {
5733 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5734 }))
5735}
5736
5737/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
5738///
5739/// This is used for populate_samples to upload file content from memory
5740/// (e.g., from ZIP archives) without writing to disk first.
5741///
5742/// Includes explicit retry logic with exponential backoff for transient
5743/// failures.
5744async fn upload_bytes_to_presigned_url(
5745 http: reqwest::Client,
5746 url: &str,
5747 file_data: Vec<u8>,
5748 filename: &str,
5749) -> Result<(), Error> {
5750 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5751 .ok()
5752 .and_then(|s| s.parse().ok())
5753 .unwrap_or(5usize);
5754
5755 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5756 .ok()
5757 .and_then(|s| s.parse().ok())
5758 .unwrap_or(600u64);
5759
5760 let file_size = file_data.len();
5761 let mut last_error: Option<Error> = None;
5762
5763 for attempt in 0..=max_retries {
5764 if attempt > 0 {
5765 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5766 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5767 warn!(
5768 "Retry {}/{} for upload '{}' after {:?}",
5769 attempt, max_retries, filename, delay
5770 );
5771 tokio::time::sleep(delay).await;
5772 }
5773
5774 // Attempt upload
5775 let result = http
5776 .put(url)
5777 .header(CONTENT_LENGTH, file_size)
5778 .timeout(Duration::from_secs(upload_timeout_secs))
5779 .body(file_data.clone())
5780 .send()
5781 .await;
5782
5783 match result {
5784 Ok(resp) => {
5785 if resp.status().is_success() {
5786 if attempt > 0 {
5787 debug!(
5788 "Upload '{}' succeeded on retry {} ({} bytes)",
5789 filename, attempt, file_size
5790 );
5791 } else {
5792 debug!(
5793 "Successfully uploaded file: {} ({} bytes)",
5794 filename, file_size
5795 );
5796 }
5797 #[cfg(feature = "profiling")]
5798 upload_stats::add_upload_bytes(file_size as u64);
5799 return Ok(());
5800 }
5801
5802 let status = resp.status();
5803 let status_code = status.as_u16();
5804
5805 // Check if error is retryable
5806 let is_retryable =
5807 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5808
5809 if is_retryable && attempt < max_retries {
5810 let error_text = resp.text().await.unwrap_or_default();
5811 warn!(
5812 "Upload '{}' failed with HTTP {} (retryable): {}",
5813 filename, status_code, error_text
5814 );
5815 last_error = Some(Error::InvalidParameters(format!(
5816 "Upload failed: HTTP {} - {}",
5817 status, error_text
5818 )));
5819 continue;
5820 }
5821
5822 // Non-retryable error or max retries exceeded
5823 let error_text = resp.text().await.unwrap_or_default();
5824 if attempt > 0 {
5825 error!(
5826 "Upload '{}' failed after {} retries: HTTP {} - {}",
5827 filename, attempt, status, error_text
5828 );
5829 }
5830 return Err(Error::InvalidParameters(format!(
5831 "Upload failed: HTTP {} - {}",
5832 status, error_text
5833 )));
5834 }
5835 Err(e) => {
5836 // Transport error: no HTTP response was received. The body is
5837 // buffered in memory and the PUT is idempotent, so any transient
5838 // transport failure is safe to replay (see
5839 // `is_retryable_upload_error`).
5840 if is_retryable_upload_error(&e) && attempt < max_retries {
5841 warn!("Upload '{}' transport error (retrying): {}", filename, e);
5842 last_error = Some(Error::HttpError(e));
5843 continue;
5844 }
5845
5846 // Non-retryable or max retries exceeded
5847 if attempt > 0 {
5848 error!(
5849 "Upload '{}' failed after {} retries: {}",
5850 filename, attempt, e
5851 );
5852 }
5853 return Err(Error::HttpError(e));
5854 }
5855 }
5856 }
5857
5858 // Should not reach here, but return last error if we do
5859 Err(last_error.unwrap_or_else(|| {
5860 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5861 }))
5862}
5863
5864#[cfg(test)]
5865mod tests {
5866 use super::*;
5867
5868 #[test]
5869 fn test_filter_and_sort_by_name_exact_match_first() {
5870 // Test that exact matches come first
5871 let items = vec![
5872 "Deer Roundtrip 123".to_string(),
5873 "Deer".to_string(),
5874 "Reindeer".to_string(),
5875 "DEER".to_string(),
5876 ];
5877 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5878 assert_eq!(result[0], "Deer"); // Exact match first
5879 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
5880 }
5881
5882 #[test]
5883 fn test_filter_and_sort_by_name_shorter_names_preferred() {
5884 // Test that shorter names (more specific) come before longer ones
5885 let items = vec![
5886 "Test Dataset ABC".to_string(),
5887 "Test".to_string(),
5888 "Test Dataset".to_string(),
5889 ];
5890 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5891 assert_eq!(result[0], "Test"); // Exact match first
5892 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
5893 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
5894 }
5895
5896 #[test]
5897 fn test_filter_and_sort_by_name_case_insensitive_filter() {
5898 // Test that filtering is case-insensitive
5899 let items = vec![
5900 "UPPERCASE".to_string(),
5901 "lowercase".to_string(),
5902 "MixedCase".to_string(),
5903 ];
5904 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
5905 assert_eq!(result.len(), 3); // All items should match
5906 }
5907
5908 #[test]
5909 fn test_filter_and_sort_by_name_no_matches() {
5910 // Test that empty result is returned when no matches
5911 let items = vec!["Apple".to_string(), "Banana".to_string()];
5912 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
5913 assert!(result.is_empty());
5914 }
5915
5916 #[test]
5917 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
5918 // Test alphabetical ordering for same-length names
5919 let items = vec![
5920 "TestC".to_string(),
5921 "TestA".to_string(),
5922 "TestB".to_string(),
5923 ];
5924 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5925 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
5926 }
5927
5928 #[test]
5929 fn test_build_filename_no_flatten() {
5930 // When flatten=false, should return base_name unchanged
5931 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
5932 assert_eq!(result, "image.jpg");
5933
5934 let result = Client::build_filename("test.png", false, None, None);
5935 assert_eq!(result, "test.png");
5936 }
5937
5938 #[test]
5939 fn test_build_filename_flatten_no_sequence() {
5940 // When flatten=true but no sequence, should return base_name unchanged
5941 let result = Client::build_filename("standalone.jpg", true, None, None);
5942 assert_eq!(result, "standalone.jpg");
5943 }
5944
5945 #[test]
5946 fn test_build_filename_flatten_with_sequence_not_prefixed() {
5947 // When flatten=true, in sequence, filename not prefixed → add prefix
5948 let result = Client::build_filename(
5949 "image.camera.jpeg",
5950 true,
5951 Some(&"deer_sequence".to_string()),
5952 Some(42),
5953 );
5954 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
5955 }
5956
5957 #[test]
5958 fn test_build_filename_flatten_with_sequence_no_frame() {
5959 // When flatten=true, in sequence, no frame number → prefix with sequence only
5960 let result =
5961 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
5962 assert_eq!(result, "sequence_A_image.jpg");
5963 }
5964
5965 #[test]
5966 fn test_build_filename_flatten_already_prefixed() {
5967 // When flatten=true, filename already starts with sequence_ → return unchanged
5968 let result = Client::build_filename(
5969 "deer_sequence_042.camera.jpeg",
5970 true,
5971 Some(&"deer_sequence".to_string()),
5972 Some(42),
5973 );
5974 assert_eq!(result, "deer_sequence_042.camera.jpeg");
5975 }
5976
5977 #[test]
5978 fn test_build_filename_flatten_already_prefixed_different_frame() {
5979 // Edge case: filename has sequence prefix but we're adding different frame
5980 // Should still respect existing prefix
5981 let result = Client::build_filename(
5982 "sequence_A_001.jpg",
5983 true,
5984 Some(&"sequence_A".to_string()),
5985 Some(2),
5986 );
5987 assert_eq!(result, "sequence_A_001.jpg");
5988 }
5989
5990 #[test]
5991 fn test_build_filename_flatten_partial_match() {
5992 // Edge case: filename contains sequence name but not as prefix
5993 let result = Client::build_filename(
5994 "test_sequence_A_image.jpg",
5995 true,
5996 Some(&"sequence_A".to_string()),
5997 Some(5),
5998 );
5999 // Should add prefix because it doesn't START with "sequence_A_"
6000 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
6001 }
6002
6003 #[test]
6004 fn test_build_filename_flatten_preserves_extension() {
6005 // Verify that file extensions are preserved correctly
6006 let extensions = vec![
6007 "jpeg",
6008 "jpg",
6009 "png",
6010 "camera.jpeg",
6011 "lidar.pcd",
6012 "depth.png",
6013 ];
6014
6015 for ext in extensions {
6016 let filename = format!("image.{}", ext);
6017 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
6018 assert!(
6019 result.ends_with(&format!(".{}", ext)),
6020 "Extension .{} not preserved in {}",
6021 ext,
6022 result
6023 );
6024 }
6025 }
6026
6027 #[test]
6028 fn test_build_filename_flatten_sanitization_compatibility() {
6029 // Test with sanitized path components (no special chars)
6030 let result = Client::build_filename(
6031 "sample_001.jpg",
6032 true,
6033 Some(&"seq_name_with_underscores".to_string()),
6034 Some(10),
6035 );
6036 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
6037 }
6038
6039 // =========================================================================
6040 // Additional filter_and_sort_by_name tests for exact match determinism
6041 // =========================================================================
6042
6043 #[test]
6044 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
6045 // Test that searching for "Deer" always returns "Deer" first, not
6046 // "Deer Roundtrip 20251129" or similar
6047 let items = vec![
6048 "Deer Roundtrip 20251129".to_string(),
6049 "White-Tailed Deer".to_string(),
6050 "Deer".to_string(),
6051 "Deer Snapshot Test".to_string(),
6052 "Reindeer Dataset".to_string(),
6053 ];
6054
6055 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6056
6057 // CRITICAL: First result must be exact match "Deer"
6058 assert_eq!(
6059 result.first().map(|s| s.as_str()),
6060 Some("Deer"),
6061 "Expected exact match 'Deer' first, got: {:?}",
6062 result.first()
6063 );
6064
6065 // Verify all items containing "Deer" are present (case-insensitive)
6066 assert_eq!(result.len(), 5);
6067 }
6068
6069 #[test]
6070 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
6071 // Verify case-sensitive exact match takes priority over case-insensitive
6072 let items = vec![
6073 "DEER".to_string(),
6074 "deer".to_string(),
6075 "Deer".to_string(),
6076 "Deer Test".to_string(),
6077 ];
6078
6079 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6080
6081 // Priority 1: Case-sensitive exact match "Deer" first
6082 assert_eq!(result[0], "Deer");
6083 // Priority 2: Case-insensitive exact matches next
6084 assert!(result[1] == "DEER" || result[1] == "deer");
6085 assert!(result[2] == "DEER" || result[2] == "deer");
6086 }
6087
6088 #[test]
6089 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
6090 // Realistic scenario: User searches for snapshot "Deer" and multiple
6091 // snapshots exist with similar names
6092 let items = vec![
6093 "Unit Testing - Deer Dataset Backup".to_string(),
6094 "Deer".to_string(),
6095 "Deer Snapshot 2025-01-15".to_string(),
6096 "Original Deer".to_string(),
6097 ];
6098
6099 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6100
6101 // MUST return exact match first for deterministic test behavior
6102 assert_eq!(
6103 result[0], "Deer",
6104 "Searching for 'Deer' should return exact 'Deer' first"
6105 );
6106 }
6107
6108 #[test]
6109 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
6110 // Realistic scenario: User searches for dataset "Deer" but multiple
6111 // datasets have "Deer" in their name
6112 let items = vec![
6113 "Deer Roundtrip".to_string(),
6114 "Deer".to_string(),
6115 "deer".to_string(),
6116 "White-Tailed Deer".to_string(),
6117 "Deer-V2".to_string(),
6118 ];
6119
6120 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6121
6122 // Exact case-sensitive match must be first
6123 assert_eq!(result[0], "Deer");
6124 // Case-insensitive exact match should be second
6125 assert_eq!(result[1], "deer");
6126 // Shorter names should come before longer names
6127 assert!(
6128 result.iter().position(|s| s == "Deer-V2").unwrap()
6129 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
6130 );
6131 }
6132
6133 #[test]
6134 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
6135 // CRITICAL: The first result should ALWAYS be the best match
6136 // This is essential for deterministic test behavior
6137 let scenarios = vec![
6138 // (items, filter, expected_first)
6139 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
6140 (vec!["test", "TEST", "Test Data"], "test", "test"),
6141 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
6142 ];
6143
6144 for (items, filter, expected_first) in scenarios {
6145 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
6146 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
6147
6148 assert_eq!(
6149 result.first().map(|s| s.as_str()),
6150 Some(expected_first),
6151 "For filter '{}', expected first result '{}', got: {:?}",
6152 filter,
6153 expected_first,
6154 result.first()
6155 );
6156 }
6157 }
6158
6159 #[test]
6160 fn test_with_server_clears_storage() {
6161 use crate::storage::MemoryTokenStorage;
6162
6163 // Create client with memory storage and a token
6164 let storage = Arc::new(MemoryTokenStorage::new());
6165 storage.store("test-token").unwrap();
6166
6167 let client = Client::new().unwrap().with_storage(storage.clone());
6168
6169 // Verify token is loaded
6170 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
6171
6172 // Change server - should clear storage
6173 let _new_client = client.with_server("test").unwrap();
6174
6175 // Verify storage was cleared
6176 assert_eq!(storage.load().unwrap(), None);
6177 }
6178
6179 #[test]
6180 fn test_with_server_clears_storage_even_for_full_url() {
6181 // Regression: `with_server` used to short-circuit to `with_url`
6182 // when given a full URL, which preserved the bearer token. The
6183 // contract for `with_server` is that switching servers means
6184 // the token from the old server is no longer trusted.
6185 use crate::storage::MemoryTokenStorage;
6186
6187 let storage = Arc::new(MemoryTokenStorage::new());
6188 storage.store("token-from-old-server").unwrap();
6189 let client = Client::new().unwrap().with_storage(storage.clone());
6190 assert_eq!(
6191 storage.load().unwrap(),
6192 Some("token-from-old-server".to_string())
6193 );
6194
6195 // Switch to a self-hosted Studio (full URL). Storage must be
6196 // cleared, and the new client must have a blank in-memory token.
6197 let new_client = client
6198 .with_server("https://studio.example.com")
6199 .expect("https full URL through with_server");
6200 assert_eq!(storage.load().unwrap(), None);
6201 assert_eq!(new_client.url(), "https://studio.example.com");
6202
6203 // The new client should not carry the old token in memory either.
6204 let in_mem = tokio::runtime::Runtime::new()
6205 .unwrap()
6206 .block_on(async { new_client.token.read().await.clone() });
6207 assert!(in_mem.is_empty(), "expected blank token, got {in_mem:?}");
6208 }
6209
6210 #[test]
6211 fn test_with_server_rejects_insecure_full_url() {
6212 // `with_server` validates full URLs through `with_url`, so the
6213 // HTTPS rule applies uniformly. Plain http to a public host
6214 // must be rejected — the bearer token would otherwise leak in
6215 // plaintext when the caller next authenticates.
6216 let client = Client::new().unwrap();
6217 let err = client.with_server("http://studio.example.com").unwrap_err();
6218 assert!(matches!(err, Error::InsecureUrl(_)));
6219 }
6220
6221 // ===== with_url HTTPS enforcement =====
6222 //
6223 // The bearer token rides in the Authorization header, so plain
6224 // http:// to a public host would leak it in the clear. The function
6225 // must reject those URLs, but still let wiremock / local-dev URLs
6226 // through (loopback addresses, "localhost", "*.localhost").
6227
6228 #[test]
6229 fn with_url_accepts_https_public_host() {
6230 let client = Client::new().unwrap();
6231 let out = client
6232 .with_url("https://studio.example.com")
6233 .expect("https public host must be accepted");
6234 assert_eq!(out.url(), "https://studio.example.com");
6235 }
6236
6237 #[test]
6238 fn with_url_accepts_http_loopback_ipv4() {
6239 let client = Client::new().unwrap();
6240 let out = client
6241 .with_url("http://127.0.0.1:8080")
6242 .expect("http://127.0.0.1 must be accepted (loopback)");
6243 assert_eq!(out.url(), "http://127.0.0.1:8080");
6244 }
6245
6246 #[test]
6247 fn with_url_accepts_http_loopback_ipv6() {
6248 let client = Client::new().unwrap();
6249 let out = client
6250 .with_url("http://[::1]:8080")
6251 .expect("http://[::1] must be accepted (loopback)");
6252 assert!(out.url().starts_with("http://[::1]"));
6253 }
6254
6255 #[test]
6256 fn with_url_accepts_http_localhost() {
6257 let client = Client::new().unwrap();
6258 client
6259 .with_url("http://localhost:8080")
6260 .expect("http://localhost must be accepted");
6261 client
6262 .with_url("http://LOCALHOST")
6263 .expect("http://LOCALHOST must be accepted (case-insensitive)");
6264 client
6265 .with_url("http://wiremock.localhost")
6266 .expect("http://*.localhost must be accepted");
6267 }
6268
6269 #[test]
6270 fn with_url_rejects_http_public_host() {
6271 let client = Client::new().unwrap();
6272 let err = client.with_url("http://studio.example.com").unwrap_err();
6273 match err {
6274 Error::InsecureUrl(u) => assert_eq!(u, "http://studio.example.com"),
6275 other => panic!("expected InsecureUrl, got {other:?}"),
6276 }
6277 }
6278
6279 #[test]
6280 fn with_url_rejects_http_public_ip() {
6281 let client = Client::new().unwrap();
6282 // 8.8.8.8 is not loopback; must be rejected.
6283 let err = client.with_url("http://8.8.8.8").unwrap_err();
6284 assert!(matches!(err, Error::InsecureUrl(_)));
6285 }
6286
6287 #[test]
6288 fn with_url_rejects_non_http_scheme() {
6289 let client = Client::new().unwrap();
6290 // file:// would otherwise parse, but it's not a transport we
6291 // can use for RPC and we don't want to silently accept it.
6292 let err = client.with_url("file:///etc/passwd").unwrap_err();
6293 assert!(matches!(err, Error::InsecureUrl(_)));
6294 }
6295}
6296
6297#[cfg(test)]
6298mod tests_map_rpc_error {
6299 use super::*;
6300 use crate::api::TaskID;
6301
6302 #[test]
6303 fn maps_not_found_with_task_id_to_typed_variant() {
6304 // Server code 101 + "not found" message + task_id present → TaskNotFound
6305 let task_id = TaskID::try_from("task-1a2b").unwrap();
6306 let err = map_rpc_error(
6307 "task.data.list",
6308 101,
6309 "task not found".to_string(),
6310 Some(task_id),
6311 );
6312 assert!(matches!(err, Error::TaskNotFound(_)));
6313 }
6314
6315 #[test]
6316 fn maps_cannot_find_phrasing_to_typed_variant() {
6317 // The DVE server emits "Cannot find task..." — the original "not found"
6318 // substring match missed this and the caller saw a generic RpcError.
6319 let task_id = TaskID::try_from("task-1a2b").unwrap();
6320 let err = map_rpc_error(
6321 "task.data.list",
6322 101,
6323 "Cannot find task with id 6789".to_string(),
6324 Some(task_id),
6325 );
6326 assert!(
6327 matches!(err, Error::TaskNotFound(_)),
6328 "'Cannot find task' should map to TaskNotFound, got {err:?}"
6329 );
6330 }
6331
6332 #[test]
6333 fn maps_does_not_exist_phrasing_to_typed_variant() {
6334 let task_id = TaskID::try_from("task-1a2b").unwrap();
6335 let err = map_rpc_error(
6336 "task.chart.get",
6337 101,
6338 "task does not exist".to_string(),
6339 Some(task_id),
6340 );
6341 assert!(matches!(err, Error::TaskNotFound(_)));
6342 }
6343
6344 #[test]
6345 fn maps_code_101_with_unknown_phrasing_when_task_id_supplied() {
6346 // Server contract for code 101 is "resource not found"; even if the
6347 // phrasing is novel, the typed variant should be returned so callers
6348 // can write a stable `match`.
6349 let task_id = TaskID::try_from("task-1a2b").unwrap();
6350 let err = map_rpc_error(
6351 "task.data.list",
6352 101,
6353 "completely novel server message".to_string(),
6354 Some(task_id),
6355 );
6356 assert!(
6357 matches!(err, Error::TaskNotFound(_)),
6358 "code 101 + task_id should always map to TaskNotFound, got {err:?}"
6359 );
6360 }
6361
6362 #[test]
6363 fn maps_permission_codes_to_typed_variant() {
6364 for code in [401, 403] {
6365 let err = map_rpc_error("task.chart.add", code, "denied".to_string(), None);
6366 assert!(
6367 matches!(err, Error::PermissionDenied(_)),
6368 "code {} did not map",
6369 code
6370 );
6371 }
6372 }
6373
6374 #[test]
6375 fn permission_denied_records_method_for_diagnostics() {
6376 let err = map_rpc_error("task.data.upload", 403, "forbidden".to_string(), None);
6377 match err {
6378 Error::PermissionDenied(method) => assert_eq!(method, "task.data.upload"),
6379 other => panic!("expected PermissionDenied, got {:?}", other),
6380 }
6381 }
6382
6383 #[test]
6384 fn maps_payload_too_large_to_typed_variant() {
6385 let err = map_rpc_error("val.data.upload", 413, "request too large".into(), None);
6386 match err {
6387 Error::PayloadTooLarge { method, size_hint } => {
6388 assert_eq!(method, "val.data.upload");
6389 assert!(size_hint.is_none());
6390 }
6391 other => panic!("expected PayloadTooLarge, got {:?}", other),
6392 }
6393 }
6394
6395 #[test]
6396 fn falls_through_to_generic_rpc_error_for_unknown_codes() {
6397 let err = map_rpc_error("task.data.list", -99999, "weird".to_string(), None);
6398 match err {
6399 Error::RpcError(code, msg) => {
6400 assert_eq!(code, -99999);
6401 assert_eq!(msg, "weird");
6402 }
6403 other => panic!("expected RpcError, got {:?}", other),
6404 }
6405 }
6406
6407 #[test]
6408 fn not_found_without_task_id_falls_through() {
6409 // Code 101 without task_id → generic RpcError (no task to name)
6410 let err = map_rpc_error("task.data.list", 101, "not found".to_string(), None);
6411 assert!(matches!(err, Error::RpcError(101, _)));
6412 }
6413
6414 #[test]
6415 fn code_101_with_task_id_always_maps_even_with_unrelated_message() {
6416 // Previously the test asserted fall-through for non-"not found"
6417 // messages, but the contract for code 101 is "resource not found"
6418 // (see api.go), so when a task_id is present the typed variant is
6419 // returned unconditionally to give callers a stable error type.
6420 let task_id = TaskID::try_from("task-1a2b").unwrap();
6421 let err = map_rpc_error(
6422 "task.data.list",
6423 101,
6424 "permission denied".to_string(),
6425 Some(task_id),
6426 );
6427 assert!(matches!(err, Error::TaskNotFound(_)));
6428 }
6429}
6430
6431#[cfg(test)]
6432mod tests_jobs {
6433 use super::*;
6434
6435 #[test]
6436 fn jobs_list_request_serializes_to_empty_object() {
6437 let req = JobsListRequest {};
6438 assert_eq!(serde_json::to_value(&req).unwrap(), serde_json::json!({}));
6439 }
6440
6441 #[test]
6442 fn job_deserializes_from_bk_batch_shape() {
6443 let json = r#"{
6444 "code": "edgefirst-validator:2.9.5",
6445 "title": "EdgeFirst Validator",
6446 "job_name": "smoke-test",
6447 "job_id": "aws-batch-abc",
6448 "state": "RUNNING",
6449 "launch": "2026-05-14T15:00:00Z",
6450 "task_id": 6789,
6451 "docker_task": {},
6452 "extra_field": "ignored"
6453 }"#;
6454 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6455 assert_eq!(job.code, "edgefirst-validator:2.9.5");
6456 assert_eq!(job.state, "RUNNING");
6457 assert_eq!(job.task_id, 6789);
6458 assert_eq!(job.task_id().value(), 6789);
6459 }
6460}
6461
6462#[cfg(test)]
6463mod tests_job_run {
6464 use super::*;
6465 use crate::api::Parameter;
6466 use std::collections::HashMap;
6467
6468 #[test]
6469 fn job_run_request_serializes_with_expected_fields() {
6470 let req = JobRunRequest {
6471 name: "edgefirst-validator".into(),
6472 job_name: "post-profile-run".into(),
6473 env: HashMap::from([("LOG_LEVEL".into(), "info".into())]),
6474 data: HashMap::from([("validation_session_id".into(), Parameter::Integer(2707))]),
6475 };
6476 let json = serde_json::to_value(&req).unwrap();
6477 assert_eq!(json["name"], "edgefirst-validator");
6478 assert_eq!(json["job_name"], "post-profile-run");
6479 assert_eq!(json["env"]["LOG_LEVEL"], "info");
6480 assert_eq!(json["data"]["validation_session_id"], 2707);
6481 }
6482
6483 #[test]
6484 fn job_run_response_deserializes_as_job() {
6485 // job.run now returns the full BK_BATCH record; deserialize as Job.
6486 let json = r#"{
6487 "code": "edgefirst-validator:2.9.5",
6488 "title": "EdgeFirst Validator",
6489 "job_name": "post-profile-run",
6490 "job_id": "aws-batch-job-xxx",
6491 "state": "SUBMITTED",
6492 "task_id": 6789
6493 }"#;
6494 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6495 assert_eq!(job.task_id, 6789);
6496 assert_eq!(job.job_id, "aws-batch-job-xxx");
6497 assert_eq!(job.state, "SUBMITTED");
6498 }
6499}
6500
6501#[cfg(test)]
6502mod tests_job_stop {
6503 use super::*;
6504 use crate::api::TaskID;
6505
6506 #[test]
6507 fn job_stop_request_serializes_with_task_id() {
6508 let task_id = TaskID::try_from("task-1a2b").unwrap();
6509 let req = JobStopRequest {
6510 task_id: task_id.value(),
6511 };
6512 let json = serde_json::to_value(&req).unwrap();
6513 assert_eq!(json["task_id"], task_id.value());
6514 }
6515}
6516
6517#[cfg(test)]
6518mod tests_task_data_list_request {
6519 use super::*;
6520 use crate::api::TaskID;
6521
6522 #[test]
6523 fn task_data_list_request_serializes_with_task_id() {
6524 let task_id = TaskID::try_from("task-1a2b").unwrap();
6525 let req = TaskDataListRequest {
6526 task_id: task_id.value(),
6527 };
6528 let json = serde_json::to_value(&req).unwrap();
6529 assert_eq!(json["task_id"], task_id.value());
6530 }
6531}
6532
6533#[cfg(test)]
6534mod tests_task_data_download {
6535 use super::*;
6536 use crate::api::TaskID;
6537
6538 #[test]
6539 fn task_data_download_request_serializes_with_all_fields() {
6540 let task_id = TaskID::try_from("task-1a2b").unwrap();
6541 let req = TaskDataDownloadRequest {
6542 task_id: task_id.value(),
6543 folder: "predictions".into(),
6544 file: "predictions.parquet".into(),
6545 };
6546 let json = serde_json::to_value(&req).unwrap();
6547 assert_eq!(json["task_id"], task_id.value());
6548 assert_eq!(json["folder"], "predictions");
6549 assert_eq!(json["file"], "predictions.parquet");
6550 }
6551}
6552
6553#[cfg(test)]
6554mod tests_task_chart_add {
6555 use super::*;
6556 use crate::api::{Parameter, TaskID};
6557
6558 #[test]
6559 fn task_chart_add_request_serializes_with_correct_fields() {
6560 let task_id = TaskID::try_from("task-1a2b").unwrap();
6561 let data = Parameter::Object(std::collections::HashMap::from([(
6562 "type".into(),
6563 Parameter::String("line".into()),
6564 )]));
6565 let req = TaskChartAddRequest {
6566 task_id: task_id.value(),
6567 group_name: "metrics".into(),
6568 chart_name: "loss".into(),
6569 params: None,
6570 data,
6571 };
6572 let json = serde_json::to_value(&req).unwrap();
6573 assert_eq!(json["task_id"], task_id.value());
6574 assert_eq!(json["group_name"], "metrics");
6575 assert_eq!(json["chart_name"], "loss");
6576 assert_eq!(json["data"]["type"], "line");
6577 assert!(json["params"].is_null());
6578 }
6579}
6580
6581#[cfg(test)]
6582mod tests_task_chart_list {
6583 use super::*;
6584 use crate::api::TaskID;
6585
6586 #[test]
6587 fn task_chart_list_request_omits_empty_group_name() {
6588 let task_id = TaskID::try_from("task-1a2b").unwrap();
6589 let req = TaskChartListRequest {
6590 task_id: task_id.value(),
6591 group_name: String::new(),
6592 };
6593 let json = serde_json::to_value(&req).unwrap();
6594 assert_eq!(json["task_id"], task_id.value());
6595 assert_eq!(json["group_name"], "");
6596 }
6597}
6598
6599#[cfg(test)]
6600mod tests_task_chart_get {
6601 use super::*;
6602 use crate::api::TaskID;
6603
6604 #[test]
6605 fn task_chart_get_request_serializes_with_all_fields() {
6606 let task_id = TaskID::try_from("task-1a2b").unwrap();
6607 let req = TaskChartGetRequest {
6608 task_id: task_id.value(),
6609 group_name: "metrics".into(),
6610 chart_name: "loss".into(),
6611 };
6612 let json = serde_json::to_value(&req).unwrap();
6613 assert_eq!(json["task_id"], task_id.value());
6614 assert_eq!(json["group_name"], "metrics");
6615 assert_eq!(json["chart_name"], "loss");
6616 }
6617}
6618
6619#[cfg(test)]
6620mod tests_val_data_download {
6621 use super::*;
6622
6623 #[test]
6624 fn val_data_download_request_serializes() {
6625 let req = ValDataDownloadRequest {
6626 session_id: 2707,
6627 filename: "trace/imx95.json".into(),
6628 };
6629 let json = serde_json::to_value(&req).unwrap();
6630 assert_eq!(json["session_id"], 2707);
6631 assert_eq!(json["filename"], "trace/imx95.json");
6632 }
6633}
6634
6635#[cfg(test)]
6636mod tests_val_data_list {
6637 use super::*;
6638
6639 #[test]
6640 fn val_data_list_request_serializes() {
6641 let req = ValDataListRequest { session_id: 2707 };
6642 assert_eq!(
6643 serde_json::to_value(&req).unwrap(),
6644 serde_json::json!({"session_id": 2707})
6645 );
6646 }
6647}
6648
6649#[cfg(test)]
6650mod tests_jsonrpc_envelope_detection {
6651 use super::*;
6652
6653 #[test]
6654 fn detects_real_envelope() {
6655 let v = serde_json::json!({
6656 "jsonrpc": "2.0",
6657 "id": 0,
6658 "error": { "code": 101, "message": "Cannot find task" },
6659 });
6660 assert!(is_jsonrpc_error_envelope(&v));
6661 }
6662
6663 #[test]
6664 fn rejects_plain_json_artifact_with_error_field() {
6665 // A diagnostics file with a free-form `error` object — must not be
6666 // misread as an RPC envelope just because the key collides.
6667 let v = serde_json::json!({
6668 "metric": "loss",
6669 "value": 0.42,
6670 "error": { "code": "ENV_NOT_FOUND", "message": "missing var" },
6671 });
6672 assert!(
6673 !is_jsonrpc_error_envelope(&v),
6674 "missing jsonrpc sentinel should mean 'not an envelope'"
6675 );
6676 }
6677
6678 #[test]
6679 fn rejects_envelope_missing_jsonrpc_sentinel() {
6680 // Bare `error` block without the protocol-version marker.
6681 let v = serde_json::json!({
6682 "id": 0,
6683 "error": { "code": 101, "message": "x" },
6684 });
6685 assert!(!is_jsonrpc_error_envelope(&v));
6686 }
6687
6688 #[test]
6689 fn rejects_envelope_with_non_object_error_field() {
6690 // A diagnostics file shaped like JSON-RPC accidentally but using
6691 // a string for `error`.
6692 let v = serde_json::json!({
6693 "jsonrpc": "2.0",
6694 "error": "something went wrong",
6695 });
6696 assert!(!is_jsonrpc_error_envelope(&v));
6697 }
6698
6699 #[test]
6700 fn rejects_envelope_without_error_code() {
6701 // Real envelopes always carry an integer error.code; missing one
6702 // is suspicious enough to refuse the envelope classification.
6703 let v = serde_json::json!({
6704 "jsonrpc": "2.0",
6705 "error": { "message": "no code" },
6706 });
6707 assert!(!is_jsonrpc_error_envelope(&v));
6708 }
6709
6710 #[test]
6711 fn rejects_envelope_with_non_numeric_error_code() {
6712 let v = serde_json::json!({
6713 "jsonrpc": "2.0",
6714 "error": { "code": "ENOENT", "message": "x" },
6715 });
6716 assert!(!is_jsonrpc_error_envelope(&v));
6717 }
6718
6719 #[test]
6720 fn rejects_non_object_root() {
6721 // A JSON file whose root is an array — common for metrics dumps —
6722 // must not be misread.
6723 let v = serde_json::json!([1, 2, 3]);
6724 assert!(!is_jsonrpc_error_envelope(&v));
6725 }
6726
6727 #[test]
6728 fn accepts_unsigned_error_code() {
6729 // The server's code is technically i32 but JSON has no signed/
6730 // unsigned distinction — accept both shapes.
6731 let v = serde_json::json!({
6732 "jsonrpc": "2.0",
6733 "error": { "code": 101u32, "message": "x" },
6734 });
6735 assert!(is_jsonrpc_error_envelope(&v));
6736 }
6737}
6738
6739#[cfg(test)]
6740mod tests_validate_chart_args {
6741 use super::*;
6742
6743 #[test]
6744 fn rejects_empty_group() {
6745 let err = validate_chart_args("", "name").unwrap_err();
6746 assert!(matches!(err, Error::InvalidParameters(_)));
6747 }
6748
6749 #[test]
6750 fn rejects_empty_name() {
6751 let err = validate_chart_args("group", "").unwrap_err();
6752 assert!(matches!(err, Error::InvalidParameters(_)));
6753 }
6754
6755 #[test]
6756 fn rejects_both_empty() {
6757 let err = validate_chart_args("", "").unwrap_err();
6758 assert!(matches!(err, Error::InvalidParameters(_)));
6759 }
6760
6761 #[test]
6762 fn accepts_valid_args() {
6763 assert!(validate_chart_args("group", "name").is_ok());
6764 }
6765
6766 #[test]
6767 fn accepts_unicode_args() {
6768 // Unicode names are allowed; only emptiness is rejected.
6769 assert!(validate_chart_args("metrics-集合", "损失").is_ok());
6770 }
6771}
6772
6773// ---------------------------------------------------------------------------
6774// Additional offline tests for request shapes + helpers added in DE-2565.
6775//
6776// These focus on the wire-shape and helper logic that does not require a
6777// live Studio server — they significantly boost coverage of client.rs.
6778// ---------------------------------------------------------------------------
6779
6780#[cfg(test)]
6781mod tests_job_run_request_shape {
6782 use super::*;
6783 use crate::api::Parameter;
6784 use std::collections::HashMap;
6785
6786 #[test]
6787 fn empty_env_and_data_serialize_as_empty_objects() {
6788 let req = JobRunRequest {
6789 name: "edgefirst-validator".into(),
6790 job_name: "smoke".into(),
6791 env: HashMap::new(),
6792 data: HashMap::new(),
6793 };
6794 let json = serde_json::to_value(&req).unwrap();
6795 assert_eq!(json["name"], "edgefirst-validator");
6796 assert_eq!(json["env"], serde_json::json!({}));
6797 assert_eq!(json["data"], serde_json::json!({}));
6798 }
6799
6800 #[test]
6801 fn data_passes_through_parameter_object_payloads() {
6802 // Confirms the Parameter wrapper survives JSON serialization round-trip
6803 // for the kind of structured chart payload that exercises Parameter
6804 // variants (Real, Integer, String, Array, Object, Boolean).
6805 let req = JobRunRequest {
6806 name: "edgefirst-validator".into(),
6807 job_name: "feat".into(),
6808 env: HashMap::new(),
6809 data: HashMap::from([
6810 ("flag".into(), Parameter::Boolean(true)),
6811 ("epochs".into(), Parameter::Integer(50)),
6812 ("lr".into(), Parameter::Real(1e-3)),
6813 ("name".into(), Parameter::String("hello".into())),
6814 ]),
6815 };
6816 let json = serde_json::to_value(&req).unwrap();
6817 assert_eq!(json["data"]["flag"], true);
6818 assert_eq!(json["data"]["epochs"], 50);
6819 assert!(json["data"]["lr"].as_f64().unwrap() > 0.0);
6820 assert_eq!(json["data"]["name"], "hello");
6821 }
6822}
6823
6824#[cfg(test)]
6825mod tests_task_data_chart_request_shape {
6826 use super::*;
6827 use crate::api::{Parameter, TaskID};
6828
6829 #[test]
6830 fn chart_add_request_with_params_serializes_object() {
6831 let task_id = TaskID::try_from("task-1a2b").unwrap();
6832 let params = Parameter::Object(std::collections::HashMap::from([(
6833 "y_axis".into(),
6834 Parameter::String("log".into()),
6835 )]));
6836 let data = Parameter::Object(std::collections::HashMap::from([(
6837 "type".into(),
6838 Parameter::String("line".into()),
6839 )]));
6840 let req = TaskChartAddRequest {
6841 task_id: task_id.value(),
6842 group_name: "metrics".into(),
6843 chart_name: "loss".into(),
6844 params: Some(params),
6845 data,
6846 };
6847 let json = serde_json::to_value(&req).unwrap();
6848 assert_eq!(json["params"]["y_axis"], "log");
6849 }
6850
6851 #[test]
6852 fn task_data_list_request_round_trips() {
6853 let task_id = TaskID::try_from("task-1a2b").unwrap();
6854 let req = TaskDataListRequest {
6855 task_id: task_id.value(),
6856 };
6857 let json = serde_json::to_string(&req).unwrap();
6858 // Field order is stable for a single-field struct, so an exact match
6859 // is meaningful here.
6860 assert_eq!(json, format!("{{\"task_id\":{}}}", task_id.value()));
6861 }
6862
6863 #[test]
6864 fn task_data_download_request_treats_folder_and_file_independently() {
6865 let task_id = TaskID::try_from("task-1a2b").unwrap();
6866 let req = TaskDataDownloadRequest {
6867 task_id: task_id.value(),
6868 folder: "validation/run-01".into(),
6869 file: "metrics.json".into(),
6870 };
6871 let json = serde_json::to_value(&req).unwrap();
6872 // Server takes folder + file separately (not a single combined path)
6873 // so callers don't have to escape slashes themselves.
6874 assert_eq!(json["folder"], "validation/run-01");
6875 assert_eq!(json["file"], "metrics.json");
6876 }
6877}
6878
6879#[cfg(test)]
6880mod tests_val_data_request_shape {
6881 use super::*;
6882
6883 #[test]
6884 fn val_data_list_round_trips() {
6885 let req = ValDataListRequest { session_id: 2707 };
6886 let s = serde_json::to_string(&req).unwrap();
6887 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6888 assert_eq!(back["session_id"], 2707);
6889 }
6890
6891 #[test]
6892 fn val_data_download_round_trips_with_nested_path() {
6893 let req = ValDataDownloadRequest {
6894 session_id: 2707,
6895 filename: "subfolder/imx95.json".into(),
6896 };
6897 let s = serde_json::to_string(&req).unwrap();
6898 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6899 assert_eq!(back["session_id"], 2707);
6900 assert_eq!(back["filename"], "subfolder/imx95.json");
6901 }
6902}
6903
6904#[cfg(test)]
6905mod tests_progress_struct {
6906 use super::*;
6907
6908 #[test]
6909 fn progress_can_be_constructed_with_zero_total() {
6910 // Servers sometimes omit Content-Length; progress events should still
6911 // be representable. This guards the public field-level API.
6912 let p = Progress {
6913 current: 0,
6914 total: 0,
6915 status: None,
6916 };
6917 assert_eq!(p.current, 0);
6918 assert_eq!(p.total, 0);
6919 assert!(p.status.is_none());
6920 }
6921
6922 #[test]
6923 fn progress_tracks_current_independently_of_total() {
6924 let p = Progress {
6925 current: 123,
6926 total: 456,
6927 status: Some("Downloading".into()),
6928 };
6929 assert_eq!(p.current, 123);
6930 assert_eq!(p.total, 456);
6931 assert_eq!(p.status.as_deref(), Some("Downloading"));
6932 }
6933
6934 #[test]
6935 fn progress_can_be_cloned() {
6936 // Progress is consumed by progress sinks which may need to retain a
6937 // copy independently of the channel — derive(Clone) must hold.
6938 let p = Progress {
6939 current: 10,
6940 total: 20,
6941 status: Some("phase".into()),
6942 };
6943 let q = p.clone();
6944 assert_eq!(q.current, p.current);
6945 assert_eq!(q.total, p.total);
6946 assert_eq!(q.status, p.status);
6947 }
6948}
6949
6950#[cfg(test)]
6951mod tests_bare_filename_parent {
6952 // Documents the empty-parent guard added for `rpc_download` so that
6953 // callers passing a bare filename like "metrics.json" download to the
6954 // current directory instead of erroring on `create_dir_all("")`.
6955 use std::path::Path;
6956
6957 #[test]
6958 fn bare_filename_parent_is_empty_path() {
6959 // This is the invariant our guard depends on. If a future Rust
6960 // release ever changed `Path::parent` for bare filenames, the guard
6961 // would need revisiting.
6962 let p = Path::new("metrics.json");
6963 let parent = p.parent().expect("bare filename always has Some parent");
6964 assert!(
6965 parent.as_os_str().is_empty(),
6966 "Path::parent for bare filename should be empty, got: {parent:?}"
6967 );
6968 }
6969
6970 #[test]
6971 fn path_with_directory_has_non_empty_parent() {
6972 // The companion case: when the path includes a directory, the
6973 // parent is non-empty and `create_dir_all` should be invoked.
6974 let p = Path::new("dir/metrics.json");
6975 let parent = p.parent().expect("path-with-dir always has Some parent");
6976 assert!(!parent.as_os_str().is_empty());
6977 assert_eq!(parent, Path::new("dir"));
6978 }
6979}