edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult,
8 NewValidationSession, Organization, Project, ProjectID, SampleID, SamplesCountResult,
9 SamplesListParams, SamplesListResult, Snapshot, SnapshotCreateFromDataset,
10 SnapshotFromDatasetResult, SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage,
11 StartValidationRequest, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
12 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
13 ValidationSessionID,
14 },
15 dataset::{
16 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
17 },
18 retry::{create_retry_policy, log_retry_configuration},
19 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
20};
21use base64::Engine as _;
22use chrono::{DateTime, Utc};
23use directories::ProjectDirs;
24use futures::{StreamExt as _, future::join_all};
25use log::{Level, debug, error, log_enabled, trace, warn};
26use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::{
29 collections::HashMap,
30 ffi::OsStr,
31 fs::create_dir_all,
32 io::{SeekFrom, Write as _},
33 path::{Path, PathBuf},
34 sync::{
35 Arc,
36 atomic::{AtomicUsize, Ordering},
37 },
38 time::Duration,
39 vec,
40};
41use tokio::{
42 fs::{self, File},
43 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
44 sync::{RwLock, Semaphore, mpsc::Sender},
45};
46use tokio_util::codec::{BytesCodec, FramedRead};
47use walkdir::WalkDir;
48
49#[cfg(feature = "polars")]
50use polars::prelude::*;
51
52/// Maps a JSON-RPC error code to a typed `Error` variant when the code is
53/// well-known; otherwise returns `Error::RpcError(code, message)` unchanged.
54///
55/// Scoped to the new DE-2565 methods. Existing methods continue to return
56/// `Error::RpcError` directly.
57///
58/// Server error codes (from `api.go` via `jrpc.Fail`):
59/// - `1` – generic server error
60/// - `3` – validation / bad request
61/// - `10` – internal server error
62/// - `101` – resource not found (e.g. "Cannot find task...", "not found in DB")
63/// - `401` – unauthenticated
64/// - `403` – forbidden
65/// - `413` – payload too large
66pub(crate) fn map_rpc_error(
67 method: &str,
68 code: i32,
69 message: String,
70 task_id: Option<crate::api::TaskID>,
71) -> Error {
72 // Server emits "Cannot find task...", "not found in DB", and other phrasings
73 // for code 101. Code 101 with a task_id is task-not-found by contract
74 // (see api.go), so we return the typed variant unconditionally when the
75 // caller supplied a task_id — message phrasing is treated as informational
76 // and is preserved by the RPC layer for diagnostic logging upstream.
77 if code == 101
78 && let Some(id) = task_id
79 {
80 return Error::TaskNotFound(id);
81 }
82 match code {
83 401 | 403 => Error::PermissionDenied(method.to_string()),
84 413 => Error::PayloadTooLarge {
85 method: method.to_string(),
86 size_hint: None,
87 },
88 _ => Error::RpcError(code, message),
89 }
90}
91
92/// Returns true if `val` is structurally a JSON-RPC 2.0 *error* envelope.
93///
94/// A real envelope must:
95/// 1. Be a JSON object,
96/// 2. Carry a `"jsonrpc"` member (the protocol-version sentinel — JSON-RPC
97/// 2.0 §5 mandates this on every response object),
98/// 3. Carry an `"error"` object that includes a numeric `"code"` field.
99///
100/// This is intentionally stricter than a "looks for a top-level `error`
101/// key" check so that legitimate JSON file payloads (validation traces,
102/// metrics dumps, diagnostics) which happen to include a free-form `error`
103/// field are *not* misclassified as RPC failures.
104///
105/// Extracted so it can be unit-tested without a live server.
106pub(crate) fn is_jsonrpc_error_envelope(val: &serde_json::Value) -> bool {
107 let Some(obj) = val.as_object() else {
108 return false;
109 };
110 // Protocol-version sentinel — only JSON-RPC envelopes carry this.
111 if !obj.contains_key("jsonrpc") {
112 return false;
113 }
114 let Some(err) = obj.get("error").and_then(|e| e.as_object()) else {
115 return false;
116 };
117 err.get("code")
118 .map(|c| c.is_i64() || c.is_u64())
119 .unwrap_or(false)
120}
121
122/// Validates that `group` and `name` are both non-empty strings for chart
123/// operations (`add_chart`, `get_chart`). Extracted so it can be unit-tested
124/// without a live server.
125pub(crate) fn validate_chart_args(group: &str, name: &str) -> Result<(), Error> {
126 if group.is_empty() || name.is_empty() {
127 return Err(Error::InvalidParameters(
128 "chart: group and name must be non-empty".into(),
129 ));
130 }
131 Ok(())
132}
133
134static PART_SIZE: usize = 100 * 1024 * 1024;
135
136/// Source for file content during upload - either a local path or raw bytes.
137#[derive(Clone)]
138enum FileSource {
139 /// File content from a local filesystem path.
140 Path(PathBuf),
141 /// File content as raw bytes (e.g., from a ZIP archive).
142 Bytes(Vec<u8>),
143}
144
145fn max_tasks() -> usize {
146 std::env::var("MAX_TASKS")
147 .ok()
148 .and_then(|v| v.parse().ok())
149 .unwrap_or_else(|| {
150 // Default to half the number of CPUs, minimum 2, maximum 8
151 let cpus = std::thread::available_parallelism()
152 .map(|n| n.get())
153 .unwrap_or(4);
154 (cpus / 2).clamp(2, 8)
155 })
156}
157
158/// Maximum concurrent upload tasks for multipart S3 uploads.
159///
160/// Higher concurrency improves upload throughput by saturating available
161/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
162fn max_upload_tasks() -> usize {
163 std::env::var("MAX_UPLOAD_TASKS")
164 .ok()
165 .and_then(|v| v.parse().ok())
166 .unwrap_or(8) // Default to 8 concurrent part uploads
167}
168
169/// Filters items by name and sorts by match quality.
170///
171/// Match quality priority (best to worst):
172/// 1. Exact match (case-sensitive)
173/// 2. Exact match (case-insensitive)
174/// 3. Substring match (shorter names first, then alphabetically)
175///
176/// This ensures that searching for "Deer" returns "Deer" before
177/// "Deer Roundtrip 20251129" or "Reindeer".
178fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
179where
180 F: Fn(&T) -> &str,
181{
182 let filter_lower = filter.to_lowercase();
183 let mut filtered: Vec<T> = items
184 .into_iter()
185 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
186 .collect();
187
188 filtered.sort_by(|a, b| {
189 let name_a = get_name(a);
190 let name_b = get_name(b);
191
192 // Priority 1: Exact match (case-sensitive)
193 let exact_a = name_a == filter;
194 let exact_b = name_b == filter;
195 if exact_a != exact_b {
196 return exact_b.cmp(&exact_a); // true (exact) comes first
197 }
198
199 // Priority 2: Exact match (case-insensitive)
200 let exact_ci_a = name_a.to_lowercase() == filter_lower;
201 let exact_ci_b = name_b.to_lowercase() == filter_lower;
202 if exact_ci_a != exact_ci_b {
203 return exact_ci_b.cmp(&exact_ci_a);
204 }
205
206 // Priority 3: Shorter names first (more specific matches)
207 let len_cmp = name_a.len().cmp(&name_b.len());
208 if len_cmp != std::cmp::Ordering::Equal {
209 return len_cmp;
210 }
211
212 // Priority 4: Alphabetical order for stability
213 name_a.cmp(name_b)
214 });
215
216 filtered
217}
218
219/// Whether `host` refers to a loopback (machine-local) endpoint.
220///
221/// Used by [`Client::with_url`] to decide whether a plain-`http://` URL is
222/// safe to accept. Loopback traffic never leaves the machine, so the
223/// usual concern about leaking the Studio bearer token in plaintext does
224/// not apply — that's how wiremock and local dev servers connect.
225fn is_loopback_host(host: Option<&url::Host<&str>>) -> bool {
226 match host {
227 Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
228 Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
229 // RFC 6761 reserves "localhost" (and `*.localhost`) as a loopback
230 // name. Compare case-insensitively because URL hosts are matched
231 // that way and developers do type capitalized variants.
232 Some(url::Host::Domain(d)) => {
233 d.eq_ignore_ascii_case("localhost") || d.to_ascii_lowercase().ends_with(".localhost")
234 }
235 None => false,
236 }
237}
238
239fn sanitize_path_component(name: &str) -> String {
240 let trimmed = name.trim();
241 if trimmed.is_empty() {
242 return "unnamed".to_string();
243 }
244
245 let component = Path::new(trimmed)
246 .file_name()
247 .unwrap_or_else(|| OsStr::new(trimmed));
248
249 let sanitized: String = component
250 .to_string_lossy()
251 .chars()
252 .map(|c| match c {
253 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
254 _ => c,
255 })
256 .collect();
257
258 if sanitized.is_empty() {
259 "unnamed".to_string()
260 } else {
261 sanitized
262 }
263}
264
265/// Progress information for long-running operations.
266///
267/// This struct tracks the current progress of operations like file uploads,
268/// downloads, or dataset processing. It provides the current count, total
269/// count, and an optional status string to enable progress reporting in
270/// applications.
271///
272/// # Multi-Stage Progress
273///
274/// The `status` field enables multi-stage progress tracking. When an operation
275/// has multiple phases, the status field changes to indicate the current phase.
276/// Applications should detect status changes to reset their progress display.
277///
278/// # Operation Progress Details
279///
280/// | Operation | Status | Unit | Notes |
281/// |-----------|--------|------|-------|
282/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
283/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
284/// | [`samples`] | `None` | samples | Paginated API fetch |
285/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
286/// | [`annotations`] | `None` | samples | Samples processed for annotations |
287/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
288/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
289/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
290///
291/// [`download_dataset`]: Client::download_dataset
292/// [`populate_samples`]: Client::populate_samples
293/// [`samples`]: Client::samples
294/// [`sample_names`]: Client::sample_names
295/// [`annotations`]: Client::annotations
296/// [`download_artifact`]: Client::download_artifact
297/// [`download_checkpoint`]: Client::download_checkpoint
298/// [`download_snapshot`]: Client::download_snapshot
299///
300/// # Examples
301///
302/// Basic progress display:
303///
304/// ```rust
305/// use edgefirst_client::Progress;
306///
307/// let progress = Progress {
308/// current: 25,
309/// total: 100,
310/// status: Some("Downloading".to_string()),
311/// };
312/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
313/// println!(
314/// "{}: {:.1}% ({}/{})",
315/// progress.status.as_deref().unwrap_or("Progress"),
316/// percentage,
317/// progress.current,
318/// progress.total
319/// );
320/// ```
321///
322/// Multi-stage progress handling (e.g., for `download_dataset`):
323///
324/// ```rust,ignore
325/// let mut last_status: Option<String> = None;
326///
327/// while let Some(progress) = rx.recv().await {
328/// // Detect stage change and reset progress bar
329/// if progress.status != last_status {
330/// if let Some(ref status) = progress.status {
331/// println!("\n{}", status);
332/// }
333/// last_status = progress.status.clone();
334/// }
335///
336/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
337/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
338/// }
339/// ```
340#[derive(Debug, Clone)]
341pub struct Progress {
342 /// Current number of completed items or bytes.
343 pub current: usize,
344 /// Total number of items or bytes to process.
345 pub total: usize,
346 /// Optional status describing the current operation phase.
347 ///
348 /// When this value changes from `None` to `Some(...)` or between different
349 /// values, it indicates a new phase has started. Applications should reset
350 /// their progress display when the status changes.
351 ///
352 /// Currently only [`Client::download_dataset`] uses status changes:
353 /// - Phase 1: `None` while fetching sample metadata
354 /// - Phase 2: `"Downloading"` while downloading files
355 ///
356 /// All other operations use `None` throughout.
357 pub status: Option<String>,
358}
359
360#[derive(Serialize)]
361struct RpcRequest<Params> {
362 id: u64,
363 jsonrpc: String,
364 method: String,
365 params: Option<Params>,
366}
367
368impl<T> Default for RpcRequest<T> {
369 fn default() -> Self {
370 RpcRequest {
371 id: 0,
372 jsonrpc: "2.0".to_string(),
373 method: "".to_string(),
374 params: None,
375 }
376 }
377}
378
379#[derive(Deserialize)]
380struct RpcError {
381 code: i32,
382 message: String,
383}
384
385#[derive(Deserialize)]
386struct RpcResponse<RpcResult> {
387 #[allow(dead_code)]
388 id: String,
389 #[allow(dead_code)]
390 jsonrpc: String,
391 error: Option<RpcError>,
392 result: Option<RpcResult>,
393}
394
395#[derive(Deserialize)]
396#[allow(dead_code)]
397struct EmptyResult {}
398
399#[derive(Debug, Serialize)]
400#[allow(dead_code)]
401struct SnapshotCreateParams {
402 snapshot_name: String,
403 keys: Vec<String>,
404}
405
406#[derive(Debug, Deserialize)]
407#[allow(dead_code)]
408struct SnapshotCreateResult {
409 snapshot_id: SnapshotID,
410 urls: Vec<String>,
411}
412
413#[derive(Debug, Serialize)]
414struct SnapshotCreateMultipartParams {
415 snapshot_name: String,
416 keys: Vec<String>,
417 file_sizes: Vec<usize>,
418 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
419 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
420 snapshot_type: Option<String>,
421}
422
423#[derive(Debug, Deserialize)]
424#[serde(untagged)]
425enum SnapshotCreateMultipartResultField {
426 Id(u64),
427 Part(SnapshotPart),
428}
429
430#[derive(Debug, Serialize)]
431struct SnapshotCompleteMultipartParams {
432 key: String,
433 upload_id: String,
434 etag_list: Vec<EtagPart>,
435}
436
437#[derive(Debug, Clone, Serialize)]
438struct EtagPart {
439 #[serde(rename = "ETag")]
440 etag: String,
441 #[serde(rename = "PartNumber")]
442 part_number: usize,
443}
444
445#[derive(Debug, Clone, Deserialize)]
446struct SnapshotPart {
447 key: Option<String>,
448 upload_id: String,
449 urls: Vec<String>,
450}
451
452#[derive(Debug, Serialize)]
453struct SnapshotStatusParams {
454 snapshot_id: SnapshotID,
455 status: String,
456}
457
458#[derive(Deserialize, Debug)]
459struct SnapshotStatusResult {
460 #[allow(dead_code)]
461 pub id: SnapshotID,
462 #[allow(dead_code)]
463 pub uid: String,
464 #[allow(dead_code)]
465 pub description: String,
466 #[allow(dead_code)]
467 pub date: String,
468 #[allow(dead_code)]
469 pub status: String,
470}
471
472#[derive(Serialize)]
473#[allow(dead_code)]
474struct ImageListParams {
475 images_filter: ImagesFilter,
476 image_files_filter: HashMap<String, String>,
477 only_ids: bool,
478}
479
480#[derive(Serialize)]
481#[allow(dead_code)]
482struct ImagesFilter {
483 dataset_id: DatasetID,
484}
485
486/// Main client for interacting with EdgeFirst Studio Server.
487///
488/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
489/// and manages authentication, RPC calls, and data operations. It provides
490/// methods for managing projects, datasets, experiments, training sessions,
491/// and various utility functions for data processing.
492///
493/// The client supports multiple authentication methods and can work with both
494/// SaaS and self-hosted EdgeFirst Studio instances.
495///
496/// # Features
497///
498/// - **Authentication**: Token-based authentication with automatic persistence
499/// - **Dataset Management**: Upload, download, and manipulate datasets
500/// - **Project Operations**: Create and manage projects and experiments
501/// - **Training & Validation**: Submit and monitor ML training jobs
502/// - **Data Integration**: Convert between EdgeFirst datasets and popular
503/// formats
504/// - **Progress Tracking**: Real-time progress updates for long-running
505/// operations
506///
507/// # Examples
508///
509/// ```no_run
510/// use edgefirst_client::{Client, DatasetID};
511/// use std::str::FromStr;
512///
513/// # async fn example() -> Result<(), edgefirst_client::Error> {
514/// // Create a new client and authenticate
515/// let mut client = Client::new()?;
516/// let client = client
517/// .with_login("your-email@example.com", "password")
518/// .await?;
519///
520/// // Or use an existing token
521/// let base_client = Client::new()?;
522/// let client = base_client.with_token("your-token-here")?;
523///
524/// // Get organization and projects
525/// let org = client.organization().await?;
526/// let projects = client.projects(None).await?;
527///
528/// // Work with datasets
529/// let dataset_id = DatasetID::from_str("ds-abc123")?;
530/// let dataset = client.dataset(dataset_id).await?;
531/// # Ok(())
532/// # }
533/// ```
534/// Client is Clone but cannot derive Debug due to dyn TokenStorage
535#[derive(Clone)]
536pub struct Client {
537 http: reqwest::Client,
538 /// HTTP client for long-running bulk transfers (uploads/downloads, no total-request
539 /// timeout). An idle read timeout is still configured on the underlying client, and
540 /// some operations (such as uploads) may apply additional per-request timeouts.
541 bulk_http: reqwest::Client,
542 url: String,
543 token: Arc<RwLock<String>>,
544 /// Token storage backend. When set, tokens are automatically persisted.
545 storage: Option<Arc<dyn TokenStorage>>,
546 /// Legacy token path field for backwards compatibility with
547 /// with_token_path(). Deprecated: Use with_storage() instead.
548 token_path: Option<PathBuf>,
549}
550
551impl std::fmt::Debug for Client {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("Client")
554 .field("url", &self.url)
555 .field("has_storage", &self.storage.is_some())
556 .field("token_path", &self.token_path)
557 .finish()
558 }
559}
560
561/// Private context struct for pagination operations
562struct FetchContext<'a> {
563 dataset_id: DatasetID,
564 annotation_set_id: Option<AnnotationSetID>,
565 groups: &'a [String],
566 types: Vec<String>,
567 labels: &'a HashMap<String, u64>,
568}
569
570#[derive(Debug, Serialize)]
571struct JobsListRequest {}
572
573#[derive(Debug, Serialize)]
574struct JobRunRequest {
575 name: String,
576 job_name: String,
577 env: std::collections::HashMap<String, String>,
578 data: std::collections::HashMap<String, crate::api::Parameter>,
579}
580
581#[derive(Debug, Serialize)]
582struct JobStopRequest {
583 task_id: u64,
584}
585
586#[derive(Debug, Serialize)]
587pub(crate) struct TaskDataListRequest {
588 pub(crate) task_id: u64,
589}
590
591#[derive(Debug, Serialize)]
592pub(crate) struct TaskDataDownloadRequest {
593 pub(crate) task_id: u64,
594 pub(crate) folder: String,
595 pub(crate) file: String,
596}
597
598#[derive(Debug, Serialize)]
599pub(crate) struct TaskChartAddRequest {
600 pub(crate) task_id: u64,
601 pub(crate) group_name: String,
602 pub(crate) chart_name: String,
603 pub(crate) params: Option<crate::api::Parameter>,
604 pub(crate) data: crate::api::Parameter,
605}
606
607#[derive(Debug, Serialize)]
608pub(crate) struct TaskChartListRequest {
609 pub(crate) task_id: u64,
610 pub(crate) group_name: String,
611}
612
613#[derive(Debug, Serialize)]
614pub(crate) struct TaskChartGetRequest {
615 pub(crate) task_id: u64,
616 pub(crate) group_name: String,
617 pub(crate) chart_name: String,
618}
619
620#[derive(Debug, Serialize)]
621pub(crate) struct ValDataDownloadRequest {
622 pub(crate) session_id: u64,
623 pub(crate) filename: String,
624}
625
626#[derive(Debug, Serialize)]
627pub(crate) struct ValDataListRequest {
628 pub(crate) session_id: u64,
629}
630
631/// Streams the body of a successful `reqwest` response to a file on disk,
632/// emitting optional progress events.
633///
634/// Both `download_artifact` and `rpc_download` share this logic. The caller is
635/// responsible for creating any required parent directories before calling this
636/// function.
637///
638/// # Arguments
639/// * `resp` - A successful (HTTP 2xx) `reqwest::Response` whose body will
640/// be streamed to `path`.
641/// * `path` - Destination file path (created or truncated).
642/// * `progress` - Optional channel; events carry bytes received and
643/// `Content-Length` total (0 if the server omits it).
644///
645/// # Errors
646/// Returns `Error::IoError` on file I/O failures or propagates stream errors.
647async fn stream_response_to_file(
648 resp: reqwest::Response,
649 path: &std::path::Path,
650 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
651) -> Result<(), Error> {
652 use tokio::io::AsyncWriteExt as _;
653 let total = resp.content_length().unwrap_or(0) as usize;
654 let mut stream = resp.bytes_stream();
655 let mut file = tokio::fs::File::create(path).await?;
656 let mut current = 0usize;
657
658 if let Some(ref tx) = progress {
659 let _ = tx
660 .send(Progress {
661 current: 0,
662 total,
663 status: None,
664 })
665 .await;
666 }
667
668 while let Some(chunk) = stream.next().await {
669 let chunk = chunk?;
670 file.write_all(&chunk).await?;
671 current += chunk.len();
672 if let Some(ref tx) = progress {
673 let _ = tx
674 .send(Progress {
675 current,
676 total,
677 status: None,
678 })
679 .await;
680 }
681 }
682
683 // Flush tokio's internal write buffer to the OS before returning.
684 // tokio::fs::File buffers writes internally; without this, the buffer
685 // may not reach the filesystem before the caller reads the file.
686 file.flush().await?;
687 Ok(())
688}
689
690impl Client {
691 /// Create a new unauthenticated client with the default saas server.
692 ///
693 /// By default, the client uses [`FileTokenStorage`] for token persistence.
694 /// Use [`with_storage`][Self::with_storage],
695 /// [`with_memory_storage`][Self::with_memory_storage],
696 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
697 /// behavior.
698 ///
699 /// To connect to a different server, use [`with_server`][Self::with_server]
700 /// or [`with_token`][Self::with_token] (tokens include the server
701 /// instance).
702 ///
703 /// This client is created without a token and will need to authenticate
704 /// before using methods that require authentication.
705 ///
706 /// # Examples
707 ///
708 /// ```rust,no_run
709 /// use edgefirst_client::Client;
710 ///
711 /// # fn main() -> Result<(), edgefirst_client::Error> {
712 /// // Create client with default file storage
713 /// let client = Client::new()?;
714 ///
715 /// // Create client without token persistence
716 /// let client = Client::new()?.with_memory_storage();
717 /// # Ok(())
718 /// # }
719 /// ```
720 pub fn new() -> Result<Self, Error> {
721 log_retry_configuration();
722
723 // Get timeout from environment or use default
724 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
725 .ok()
726 .and_then(|s| s.parse().ok())
727 .unwrap_or(30); // Default 30s total deadline for API calls
728
729 // Per-chunk idle timeout for bulk transfers: fires only when no bytes
730 // arrive for this duration. Resets after every received chunk, so a
731 // healthy multi-GB transfer will never be interrupted.
732 let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
733 .ok()
734 .and_then(|s| s.parse().ok())
735 .unwrap_or(120); // Default 120s idle timeout for bulk transfers
736
737 // Create single HTTP client with URL-based retry policy
738 //
739 // The retry policy classifies requests into two categories:
740 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
741 // errors
742 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
743 //
744 // This allows the same client to handle both API calls and file operations
745 // with appropriate retry behavior for each. See retry.rs for details.
746 let http = reqwest::Client::builder()
747 .connect_timeout(Duration::from_secs(10))
748 .timeout(Duration::from_secs(timeout_secs))
749 .pool_idle_timeout(Duration::from_secs(90))
750 .pool_max_idle_per_host(10)
751 .retry(create_retry_policy())
752 .build()?;
753
754 // Separate HTTP client for bulk transfers (uploads and downloads).
755 // No total-request timeout (EDGEFIRST_TIMEOUT does not apply here).
756 // Uses read_timeout instead: resets after every received chunk, so a
757 // healthy large transfer is never interrupted, but a truly stalled
758 // connection (no bytes for EDGEFIRST_READ_TIMEOUT seconds) is aborted.
759 let bulk_http = reqwest::Client::builder()
760 .connect_timeout(Duration::from_secs(30))
761 .read_timeout(Duration::from_secs(read_timeout_secs))
762 .pool_idle_timeout(Duration::from_secs(90))
763 .pool_max_idle_per_host(10)
764 .retry(create_retry_policy())
765 .build()?;
766
767 // Default to file storage, loading any existing token
768 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
769 Ok(file_storage) => Arc::new(file_storage),
770 Err(e) => {
771 warn!(
772 "Could not initialize file token storage: {}. Using memory storage.",
773 e
774 );
775 Arc::new(MemoryTokenStorage::new())
776 }
777 };
778
779 // Try to load existing token from storage
780 let token = match storage.load() {
781 Ok(Some(t)) => t,
782 Ok(None) => String::new(),
783 Err(e) => {
784 warn!(
785 "Failed to load token from storage: {}. Starting with empty token.",
786 e
787 );
788 String::new()
789 }
790 };
791
792 // Extract server from token if available
793 let url = if !token.is_empty() {
794 match Self::extract_server_from_token(&token) {
795 Ok(server) => format!("https://{}.edgefirst.studio", server),
796 Err(e) => {
797 warn!(
798 "Failed to extract server from token: {}. Using default server.",
799 e
800 );
801 "https://edgefirst.studio".to_string()
802 }
803 }
804 } else {
805 "https://edgefirst.studio".to_string()
806 };
807
808 Ok(Client {
809 http,
810 bulk_http,
811 url,
812 token: Arc::new(tokio::sync::RwLock::new(token)),
813 storage: Some(storage),
814 token_path: None,
815 })
816 }
817
818 /// Returns a new client connected to the specified server instance.
819 ///
820 /// The server parameter is an instance name that maps to a URL:
821 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
822 /// server)
823 /// - `"test"` → `https://test.edgefirst.studio`
824 /// - `"stage"` → `https://stage.edgefirst.studio`
825 /// - `"dev"` → `https://dev.edgefirst.studio`
826 /// - `"{name}"` → `https://{name}.edgefirst.studio`
827 ///
828 /// # Server Selection Priority
829 ///
830 /// When using the CLI or Python API, server selection follows this
831 /// priority:
832 ///
833 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
834 /// they were issued for. If you have a valid token, its server is used.
835 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
836 /// token is available. If a token exists with a different server, a
837 /// warning is emitted and the token's server takes priority.
838 /// 3. **Default `"saas"`** - If no token and no server specified, the
839 /// production server (`https://edgefirst.studio`) is used.
840 ///
841 /// # Important Notes
842 ///
843 /// - If a token is already set in the client, calling this method will
844 /// **drop the token** as tokens are specific to the server instance.
845 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
846 /// token's server before calling this method.
847 /// - For login operations, call `with_server()` first, then authenticate.
848 ///
849 /// # Examples
850 ///
851 /// ```rust,no_run
852 /// use edgefirst_client::Client;
853 ///
854 /// # fn main() -> Result<(), edgefirst_client::Error> {
855 /// let client = Client::new()?.with_server("test")?;
856 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
857 /// # Ok(())
858 /// # }
859 /// ```
860 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
861 // Resolve the target URL. Full URLs (self-hosted Studio,
862 // wiremock) are validated through `with_url` so the HTTPS rules
863 // there apply uniformly. Short names map to the SaaS pattern.
864 // We extract only the URL string and rebuild the Client below,
865 // because `with_url` preserves the in-memory token (the contract
866 // for self-hosted deployments) whereas `with_server` deliberately
867 // clears it (a different server means a stale token).
868 let url = if server.starts_with("http://") || server.starts_with("https://") {
869 self.with_url(server)?.url().to_string()
870 } else {
871 match server {
872 "" | "saas" => "https://edgefirst.studio".to_string(),
873 name => format!("https://{}.edgefirst.studio", name),
874 }
875 };
876
877 // Clear token from storage when changing servers to prevent
878 // authentication issues with stale tokens from different
879 // instances. This runs whether the caller passed a short name
880 // or a full URL — both reach a new server.
881 if let Some(ref storage) = self.storage
882 && let Err(e) = storage.clear()
883 {
884 warn!(
885 "Failed to clear token from storage when changing servers: {}",
886 e
887 );
888 }
889
890 Ok(Client {
891 url,
892 token: Arc::new(tokio::sync::RwLock::new(String::new())),
893 ..self.clone()
894 })
895 }
896
897 /// Returns a new client pointed at an explicit URL.
898 ///
899 /// Used for self-hosted Studio deployments (e.g.
900 /// `https://studio.example.com`) and for offline integration tests
901 /// against a mock HTTP server (e.g. `http://127.0.0.1:8080`). The
902 /// token is preserved so callers can chain
903 /// `Client::new()?.with_url(...)?.with_token(...)`.
904 ///
905 /// # Errors
906 ///
907 /// Returns [`Error::UrlParseError`] for syntactically invalid URLs and
908 /// [`Error::InsecureUrl`] for plain `http://` URLs that resolve to a
909 /// non-loopback host: the Studio bearer token rides in the
910 /// `Authorization` header, and plain HTTP would leak it in the clear.
911 /// Loopback URLs (`127.0.0.1`, `::1`, `localhost`, `*.localhost`) are
912 /// permitted because traffic never leaves the machine — wiremock and
913 /// local dev servers go through that path.
914 pub fn with_url(&self, url: &str) -> Result<Self, Error> {
915 // Reject malformed inputs early so test failures point at the test
916 // rather than a downstream reqwest send.
917 let parsed = url::Url::parse(url)?;
918 let scheme = parsed.scheme();
919 if scheme == "http" {
920 if !is_loopback_host(parsed.host().as_ref()) {
921 return Err(Error::InsecureUrl(url.to_string()));
922 }
923 } else if scheme != "https" {
924 return Err(Error::InsecureUrl(url.to_string()));
925 }
926 Ok(Client {
927 url: url.trim_end_matches('/').to_string(),
928 ..self.clone()
929 })
930 }
931
932 /// Returns a new client with the specified token storage backend.
933 ///
934 /// Use this to configure custom token storage, such as platform-specific
935 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
936 ///
937 /// # Examples
938 ///
939 /// ```rust,no_run
940 /// use edgefirst_client::{Client, FileTokenStorage};
941 /// use std::{path::PathBuf, sync::Arc};
942 ///
943 /// # fn main() -> Result<(), edgefirst_client::Error> {
944 /// // Use a custom file path for token storage
945 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
946 /// let client = Client::new()?.with_storage(Arc::new(storage));
947 /// # Ok(())
948 /// # }
949 /// ```
950 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
951 // Try to load existing token from the new storage
952 let token = match storage.load() {
953 Ok(Some(t)) => t,
954 Ok(None) => String::new(),
955 Err(e) => {
956 warn!(
957 "Failed to load token from storage: {}. Starting with empty token.",
958 e
959 );
960 String::new()
961 }
962 };
963
964 Client {
965 token: Arc::new(tokio::sync::RwLock::new(token)),
966 storage: Some(storage),
967 token_path: None,
968 ..self
969 }
970 }
971
972 /// Returns a new client with in-memory token storage (no persistence).
973 ///
974 /// Tokens are stored in memory only and lost when the application exits.
975 /// This is useful for testing or when you want to manage token persistence
976 /// externally.
977 ///
978 /// # Examples
979 ///
980 /// ```rust,no_run
981 /// use edgefirst_client::Client;
982 ///
983 /// # fn main() -> Result<(), edgefirst_client::Error> {
984 /// let client = Client::new()?.with_memory_storage();
985 /// # Ok(())
986 /// # }
987 /// ```
988 pub fn with_memory_storage(self) -> Self {
989 Client {
990 token: Arc::new(tokio::sync::RwLock::new(String::new())),
991 storage: Some(Arc::new(MemoryTokenStorage::new())),
992 token_path: None,
993 ..self
994 }
995 }
996
997 /// Returns a new client with no token storage.
998 ///
999 /// Tokens are not persisted. Use this when you want to manage tokens
1000 /// entirely manually.
1001 ///
1002 /// # Examples
1003 ///
1004 /// ```rust,no_run
1005 /// use edgefirst_client::Client;
1006 ///
1007 /// # fn main() -> Result<(), edgefirst_client::Error> {
1008 /// let client = Client::new()?.with_no_storage();
1009 /// # Ok(())
1010 /// # }
1011 /// ```
1012 pub fn with_no_storage(self) -> Self {
1013 Client {
1014 storage: None,
1015 token_path: None,
1016 ..self
1017 }
1018 }
1019
1020 /// Returns a new client authenticated with the provided username and
1021 /// password.
1022 ///
1023 /// The token is automatically persisted to storage (if configured).
1024 ///
1025 /// # Examples
1026 ///
1027 /// ```rust,no_run
1028 /// use edgefirst_client::Client;
1029 ///
1030 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1031 /// let client = Client::new()?
1032 /// .with_server("test")?
1033 /// .with_login("user@example.com", "password")
1034 /// .await?;
1035 /// # Ok(())
1036 /// # }
1037 /// ```
1038 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
1039 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
1040 let params = HashMap::from([("username", username), ("password", password)]);
1041 let login: LoginResult = self
1042 .rpc_without_auth("auth.login".to_owned(), Some(params))
1043 .await?;
1044
1045 // Validate that the server returned a non-empty token
1046 if login.token.is_empty() {
1047 return Err(Error::EmptyToken);
1048 }
1049
1050 // Persist token to storage if configured
1051 if let Some(ref storage) = self.storage
1052 && let Err(e) = storage.store(&login.token)
1053 {
1054 warn!("Failed to persist token to storage: {}", e);
1055 }
1056
1057 Ok(Client {
1058 token: Arc::new(tokio::sync::RwLock::new(login.token)),
1059 ..self.clone()
1060 })
1061 }
1062
1063 /// Returns a new client which will load and save the token to the specified
1064 /// path.
1065 ///
1066 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
1067 /// [`FileTokenStorage`] instead for more flexible token management.
1068 ///
1069 /// This method is maintained for backwards compatibility with existing
1070 /// code. It disables the default storage and uses file-based storage at
1071 /// the specified path.
1072 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
1073 let token_path = match token_path {
1074 Some(path) => path.to_path_buf(),
1075 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1076 .ok_or_else(|| {
1077 Error::IoError(std::io::Error::new(
1078 std::io::ErrorKind::NotFound,
1079 "Could not determine user config directory",
1080 ))
1081 })?
1082 .config_dir()
1083 .join("token"),
1084 };
1085
1086 debug!("Using token path (legacy): {:?}", token_path);
1087
1088 let token = match token_path.exists() {
1089 true => std::fs::read_to_string(&token_path)?,
1090 false => "".to_string(),
1091 };
1092
1093 if !token.is_empty() {
1094 match self.with_token(&token) {
1095 Ok(client) => Ok(Client {
1096 token_path: Some(token_path),
1097 storage: None, // Disable new storage when using legacy token_path
1098 ..client
1099 }),
1100 Err(e) => {
1101 // Token is corrupted or invalid - remove it and continue with no token
1102 warn!(
1103 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
1104 token_path, e
1105 );
1106 if let Err(remove_err) = std::fs::remove_file(&token_path) {
1107 warn!("Failed to remove corrupted token file: {:?}", remove_err);
1108 }
1109 // Clear any token from default storage to ensure we don't use it
1110 Ok(Client {
1111 token_path: Some(token_path),
1112 storage: None,
1113 token: Arc::new(RwLock::new("".to_string())),
1114 ..self.clone()
1115 })
1116 }
1117 }
1118 } else {
1119 // No token in the legacy file - clear any token from default storage
1120 Ok(Client {
1121 token_path: Some(token_path),
1122 storage: None,
1123 token: Arc::new(RwLock::new("".to_string())),
1124 ..self.clone()
1125 })
1126 }
1127 }
1128
1129 /// Returns a new client authenticated with the provided token.
1130 ///
1131 /// The token is automatically persisted to storage (if configured).
1132 /// The server URL is extracted from the token payload.
1133 ///
1134 /// # Examples
1135 ///
1136 /// ```rust,no_run
1137 /// use edgefirst_client::Client;
1138 ///
1139 /// # fn main() -> Result<(), edgefirst_client::Error> {
1140 /// let client = Client::new()?.with_token("your-jwt-token")?;
1141 /// # Ok(())
1142 /// # }
1143 /// ```
1144 /// Extract server name from JWT token payload.
1145 ///
1146 /// Helper method to parse the JWT token and extract the "server" field
1147 /// from the payload. Returns the server name (e.g., "test", "stage", "")
1148 /// or an error if the token is invalid.
1149 fn extract_server_from_token(token: &str) -> Result<String, Error> {
1150 let token_parts: Vec<&str> = token.split('.').collect();
1151 if token_parts.len() != 3 {
1152 return Err(Error::InvalidToken);
1153 }
1154
1155 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1156 .decode(token_parts[1])
1157 .map_err(|_| Error::InvalidToken)?;
1158 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1159 let server = match payload.get("server") {
1160 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
1161 None => return Err(Error::InvalidToken),
1162 };
1163
1164 Ok(server)
1165 }
1166
1167 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
1168 if token.is_empty() {
1169 return Ok(self.clone());
1170 }
1171
1172 let server = Self::extract_server_from_token(token)?;
1173
1174 // Persist token to storage if configured
1175 if let Some(ref storage) = self.storage
1176 && let Err(e) = storage.store(token)
1177 {
1178 warn!("Failed to persist token to storage: {}", e);
1179 }
1180
1181 Ok(Client {
1182 url: format!("https://{}.edgefirst.studio", server),
1183 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
1184 ..self.clone()
1185 })
1186 }
1187
1188 /// Persist the current token to storage.
1189 ///
1190 /// This is automatically called when using [`with_login`][Self::with_login]
1191 /// or [`with_token`][Self::with_token], so you typically don't need to call
1192 /// this directly.
1193 ///
1194 /// If using the legacy `token_path` configuration, saves to the file path.
1195 /// If using the new storage abstraction, saves to the configured storage.
1196 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1197 pub async fn save_token(&self) -> Result<(), Error> {
1198 let token = self.token.read().await;
1199
1200 // Try new storage first
1201 if let Some(ref storage) = self.storage {
1202 storage.store(&token)?;
1203 debug!("Token saved to storage");
1204 return Ok(());
1205 }
1206
1207 // Fall back to legacy token_path behavior
1208 let path = self.token_path.clone().unwrap_or_else(|| {
1209 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1210 .map(|dirs| dirs.config_dir().join("token"))
1211 .unwrap_or_else(|| PathBuf::from(".token"))
1212 });
1213
1214 create_dir_all(path.parent().ok_or_else(|| {
1215 Error::IoError(std::io::Error::new(
1216 std::io::ErrorKind::InvalidInput,
1217 "Token path has no parent directory",
1218 ))
1219 })?)?;
1220 let mut file = std::fs::File::create(&path)?;
1221 file.write_all(token.as_bytes())?;
1222
1223 debug!("Saved token to {:?}", path);
1224
1225 Ok(())
1226 }
1227
1228 /// Return the version of the EdgeFirst Studio server for the current
1229 /// client connection.
1230 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1231 pub async fn version(&self) -> Result<String, Error> {
1232 let version: HashMap<String, String> = self
1233 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
1234 .await?;
1235 let version = version.get("version").ok_or(Error::InvalidResponse)?;
1236 Ok(version.to_owned())
1237 }
1238
1239 /// Clear the token used to authenticate the client with the server.
1240 ///
1241 /// Clears the token from memory and from storage (if configured).
1242 /// If using the legacy `token_path` configuration, removes the token file.
1243 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1244 pub async fn logout(&self) -> Result<(), Error> {
1245 {
1246 let mut token = self.token.write().await;
1247 *token = "".to_string();
1248 }
1249
1250 // Clear from new storage if configured
1251 if let Some(ref storage) = self.storage
1252 && let Err(e) = storage.clear()
1253 {
1254 warn!("Failed to clear token from storage: {}", e);
1255 }
1256
1257 // Also clear legacy token_path if configured
1258 if let Some(path) = &self.token_path
1259 && path.exists()
1260 {
1261 fs::remove_file(path).await?;
1262 }
1263
1264 Ok(())
1265 }
1266
1267 /// Return the token used to authenticate the client with the server. When
1268 /// logging into the server using a username and password, the token is
1269 /// returned by the server and stored in the client for future interactions.
1270 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1271 pub async fn token(&self) -> String {
1272 self.token.read().await.clone()
1273 }
1274
1275 /// Verify the token used to authenticate the client with the server. This
1276 /// method is used to ensure that the token is still valid and has not
1277 /// expired. If the token is invalid, the server will return an error and
1278 /// the client will need to login again.
1279 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1280 pub async fn verify_token(&self) -> Result<(), Error> {
1281 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
1282 .await?;
1283 Ok::<(), Error>(())
1284 }
1285
1286 /// Renew the token used to authenticate the client with the server.
1287 ///
1288 /// Refreshes the token before it expires. If the token has already expired,
1289 /// the server will return an error and you will need to login again.
1290 ///
1291 /// The new token is automatically persisted to storage (if configured).
1292 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1293 pub async fn renew_token(&self) -> Result<(), Error> {
1294 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1295 let result: LoginResult = self
1296 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1297 .await?;
1298
1299 {
1300 let mut token = self.token.write().await;
1301 *token = result.token.clone();
1302 }
1303
1304 // Persist to new storage if configured
1305 if let Some(ref storage) = self.storage
1306 && let Err(e) = storage.store(&result.token)
1307 {
1308 warn!("Failed to persist renewed token to storage: {}", e);
1309 }
1310
1311 // Also persist to legacy token_path if configured
1312 if self.token_path.is_some() {
1313 self.save_token().await?;
1314 }
1315
1316 Ok(())
1317 }
1318
1319 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1320 let token = self.token.read().await;
1321 if token.is_empty() {
1322 return Err(Error::EmptyToken);
1323 }
1324
1325 let token_parts: Vec<&str> = token.split('.').collect();
1326 if token_parts.len() != 3 {
1327 return Err(Error::InvalidToken);
1328 }
1329
1330 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1331 .decode(token_parts[1])
1332 .map_err(|_| Error::InvalidToken)?;
1333 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1334 match payload.get(field) {
1335 Some(value) => Ok(value.to_owned()),
1336 None => Err(Error::InvalidToken),
1337 }
1338 }
1339
1340 /// Returns the URL of the EdgeFirst Studio server for the current client.
1341 pub fn url(&self) -> &str {
1342 &self.url
1343 }
1344
1345 /// Returns the server name for the current client.
1346 ///
1347 /// This extracts the server name from the client's URL:
1348 /// - `https://edgefirst.studio` → `"saas"`
1349 /// - `https://test.edgefirst.studio` → `"test"`
1350 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1351 ///
1352 /// # Examples
1353 ///
1354 /// ```rust,no_run
1355 /// use edgefirst_client::Client;
1356 ///
1357 /// # fn main() -> Result<(), edgefirst_client::Error> {
1358 /// let client = Client::new()?.with_server("test")?;
1359 /// assert_eq!(client.server(), "test");
1360 ///
1361 /// let client = Client::new()?; // default
1362 /// assert_eq!(client.server(), "saas");
1363 /// # Ok(())
1364 /// # }
1365 /// ```
1366 pub fn server(&self) -> &str {
1367 if self.url == "https://edgefirst.studio" {
1368 "saas"
1369 } else if let Some(name) = self.url.strip_prefix("https://") {
1370 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1371 } else {
1372 "saas"
1373 }
1374 }
1375
1376 /// Returns the username associated with the current token.
1377 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1378 pub async fn username(&self) -> Result<String, Error> {
1379 match self.token_field("username").await? {
1380 serde_json::Value::String(username) => Ok(username),
1381 _ => Err(Error::InvalidToken),
1382 }
1383 }
1384
1385 /// Returns the expiration time for the current token.
1386 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1387 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1388 let ts = match self.token_field("exp").await? {
1389 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1390 _ => return Err(Error::InvalidToken),
1391 };
1392
1393 match DateTime::<Utc>::from_timestamp(ts, 0) {
1394 Some(dt) => Ok(dt),
1395 None => Err(Error::InvalidToken),
1396 }
1397 }
1398
1399 /// Returns the organization information for the current user.
1400 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1401 pub async fn organization(&self) -> Result<Organization, Error> {
1402 self.rpc::<(), Organization>("org.get".to_owned(), None)
1403 .await
1404 }
1405
1406 /// Returns a list of projects available to the user. The projects are
1407 /// returned as a vector of Project objects. If a name filter is
1408 /// provided, only projects matching the filter are returned.
1409 ///
1410 /// Results are sorted by match quality: exact matches first, then
1411 /// case-insensitive exact matches, then shorter names (more specific),
1412 /// then alphabetically.
1413 ///
1414 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1415 /// Projects contain datasets, trainers, and trainer sessions. Projects
1416 /// are used to group related datasets and trainers together.
1417 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1418 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1419 let projects = self
1420 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1421 .await?;
1422 if let Some(name) = name {
1423 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1424 } else {
1425 Ok(projects)
1426 }
1427 }
1428
1429 /// Return the project with the specified project ID. If the project does
1430 /// not exist, an error is returned.
1431 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1432 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1433 let params = HashMap::from([("project_id", project_id)]);
1434 self.rpc("project.get".to_owned(), Some(params)).await
1435 }
1436
1437 /// Returns a list of datasets available to the user. The datasets are
1438 /// returned as a vector of Dataset objects. If a name filter is
1439 /// provided, only datasets matching the filter are returned.
1440 ///
1441 /// Results are sorted by match quality: exact matches first, then
1442 /// case-insensitive exact matches, then shorter names (more specific),
1443 /// then alphabetically. This ensures "Deer" returns before "Deer
1444 /// Roundtrip".
1445 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1446 pub async fn datasets(
1447 &self,
1448 project_id: ProjectID,
1449 name: Option<&str>,
1450 ) -> Result<Vec<Dataset>, Error> {
1451 let params = HashMap::from([("project_id", project_id)]);
1452 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1453 if let Some(name) = name {
1454 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1455 } else {
1456 Ok(datasets)
1457 }
1458 }
1459
1460 /// Return the dataset with the specified dataset ID. If the dataset does
1461 /// not exist, an error is returned.
1462 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1463 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1464 let params = HashMap::from([("dataset_id", dataset_id)]);
1465 self.rpc("dataset.get".to_owned(), Some(params)).await
1466 }
1467
1468 /// Lists the labels for the specified dataset.
1469 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1470 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1471 let params = HashMap::from([("dataset_id", dataset_id)]);
1472 self.rpc("label.list".to_owned(), Some(params)).await
1473 }
1474
1475 /// Add a new label to the dataset with the specified name.
1476 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1477 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1478 let new_label = NewLabel {
1479 dataset_id,
1480 labels: vec![NewLabelObject {
1481 name: name.to_owned(),
1482 }],
1483 };
1484 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1485 Ok(())
1486 }
1487
1488 /// Removes the label with the specified ID from the dataset. Label IDs are
1489 /// globally unique so the dataset_id is not required.
1490 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1491 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1492 let params = HashMap::from([("label_id", label_id)]);
1493 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1494 Ok(())
1495 }
1496
1497 /// Creates a new dataset in the specified project.
1498 ///
1499 /// # Arguments
1500 ///
1501 /// * `project_id` - The ID of the project to create the dataset in
1502 /// * `name` - The name of the new dataset
1503 /// * `description` - Optional description for the dataset
1504 ///
1505 /// # Returns
1506 ///
1507 /// Returns the dataset ID of the newly created dataset.
1508 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1509 pub async fn create_dataset(
1510 &self,
1511 project_id: &str,
1512 name: &str,
1513 description: Option<&str>,
1514 ) -> Result<DatasetID, Error> {
1515 let mut params = HashMap::new();
1516 params.insert("project_id", project_id);
1517 params.insert("name", name);
1518 if let Some(desc) = description {
1519 params.insert("description", desc);
1520 }
1521
1522 #[derive(Deserialize)]
1523 struct CreateDatasetResult {
1524 id: DatasetID,
1525 }
1526
1527 let result: CreateDatasetResult =
1528 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1529 Ok(result.id)
1530 }
1531
1532 /// Deletes a dataset by marking it as deleted.
1533 ///
1534 /// # Arguments
1535 ///
1536 /// * `dataset_id` - The ID of the dataset to delete
1537 ///
1538 /// # Returns
1539 ///
1540 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1541 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1542 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1543 let params = HashMap::from([("id", dataset_id)]);
1544 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1545 Ok(())
1546 }
1547
1548 /// Updates the label with the specified ID to have the new name or index.
1549 /// Label IDs cannot be changed. Label IDs are globally unique so the
1550 /// dataset_id is not required.
1551 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1552 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1553 #[derive(Serialize)]
1554 struct Params {
1555 dataset_id: DatasetID,
1556 label_id: u64,
1557 label_name: String,
1558 label_index: u64,
1559 }
1560
1561 let _: String = self
1562 .rpc(
1563 "label.update".to_owned(),
1564 Some(Params {
1565 dataset_id: label.dataset_id(),
1566 label_id: label.id(),
1567 label_name: label.name().to_owned(),
1568 label_index: label.index(),
1569 }),
1570 )
1571 .await?;
1572 Ok(())
1573 }
1574
1575 /// Lists the groups for the specified dataset.
1576 ///
1577 /// Groups are used to organize samples into logical subsets such as
1578 /// "train", "val", "test", etc. Each sample can belong to at most one
1579 /// group at a time.
1580 ///
1581 /// # Arguments
1582 ///
1583 /// * `dataset_id` - The ID of the dataset to list groups for
1584 ///
1585 /// # Returns
1586 ///
1587 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1588 /// empty vector if no groups have been created yet.
1589 ///
1590 /// # Errors
1591 ///
1592 /// Returns an error if the dataset does not exist or cannot be accessed.
1593 ///
1594 /// # Example
1595 ///
1596 /// ```rust,no_run
1597 /// # use edgefirst_client::{Client, DatasetID};
1598 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1599 /// let client = Client::new()?.with_token_path(None)?;
1600 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1601 ///
1602 /// let groups = client.groups(dataset_id).await?;
1603 /// for group in groups {
1604 /// println!("{}: {}", group.id, group.name);
1605 /// }
1606 /// # Ok(())
1607 /// # }
1608 /// ```
1609 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1610 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1611 let params = HashMap::from([("dataset_id", dataset_id)]);
1612 self.rpc("groups.list".to_owned(), Some(params)).await
1613 }
1614
1615 /// Gets an existing group by name or creates a new one.
1616 ///
1617 /// This is a convenience method that first checks if a group with the
1618 /// specified name exists, and creates it if not. This is useful when
1619 /// you need to ensure a group exists before assigning samples to it.
1620 ///
1621 /// # Arguments
1622 ///
1623 /// * `dataset_id` - The ID of the dataset
1624 /// * `name` - The name of the group (e.g., "train", "val", "test")
1625 ///
1626 /// # Returns
1627 ///
1628 /// Returns the group ID (either existing or newly created).
1629 ///
1630 /// # Errors
1631 ///
1632 /// Returns an error if:
1633 /// - The dataset does not exist or cannot be accessed
1634 /// - The group creation fails
1635 ///
1636 /// # Concurrency
1637 ///
1638 /// This method handles concurrent creation attempts gracefully. If another
1639 /// process creates the group between the existence check and creation,
1640 /// this method will return the existing group's ID.
1641 ///
1642 /// # Example
1643 ///
1644 /// ```rust,no_run
1645 /// # use edgefirst_client::{Client, DatasetID};
1646 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1647 /// let client = Client::new()?.with_token_path(None)?;
1648 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1649 ///
1650 /// // Get or create a "train" group
1651 /// let train_group_id = client
1652 /// .get_or_create_group(dataset_id.clone(), "train")
1653 /// .await?;
1654 /// println!("Train group ID: {}", train_group_id);
1655 ///
1656 /// // Calling again returns the same ID
1657 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1658 /// assert_eq!(train_group_id, same_id);
1659 /// # Ok(())
1660 /// # }
1661 /// ```
1662 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1663 pub async fn get_or_create_group(
1664 &self,
1665 dataset_id: DatasetID,
1666 name: &str,
1667 ) -> Result<u64, Error> {
1668 // First check if the group already exists
1669 let groups = self.groups(dataset_id).await?;
1670 if let Some(group) = groups.iter().find(|g| g.name == name) {
1671 return Ok(group.id);
1672 }
1673
1674 // Create the group
1675 #[derive(Serialize)]
1676 struct CreateGroupParams {
1677 dataset_id: DatasetID,
1678 group_names: Vec<String>,
1679 group_splits: Vec<i64>,
1680 }
1681
1682 let params = CreateGroupParams {
1683 dataset_id,
1684 group_names: vec![name.to_string()],
1685 group_splits: vec![0], // No automatic splitting
1686 };
1687
1688 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1689 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1690 Ok(group.id)
1691 } else {
1692 // Group might have been created by concurrent call, try fetching again
1693 let groups = self.groups(dataset_id).await?;
1694 groups
1695 .iter()
1696 .find(|g| g.name == name)
1697 .map(|g| g.id)
1698 .ok_or_else(|| {
1699 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1700 })
1701 }
1702 }
1703
1704 /// Sets the group for a sample.
1705 ///
1706 /// Assigns a sample to a specific group. Each sample can belong to at most
1707 /// one group at a time. Setting a new group replaces any existing group
1708 /// assignment.
1709 ///
1710 /// # Arguments
1711 ///
1712 /// * `sample_id` - The ID of the sample (image) to update
1713 /// * `group_id` - The ID of the group to assign. Use
1714 /// [`get_or_create_group`] to obtain a group ID from a name.
1715 ///
1716 /// # Returns
1717 ///
1718 /// Returns `Ok(())` on success.
1719 ///
1720 /// # Errors
1721 ///
1722 /// Returns an error if:
1723 /// - The sample does not exist
1724 /// - The group does not exist
1725 /// - Insufficient permissions to modify the sample
1726 ///
1727 /// # Example
1728 ///
1729 /// ```rust,no_run
1730 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1731 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1732 /// let client = Client::new()?.with_token_path(None)?;
1733 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1734 /// let sample_id: SampleID = 12345.into();
1735 ///
1736 /// // Get or create the "val" group
1737 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1738 ///
1739 /// // Assign the sample to the "val" group
1740 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1741 /// # Ok(())
1742 /// # }
1743 /// ```
1744 ///
1745 /// [`get_or_create_group`]: Self::get_or_create_group
1746 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1747 pub async fn set_sample_group_id(
1748 &self,
1749 sample_id: SampleID,
1750 group_id: u64,
1751 ) -> Result<(), Error> {
1752 #[derive(Serialize)]
1753 struct SetGroupParams {
1754 image_id: SampleID,
1755 group_id: u64,
1756 }
1757
1758 let params = SetGroupParams {
1759 image_id: sample_id,
1760 group_id,
1761 };
1762 let _: String = self
1763 .rpc("image.set_group_id".to_owned(), Some(params))
1764 .await?;
1765 Ok(())
1766 }
1767
1768 /// Downloads dataset samples to the local filesystem.
1769 ///
1770 /// # Arguments
1771 ///
1772 /// * `dataset_id` - The unique identifier of the dataset
1773 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1774 /// * `file_types` - File types to download. Supported types:
1775 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
1776 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
1777 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
1778 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
1779 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
1780 /// - `FileType::RadarCube` - Radar cube data (.png format)
1781 /// - `FileType::All` - All sensor types (expands to all of the above)
1782 /// * `output` - Local directory to save downloaded files
1783 /// * `flatten` - If true, download all files to output root without
1784 /// sequence subdirectories. When flattening, filenames are prefixed with
1785 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1786 /// unavailable) unless the filename already starts with
1787 /// `{sequence_name}_`, to avoid conflicts between sequences.
1788 /// * `progress` - Optional channel for progress updates
1789 ///
1790 /// # Progress
1791 ///
1792 /// This operation has two phases with distinct progress reporting:
1793 ///
1794 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
1795 /// from the server. Progress counts samples fetched.
1796 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
1797 /// files to disk. Progress counts samples completed (each sample may
1798 /// have multiple files for different sensor types).
1799 ///
1800 /// Applications should detect the status change from `None` to
1801 /// `"Downloading"` to reset their progress bar for the second phase.
1802 ///
1803 /// # Returns
1804 ///
1805 /// Returns `Ok(())` on success or an error if download fails.
1806 ///
1807 /// # Example
1808 ///
1809 /// ```rust,no_run
1810 /// # use edgefirst_client::{Client, DatasetID, FileType};
1811 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1812 /// let client = Client::new()?.with_token_path(None)?;
1813 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1814 ///
1815 /// // Download with sequence subdirectories (default)
1816 /// client
1817 /// .download_dataset(
1818 /// dataset_id,
1819 /// &[],
1820 /// &[FileType::Image],
1821 /// "./data".into(),
1822 /// false,
1823 /// None,
1824 /// )
1825 /// .await?;
1826 ///
1827 /// // Download flattened (all files in one directory)
1828 /// client
1829 /// .download_dataset(
1830 /// dataset_id,
1831 /// &[],
1832 /// &[FileType::Image],
1833 /// "./data".into(),
1834 /// true,
1835 /// None,
1836 /// )
1837 /// .await?;
1838 ///
1839 /// // Download all sensor types
1840 /// client
1841 /// .download_dataset(
1842 /// dataset_id,
1843 /// &[],
1844 /// &FileType::expand_types(&[FileType::All]),
1845 /// "./data".into(),
1846 /// false,
1847 /// None,
1848 /// )
1849 /// .await?;
1850 /// # Ok(())
1851 /// # }
1852 /// ```
1853 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
1854 pub async fn download_dataset(
1855 &self,
1856 dataset_id: DatasetID,
1857 groups: &[String],
1858 file_types: &[FileType],
1859 output: PathBuf,
1860 flatten: bool,
1861 progress: Option<Sender<Progress>>,
1862 ) -> Result<(), Error> {
1863 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
1864 let samples = self
1865 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1866 .await?;
1867 fs::create_dir_all(&output).await?;
1868
1869 // Phase 2: Download actual files using direct semaphore pattern
1870 let total = samples.len();
1871 let current = Arc::new(AtomicUsize::new(0));
1872 let sem = Arc::new(Semaphore::new(max_tasks()));
1873
1874 // Send initial progress for download phase
1875 if let Some(ref progress) = progress {
1876 let _ = progress
1877 .send(Progress {
1878 current: 0,
1879 total,
1880 status: Some("Downloading".to_string()),
1881 })
1882 .await;
1883 }
1884
1885 let tasks = samples
1886 .into_iter()
1887 .map(|sample| {
1888 let client = self.clone();
1889 let file_types = file_types.to_vec();
1890 let output = output.clone();
1891 let progress = progress.clone();
1892 let current = current.clone();
1893 let sem = sem.clone();
1894
1895 tokio::spawn(async move {
1896 let _permit = sem.acquire().await.map_err(|_| {
1897 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1898 })?;
1899
1900 for file_type in &file_types {
1901 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1902 let (file_ext, is_image) = match file_type {
1903 FileType::Image => (
1904 infer::get(&data)
1905 .expect("Failed to identify image file format for sample")
1906 .extension()
1907 .to_string(),
1908 true,
1909 ),
1910 other => (other.file_extension().to_string(), false),
1911 };
1912
1913 // Determine target directory based on sequence membership and
1914 // flatten option
1915 // - flatten=false + sequence_name: dataset/sequence_name/
1916 // - flatten=false + no sequence: dataset/ (root level)
1917 // - flatten=true: dataset/ (all files in output root)
1918 // NOTE: group (train/val/test) is NOT used for directory structure
1919 let sequence_dir = sample
1920 .sequence_name()
1921 .map(|name| sanitize_path_component(name));
1922
1923 let target_dir = if flatten {
1924 output.clone()
1925 } else {
1926 sequence_dir
1927 .as_ref()
1928 .map(|seq| output.join(seq))
1929 .unwrap_or_else(|| output.clone())
1930 };
1931 fs::create_dir_all(&target_dir).await?;
1932
1933 let sanitized_sample_name = sample
1934 .name()
1935 .map(|name| sanitize_path_component(&name))
1936 .unwrap_or_else(|| "unknown".to_string());
1937
1938 let image_name = sample.image_name().map(sanitize_path_component);
1939
1940 // Construct filename with smart prefixing for flatten mode
1941 // When flatten=true and sample belongs to a sequence:
1942 // - Check if filename already starts with "{sequence_name}_"
1943 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1944 // - If yes, use filename as-is (already uniquely named)
1945 let file_name = if is_image {
1946 if let Some(img_name) = image_name {
1947 Client::build_filename(
1948 &img_name,
1949 flatten,
1950 sequence_dir.as_ref(),
1951 sample.frame_number(),
1952 )
1953 } else {
1954 format!("{}.{}", sanitized_sample_name, file_ext)
1955 }
1956 } else {
1957 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1958 Client::build_filename(
1959 &base_name,
1960 flatten,
1961 sequence_dir.as_ref(),
1962 sample.frame_number(),
1963 )
1964 };
1965
1966 let file_path = target_dir.join(&file_name);
1967
1968 let mut file = File::create(&file_path).await?;
1969 file.write_all(&data).await?;
1970 }
1971 }
1972
1973 // Update progress after sample completes
1974 if let Some(progress) = &progress {
1975 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
1976 let _ = progress
1977 .send(Progress {
1978 current: completed,
1979 total,
1980 status: Some("Downloading".to_string()),
1981 })
1982 .await;
1983 }
1984
1985 Ok::<(), Error>(())
1986 })
1987 })
1988 .collect::<Vec<_>>();
1989
1990 join_all(tasks)
1991 .await
1992 .into_iter()
1993 .collect::<Result<Vec<_>, _>>()?
1994 .into_iter()
1995 .collect::<Result<Vec<_>, _>>()?;
1996
1997 Ok(())
1998 }
1999
2000 /// Builds a filename with smart prefixing for flatten mode.
2001 ///
2002 /// When flattening sequences into a single directory, this function ensures
2003 /// unique filenames by checking if the sequence prefix already exists and
2004 /// adding it if necessary.
2005 ///
2006 /// # Logic
2007 ///
2008 /// - If `flatten=false`: returns `base_name` unchanged
2009 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
2010 /// - If `flatten=true` and in sequence:
2011 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
2012 /// unchanged
2013 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
2014 /// `{sequence_name}_{base_name}`
2015 fn build_filename(
2016 base_name: &str,
2017 flatten: bool,
2018 sequence_name: Option<&String>,
2019 frame_number: Option<u32>,
2020 ) -> String {
2021 if !flatten || sequence_name.is_none() {
2022 return base_name.to_string();
2023 }
2024
2025 let seq_name = sequence_name.unwrap();
2026 let prefix = format!("{}_", seq_name);
2027
2028 // Check if already prefixed with sequence name
2029 if base_name.starts_with(&prefix) {
2030 base_name.to_string()
2031 } else {
2032 // Add sequence (and optionally frame) prefix
2033 match frame_number {
2034 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
2035 None => format!("{}{}", prefix, base_name),
2036 }
2037 }
2038 }
2039
2040 /// List available annotation sets for the specified dataset.
2041 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2042 pub async fn annotation_sets(
2043 &self,
2044 dataset_id: DatasetID,
2045 ) -> Result<Vec<AnnotationSet>, Error> {
2046 let params = HashMap::from([("dataset_id", dataset_id)]);
2047 self.rpc("annset.list".to_owned(), Some(params)).await
2048 }
2049
2050 /// Create a new annotation set for the specified dataset.
2051 ///
2052 /// # Arguments
2053 ///
2054 /// * `dataset_id` - The ID of the dataset to create the annotation set in
2055 /// * `name` - The name of the new annotation set
2056 /// * `description` - Optional description for the annotation set
2057 ///
2058 /// # Returns
2059 ///
2060 /// Returns the annotation set ID of the newly created annotation set.
2061 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2062 pub async fn create_annotation_set(
2063 &self,
2064 dataset_id: DatasetID,
2065 name: &str,
2066 description: Option<&str>,
2067 ) -> Result<AnnotationSetID, Error> {
2068 #[derive(Serialize)]
2069 struct Params<'a> {
2070 dataset_id: DatasetID,
2071 name: &'a str,
2072 operator: &'a str,
2073 #[serde(skip_serializing_if = "Option::is_none")]
2074 description: Option<&'a str>,
2075 }
2076
2077 #[derive(Deserialize)]
2078 struct CreateAnnotationSetResult {
2079 id: AnnotationSetID,
2080 }
2081
2082 let username = self.username().await?;
2083 let result: CreateAnnotationSetResult = self
2084 .rpc(
2085 "annset.add".to_owned(),
2086 Some(Params {
2087 dataset_id,
2088 name,
2089 operator: &username,
2090 description,
2091 }),
2092 )
2093 .await?;
2094 Ok(result.id)
2095 }
2096
2097 /// Deletes an annotation set by marking it as deleted.
2098 ///
2099 /// # Arguments
2100 ///
2101 /// * `annotation_set_id` - The ID of the annotation set to delete
2102 ///
2103 /// # Returns
2104 ///
2105 /// Returns `Ok(())` if the annotation set was successfully marked as
2106 /// deleted.
2107 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2108 pub async fn delete_annotation_set(
2109 &self,
2110 annotation_set_id: AnnotationSetID,
2111 ) -> Result<(), Error> {
2112 let params = HashMap::from([("id", annotation_set_id)]);
2113 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
2114 Ok(())
2115 }
2116
2117 /// Retrieve the annotation set with the specified ID.
2118 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2119 pub async fn annotation_set(
2120 &self,
2121 annotation_set_id: AnnotationSetID,
2122 ) -> Result<AnnotationSet, Error> {
2123 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
2124 self.rpc("annset.get".to_owned(), Some(params)).await
2125 }
2126
2127 /// Get the annotations for the specified annotation set with the
2128 /// requested annotation types. The annotation types are used to filter
2129 /// the annotations returned. The groups parameter is used to filter for
2130 /// dataset groups (train, val, test). Images which do not have any
2131 /// annotations are also included in the result as long as they are in the
2132 /// requested groups (when specified).
2133 ///
2134 /// The result is a vector of Annotations objects which contain the
2135 /// full dataset along with the annotations for the specified types.
2136 ///
2137 /// # Progress
2138 ///
2139 /// Reports progress with `status: None` as samples are fetched and
2140 /// processed for their annotations. Progress unit is samples processed
2141 /// (not individual annotations).
2142 ///
2143 /// To get the annotations as a DataFrame, use the `samples_dataframe`
2144 /// method instead.
2145 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2146 pub async fn annotations(
2147 &self,
2148 annotation_set_id: AnnotationSetID,
2149 groups: &[String],
2150 annotation_types: &[AnnotationType],
2151 progress: Option<Sender<Progress>>,
2152 ) -> Result<Vec<Annotation>, Error> {
2153 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
2154 let labels = self
2155 .labels(dataset_id)
2156 .await?
2157 .into_iter()
2158 .map(|label| (label.name().to_string(), label.index()))
2159 .collect::<HashMap<_, _>>();
2160 let total = self
2161 .samples_count(
2162 dataset_id,
2163 Some(annotation_set_id),
2164 annotation_types,
2165 groups,
2166 &[],
2167 )
2168 .await?
2169 .total as usize;
2170
2171 if total == 0 {
2172 return Ok(vec![]);
2173 }
2174
2175 let context = FetchContext {
2176 dataset_id,
2177 annotation_set_id: Some(annotation_set_id),
2178 groups,
2179 types: annotation_types.iter().map(|t| t.to_string()).collect(),
2180 labels: &labels,
2181 };
2182
2183 self.fetch_annotations_paginated(context, total, progress)
2184 .await
2185 }
2186
2187 async fn fetch_annotations_paginated(
2188 &self,
2189 context: FetchContext<'_>,
2190 total: usize,
2191 progress: Option<Sender<Progress>>,
2192 ) -> Result<Vec<Annotation>, Error> {
2193 let mut annotations = vec![];
2194 let mut continue_token: Option<String> = None;
2195 let mut current = 0;
2196
2197 loop {
2198 let params = SamplesListParams {
2199 dataset_id: context.dataset_id,
2200 annotation_set_id: context.annotation_set_id,
2201 types: context.types.clone(),
2202 group_names: context.groups.to_vec(),
2203 continue_token,
2204 };
2205
2206 let result: SamplesListResult =
2207 self.rpc("samples.list".to_owned(), Some(params)).await?;
2208 current += result.samples.len();
2209 continue_token = result.continue_token;
2210
2211 if result.samples.is_empty() {
2212 break;
2213 }
2214
2215 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
2216
2217 if let Some(progress) = &progress {
2218 let _ = progress
2219 .send(Progress {
2220 current,
2221 total,
2222 status: None,
2223 })
2224 .await;
2225 }
2226
2227 match &continue_token {
2228 Some(token) if !token.is_empty() => continue,
2229 _ => break,
2230 }
2231 }
2232
2233 drop(progress);
2234 Ok(annotations)
2235 }
2236
2237 fn process_sample_annotations(
2238 &self,
2239 samples: &[Sample],
2240 labels: &HashMap<String, u64>,
2241 annotations: &mut Vec<Annotation>,
2242 ) {
2243 for sample in samples {
2244 if sample.annotations().is_empty() {
2245 let mut annotation = Annotation::new();
2246 annotation.set_sample_id(sample.id());
2247 annotation.set_name(sample.name());
2248 annotation.set_sequence_name(sample.sequence_name().cloned());
2249 annotation.set_frame_number(sample.frame_number());
2250 annotation.set_group(sample.group().cloned());
2251 annotations.push(annotation);
2252 continue;
2253 }
2254
2255 for annotation in sample.annotations() {
2256 let mut annotation = annotation.clone();
2257 annotation.set_sample_id(sample.id());
2258 annotation.set_name(sample.name());
2259 annotation.set_sequence_name(sample.sequence_name().cloned());
2260 annotation.set_frame_number(sample.frame_number());
2261 annotation.set_group(sample.group().cloned());
2262 Self::set_label_index_from_map(&mut annotation, labels);
2263 annotations.push(annotation);
2264 }
2265 }
2266 }
2267
2268 /// Delete annotations in bulk from specified samples.
2269 ///
2270 /// This method calls the `annotation.bulk.del` API to efficiently remove
2271 /// annotations from multiple samples at once. Useful for clearing
2272 /// annotations before re-importing updated data.
2273 ///
2274 /// # Arguments
2275 /// * `annotation_set_id` - The annotation set containing the annotations
2276 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
2277 /// for masks
2278 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
2279 ///
2280 /// # Example
2281 /// ```no_run
2282 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
2283 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2284 /// # let client = Client::new()?.with_login("user", "pass").await?;
2285 /// let annotation_set_id = AnnotationSetID::from(123);
2286 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2287 ///
2288 /// client
2289 /// .delete_annotations_bulk(
2290 /// annotation_set_id,
2291 /// &["box".to_string(), "seg".to_string()],
2292 /// &sample_ids,
2293 /// )
2294 /// .await?;
2295 /// # Ok(())
2296 /// # }
2297 /// ```
2298 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2299 pub async fn delete_annotations_bulk(
2300 &self,
2301 annotation_set_id: AnnotationSetID,
2302 annotation_types: &[String],
2303 sample_ids: &[SampleID],
2304 ) -> Result<(), Error> {
2305 use crate::api::AnnotationBulkDeleteParams;
2306
2307 let params = AnnotationBulkDeleteParams {
2308 annotation_set_id: annotation_set_id.into(),
2309 annotation_types: annotation_types.to_vec(),
2310 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2311 delete_all: None,
2312 };
2313
2314 let _: String = self
2315 .rpc("annotation.bulk.del".to_owned(), Some(params))
2316 .await?;
2317 Ok(())
2318 }
2319
2320 /// Add annotations in bulk.
2321 ///
2322 /// This method calls the `annotation.add_bulk` API to efficiently add
2323 /// multiple annotations at once. The annotations must be in server format
2324 /// with image_id references.
2325 ///
2326 /// # Arguments
2327 /// * `annotation_set_id` - The annotation set to add annotations to
2328 /// * `annotations` - Vector of server-format annotations to add
2329 ///
2330 /// # Returns
2331 /// Vector of created annotation records from the server.
2332 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2333 pub async fn add_annotations_bulk(
2334 &self,
2335 annotation_set_id: AnnotationSetID,
2336 annotations: Vec<crate::api::ServerAnnotation>,
2337 ) -> Result<Vec<serde_json::Value>, Error> {
2338 use crate::api::AnnotationAddBulkParams;
2339
2340 let params = AnnotationAddBulkParams {
2341 annotation_set_id: annotation_set_id.into(),
2342 annotations,
2343 };
2344
2345 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2346 .await
2347 }
2348
2349 /// Helper to parse frame number from image_name when sequence_name is
2350 /// present. This ensures frame_number is always derived from the image
2351 /// filename, not from the server's frame_number field (which may be
2352 /// inconsistent).
2353 ///
2354 /// Returns Some(frame_number) if sequence_name is present and frame can be
2355 /// parsed, otherwise None.
2356 fn parse_frame_from_image_name(
2357 image_name: Option<&String>,
2358 sequence_name: Option<&String>,
2359 ) -> Option<u32> {
2360 use std::path::Path;
2361
2362 let sequence = sequence_name?;
2363 let name = image_name?;
2364
2365 // Extract stem (remove extension)
2366 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2367
2368 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2369 stem.strip_prefix(sequence)
2370 .and_then(|suffix| suffix.strip_prefix('_'))
2371 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2372 }
2373
2374 /// Helper to set label index from a label map
2375 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2376 if let Some(label) = annotation.label() {
2377 annotation.set_label_index(Some(labels[label.as_str()]));
2378 }
2379 }
2380
2381 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2382 pub async fn samples_count(
2383 &self,
2384 dataset_id: DatasetID,
2385 annotation_set_id: Option<AnnotationSetID>,
2386 annotation_types: &[AnnotationType],
2387 groups: &[String],
2388 types: &[FileType],
2389 ) -> Result<SamplesCountResult, Error> {
2390 // Use server type names for API calls (e.g., "box" instead of "box2d")
2391 let types = annotation_types
2392 .iter()
2393 .map(|t| t.as_server_type().to_string())
2394 .chain(types.iter().map(|t| t.to_string()))
2395 .collect::<Vec<_>>();
2396
2397 let params = SamplesListParams {
2398 dataset_id,
2399 annotation_set_id,
2400 group_names: groups.to_vec(),
2401 types,
2402 continue_token: None,
2403 };
2404
2405 self.rpc("samples.count".to_owned(), Some(params)).await
2406 }
2407
2408 /// Fetches samples from a dataset with optional annotation and file type
2409 /// filters.
2410 ///
2411 /// # Arguments
2412 ///
2413 /// * `dataset_id` - The dataset to fetch samples from
2414 /// * `annotation_set_id` - Optional annotation set to include annotations
2415 /// from
2416 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2417 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2418 /// * `types` - File types to include metadata for
2419 /// * `progress` - Optional channel for progress updates
2420 ///
2421 /// # Progress
2422 ///
2423 /// Reports progress with `status: None` as samples are fetched from the
2424 /// server in paginated batches. Progress unit is samples fetched.
2425 ///
2426 /// # Returns
2427 ///
2428 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2429 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2430 pub async fn samples(
2431 &self,
2432 dataset_id: DatasetID,
2433 annotation_set_id: Option<AnnotationSetID>,
2434 annotation_types: &[AnnotationType],
2435 groups: &[String],
2436 types: &[FileType],
2437 progress: Option<Sender<Progress>>,
2438 ) -> Result<Vec<Sample>, Error> {
2439 // Use server type names for API calls (e.g., "box" instead of "box2d")
2440 let types_vec = annotation_types
2441 .iter()
2442 .map(|t| t.as_server_type().to_string())
2443 .chain(types.iter().map(|t| t.to_string()))
2444 .collect::<Vec<_>>();
2445 let labels = self
2446 .labels(dataset_id)
2447 .await?
2448 .into_iter()
2449 .map(|label| (label.name().to_string(), label.index()))
2450 .collect::<HashMap<_, _>>();
2451 let total = self
2452 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2453 .await?
2454 .total as usize;
2455
2456 if total == 0 {
2457 return Ok(vec![]);
2458 }
2459
2460 let context = FetchContext {
2461 dataset_id,
2462 annotation_set_id,
2463 groups,
2464 types: types_vec,
2465 labels: &labels,
2466 };
2467
2468 self.fetch_samples_paginated(context, total, progress).await
2469 }
2470
2471 /// Get all sample names in a dataset.
2472 ///
2473 /// This is an efficient method for checking which samples already exist,
2474 /// useful for resuming interrupted imports. It only retrieves sample names
2475 /// without loading full annotation data.
2476 ///
2477 /// # Arguments
2478 ///
2479 /// * `dataset_id` - The dataset to query
2480 /// * `groups` - Optional group filter (empty = all groups)
2481 /// * `progress` - Optional progress channel
2482 ///
2483 /// # Progress
2484 ///
2485 /// Reports progress with `status: None` as sample names are fetched from
2486 /// the server in paginated batches. Progress unit is samples fetched.
2487 ///
2488 /// # Returns
2489 ///
2490 /// A HashSet of sample names (image_name field) that exist in the dataset.
2491 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2492 pub async fn sample_names(
2493 &self,
2494 dataset_id: DatasetID,
2495 groups: &[String],
2496 progress: Option<Sender<Progress>>,
2497 ) -> Result<std::collections::HashSet<String>, Error> {
2498 use std::collections::HashSet;
2499
2500 let total = self
2501 .samples_count(dataset_id, None, &[], groups, &[])
2502 .await?
2503 .total as usize;
2504
2505 if total == 0 {
2506 return Ok(HashSet::new());
2507 }
2508
2509 let mut names = HashSet::with_capacity(total);
2510 let mut continue_token: Option<String> = None;
2511 let mut current = 0;
2512
2513 loop {
2514 let params = SamplesListParams {
2515 dataset_id,
2516 annotation_set_id: None,
2517 types: vec![], // No type filter - we just want names
2518 group_names: groups.to_vec(),
2519 continue_token: continue_token.clone(),
2520 };
2521
2522 let result: SamplesListResult =
2523 self.rpc("samples.list".to_owned(), Some(params)).await?;
2524 current += result.samples.len();
2525 continue_token = result.continue_token;
2526
2527 if result.samples.is_empty() {
2528 break;
2529 }
2530
2531 // Extract sample names (normalized without extension)
2532 for sample in result.samples {
2533 if let Some(name) = sample.name() {
2534 names.insert(name);
2535 }
2536 }
2537
2538 if let Some(ref p) = progress {
2539 let _ = p
2540 .send(Progress {
2541 current,
2542 total,
2543 status: None,
2544 })
2545 .await;
2546 }
2547
2548 match &continue_token {
2549 Some(token) if !token.is_empty() => continue,
2550 _ => break,
2551 }
2552 }
2553
2554 Ok(names)
2555 }
2556
2557 async fn fetch_samples_paginated(
2558 &self,
2559 context: FetchContext<'_>,
2560 total: usize,
2561 progress: Option<Sender<Progress>>,
2562 ) -> Result<Vec<Sample>, Error> {
2563 let mut samples = vec![];
2564 let mut continue_token: Option<String> = None;
2565 let mut current = 0;
2566
2567 loop {
2568 let params = SamplesListParams {
2569 dataset_id: context.dataset_id,
2570 annotation_set_id: context.annotation_set_id,
2571 types: context.types.clone(),
2572 group_names: context.groups.to_vec(),
2573 continue_token: continue_token.clone(),
2574 };
2575
2576 let result: SamplesListResult =
2577 self.rpc("samples.list".to_owned(), Some(params)).await?;
2578 current += result.samples.len();
2579 continue_token = result.continue_token;
2580
2581 if result.samples.is_empty() {
2582 break;
2583 }
2584
2585 samples.append(
2586 &mut result
2587 .samples
2588 .into_iter()
2589 .map(|s| {
2590 // Use server's frame_number if valid (>= 0 after deserialization)
2591 // Otherwise parse from image_name as fallback
2592 // This ensures we respect explicit frame_number from uploads
2593 // while still handling legacy data that only has filename encoding
2594 let frame_number = s.frame_number.or_else(|| {
2595 Self::parse_frame_from_image_name(
2596 s.image_name.as_ref(),
2597 s.sequence_name.as_ref(),
2598 )
2599 });
2600
2601 let mut anns = s.annotations().to_vec();
2602 for ann in &mut anns {
2603 // Set annotation fields from parent sample
2604 ann.set_name(s.name());
2605 ann.set_group(s.group().cloned());
2606 ann.set_sequence_name(s.sequence_name().cloned());
2607 ann.set_frame_number(frame_number);
2608 Self::set_label_index_from_map(ann, context.labels);
2609 }
2610 s.with_annotations(anns).with_frame_number(frame_number)
2611 })
2612 .collect::<Vec<_>>(),
2613 );
2614
2615 if let Some(progress) = &progress {
2616 let _ = progress
2617 .send(Progress {
2618 current,
2619 total,
2620 status: None,
2621 })
2622 .await;
2623 }
2624
2625 match &continue_token {
2626 Some(token) if !token.is_empty() => continue,
2627 _ => break,
2628 }
2629 }
2630
2631 drop(progress);
2632 Ok(samples)
2633 }
2634
2635 /// Populates (imports) samples into a dataset using the `samples.populate2`
2636 /// API.
2637 ///
2638 /// This method creates new samples in the specified dataset, optionally
2639 /// with annotations and sensor data files. For each sample, the `files`
2640 /// field is checked for local file paths. If a filename is a valid path
2641 /// to an existing file, the file will be automatically uploaded to S3
2642 /// using presigned URLs returned by the server. The filename in the
2643 /// request is replaced with the basename (path removed) before sending
2644 /// to the server.
2645 ///
2646 /// # Important Notes
2647 ///
2648 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2649 /// annotations. Without it, the server will accept the request but will
2650 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2651 /// available annotation sets for a dataset, or create a new one via the
2652 /// Studio UI.
2653 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2654 /// boxes. Divide pixel coordinates by image width/height before creating
2655 /// [`Box2d`](crate::Box2d) annotations.
2656 /// - **Files are uploaded automatically** when the filename is a valid
2657 /// local path. The method will replace the full path with just the
2658 /// basename before sending to the server.
2659 /// - **Image dimensions are extracted automatically** for image files using
2660 /// the `imagesize` crate. The width/height are sent to the server and
2661 /// stored in the `image_files` table. These dimensions are returned by
2662 /// `samples.list` and used in [`samples_dataframe`](crate::samples_dataframe)
2663 /// to populate the `size` column.
2664 /// - **UUIDs are generated automatically** if not provided. If you need
2665 /// deterministic UUIDs, set `sample.uuid` explicitly before calling.
2666 ///
2667 /// # Arguments
2668 ///
2669 /// * `dataset_id` - The ID of the dataset to populate
2670 /// * `annotation_set_id` - **Required** if samples contain annotations,
2671 /// otherwise they will be ignored. Query with
2672 /// [`Client::annotation_sets`].
2673 /// * `samples` - Vector of samples to import with metadata and file
2674 /// references. For files, use the full local path - it will be uploaded
2675 /// automatically. UUIDs and image dimensions will be
2676 /// auto-generated/extracted if not provided.
2677 /// * `progress` - Optional channel for progress updates
2678 ///
2679 /// # Progress
2680 ///
2681 /// Reports progress with `status: None` as each sample's files are
2682 /// uploaded. Progress unit is samples (not individual files). Each
2683 /// sample may contain multiple files (image, lidar, radar, etc.) which
2684 /// are all uploaded before the sample is counted as complete.
2685 ///
2686 /// # Returns
2687 ///
2688 /// Returns the API result with sample UUIDs and upload status.
2689 ///
2690 /// # Example
2691 ///
2692 /// ```no_run
2693 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2694 ///
2695 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2696 /// # let client = Client::new()?.with_login("user", "pass").await?;
2697 /// # let dataset_id = DatasetID::from(1);
2698 /// // Query available annotation sets for the dataset
2699 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2700 /// let annotation_set_id = annotation_sets
2701 /// .first()
2702 /// .ok_or_else(|| {
2703 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2704 /// })?
2705 /// .id();
2706 ///
2707 /// // Create sample with annotation (UUID will be auto-generated)
2708 /// let mut sample = Sample::new();
2709 /// sample.width = Some(1920);
2710 /// sample.height = Some(1080);
2711 /// sample.group = Some("train".to_string());
2712 ///
2713 /// // Add file - use full path to local file, it will be uploaded automatically
2714 /// sample.files = vec![SampleFile::with_filename(
2715 /// "image".to_string(),
2716 /// "/path/to/image.jpg".to_string(),
2717 /// )];
2718 ///
2719 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2720 /// let mut annotation = Annotation::new();
2721 /// annotation.set_label(Some("person".to_string()));
2722 /// // Normalize pixel coordinates by dividing by image dimensions
2723 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2724 /// annotation.set_box2d(Some(bbox));
2725 /// sample.annotations = vec![annotation];
2726 ///
2727 /// // Populate with annotation_set_id (REQUIRED for annotations)
2728 /// let result = client
2729 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2730 /// .await?;
2731 /// # Ok(())
2732 /// # }
2733 /// ```
2734 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2735 pub async fn populate_samples(
2736 &self,
2737 dataset_id: DatasetID,
2738 annotation_set_id: Option<AnnotationSetID>,
2739 samples: Vec<Sample>,
2740 progress: Option<Sender<Progress>>,
2741 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2742 self.populate_samples_with_concurrency(
2743 dataset_id,
2744 annotation_set_id,
2745 samples,
2746 progress,
2747 None,
2748 )
2749 .await
2750 }
2751
2752 /// Populate samples with custom upload concurrency.
2753 ///
2754 /// Same as [`populate_samples`](Self::populate_samples) but allows
2755 /// specifying the maximum number of concurrent file uploads. Use this
2756 /// for bulk imports where higher concurrency can significantly reduce
2757 /// upload time.
2758 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2759 pub async fn populate_samples_with_concurrency(
2760 &self,
2761 dataset_id: DatasetID,
2762 annotation_set_id: Option<AnnotationSetID>,
2763 samples: Vec<Sample>,
2764 progress: Option<Sender<Progress>>,
2765 concurrency: Option<usize>,
2766 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2767 use crate::api::SamplesPopulateParams;
2768
2769 // Track which files need to be uploaded
2770 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
2771
2772 // Process samples to detect local files and generate UUIDs
2773 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
2774
2775 let has_files_to_upload = !files_to_upload.is_empty();
2776
2777 // Call populate API with presigned_urls=true if we have files to upload
2778 let params = SamplesPopulateParams {
2779 dataset_id,
2780 annotation_set_id,
2781 presigned_urls: Some(has_files_to_upload),
2782 samples,
2783 };
2784
2785 let results: Vec<crate::SamplesPopulateResult> = self
2786 .rpc("samples.populate2".to_owned(), Some(params))
2787 .await?;
2788
2789 // Upload files if we have any
2790 if has_files_to_upload {
2791 self.upload_sample_files(&results, files_to_upload, progress, concurrency)
2792 .await?;
2793 }
2794
2795 Ok(results)
2796 }
2797
2798 fn prepare_samples_for_upload(
2799 &self,
2800 samples: Vec<Sample>,
2801 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2802 ) -> Result<Vec<Sample>, Error> {
2803 Ok(samples
2804 .into_iter()
2805 .map(|mut sample| {
2806 // Generate UUID if not provided
2807 if sample.uuid.is_none() {
2808 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2809 }
2810
2811 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2812
2813 // Process files: detect local paths and queue for upload
2814 let files_copy = sample.files.clone();
2815 let updated_files: Vec<crate::SampleFile> = files_copy
2816 .iter()
2817 .map(|file| {
2818 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2819 })
2820 .collect();
2821
2822 sample.files = updated_files;
2823 sample
2824 })
2825 .collect())
2826 }
2827
2828 fn process_sample_file(
2829 &self,
2830 file: &crate::SampleFile,
2831 sample_uuid: &str,
2832 sample: &mut Sample,
2833 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2834 ) -> crate::SampleFile {
2835 use std::path::Path;
2836
2837 // Handle files with raw bytes (e.g., from ZIP archives)
2838 if let Some(bytes) = file.bytes()
2839 && let Some(filename) = file.filename()
2840 {
2841 // For image files with bytes, try to extract dimensions if not already set
2842 if file.file_type() == "image"
2843 && (sample.width.is_none() || sample.height.is_none())
2844 && let Ok(size) = imagesize::blob_size(bytes)
2845 {
2846 sample.width = Some(size.width as u32);
2847 sample.height = Some(size.height as u32);
2848 }
2849
2850 // Store the bytes for later upload
2851 files_to_upload.push((
2852 sample_uuid.to_string(),
2853 file.file_type().to_string(),
2854 FileSource::Bytes(bytes.to_vec()),
2855 filename.to_string(),
2856 ));
2857
2858 // Return SampleFile with just the filename
2859 return crate::SampleFile::with_filename(
2860 file.file_type().to_string(),
2861 filename.to_string(),
2862 );
2863 }
2864
2865 // Handle files with local paths
2866 if let Some(filename) = file.filename() {
2867 let path = Path::new(filename);
2868
2869 // Check if this is a valid local file path
2870 if path.exists()
2871 && path.is_file()
2872 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2873 {
2874 // For image files, try to extract dimensions if not already set
2875 if file.file_type() == "image"
2876 && (sample.width.is_none() || sample.height.is_none())
2877 && let Ok(size) = imagesize::size(path)
2878 {
2879 sample.width = Some(size.width as u32);
2880 sample.height = Some(size.height as u32);
2881 }
2882
2883 // Store the full path for later upload
2884 files_to_upload.push((
2885 sample_uuid.to_string(),
2886 file.file_type().to_string(),
2887 FileSource::Path(path.to_path_buf()),
2888 basename.to_string(),
2889 ));
2890
2891 // Return SampleFile with just the basename
2892 return crate::SampleFile::with_filename(
2893 file.file_type().to_string(),
2894 basename.to_string(),
2895 );
2896 }
2897 }
2898 // Return the file unchanged if not a local path
2899 file.clone()
2900 }
2901
2902 async fn upload_sample_files(
2903 &self,
2904 results: &[crate::SamplesPopulateResult],
2905 files_to_upload: Vec<(String, String, FileSource, String)>,
2906 progress: Option<Sender<Progress>>,
2907 concurrency: Option<usize>,
2908 ) -> Result<(), Error> {
2909 // Build a map from (sample_uuid, basename) -> file source
2910 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
2911 for (uuid, _file_type, source, basename) in files_to_upload {
2912 upload_map.insert((uuid, basename), source);
2913 }
2914
2915 let http = self.bulk_http.clone();
2916
2917 // Extract the data we need for parallel upload
2918 let upload_tasks: Vec<_> = results
2919 .iter()
2920 .map(|result| (result.uuid.clone(), result.urls.clone()))
2921 .collect();
2922
2923 parallel_foreach_items(
2924 upload_tasks,
2925 progress.clone(),
2926 concurrency,
2927 move |(uuid, urls)| {
2928 let http = http.clone();
2929 let upload_map = upload_map.clone();
2930
2931 async move {
2932 // Upload all files for this sample
2933 for url_info in &urls {
2934 if let Some(source) =
2935 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2936 {
2937 match source {
2938 FileSource::Path(path) => {
2939 upload_file_to_presigned_url(
2940 http.clone(),
2941 &url_info.url,
2942 path.clone(),
2943 )
2944 .await?;
2945 }
2946 FileSource::Bytes(bytes) => {
2947 upload_bytes_to_presigned_url(
2948 http.clone(),
2949 &url_info.url,
2950 bytes.clone(),
2951 &url_info.filename,
2952 )
2953 .await?;
2954 }
2955 }
2956 }
2957 }
2958
2959 Ok(())
2960 }
2961 },
2962 )
2963 .await
2964 }
2965
2966 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2967 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
2968 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
2969 if !url.starts_with("http://") && !url.starts_with("https://") {
2970 return Err(Error::InvalidParameters(format!(
2971 "Invalid URL (must be absolute): {}",
2972 url
2973 )));
2974 }
2975
2976 let resp = self.bulk_http.get(url).send().await?;
2977
2978 if !resp.status().is_success() {
2979 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2980 }
2981
2982 let bytes = resp.bytes().await?;
2983 Ok(bytes.to_vec())
2984 }
2985
2986 /// Get samples as a DataFrame with complete 2025.10 schema.
2987 ///
2988 /// This is the recommended method for obtaining dataset annotations in
2989 /// DataFrame format. It includes all sample metadata (size, location,
2990 /// pose, degradation) as optional columns.
2991 ///
2992 /// # Arguments
2993 ///
2994 /// * `dataset_id` - Dataset identifier
2995 /// * `annotation_set_id` - Optional annotation set filter
2996 /// * `groups` - Dataset groups to include (train, val, test)
2997 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2998 /// * `progress` - Optional progress callback
2999 ///
3000 /// # Progress
3001 ///
3002 /// Reports progress with `status: None` as samples are fetched from the
3003 /// server in paginated batches. Progress unit is samples fetched. This
3004 /// method delegates to [`samples()`](Self::samples) and shares its
3005 /// progress behavior.
3006 ///
3007 /// # Example
3008 ///
3009 /// ```rust,no_run
3010 /// use edgefirst_client::Client;
3011 ///
3012 /// # async fn example() -> Result<(), edgefirst_client::Error> {
3013 /// # let client = Client::new()?;
3014 /// # let dataset_id = 1.into();
3015 /// # let annotation_set_id = 1.into();
3016 /// let df = client
3017 /// .samples_dataframe(
3018 /// dataset_id,
3019 /// Some(annotation_set_id),
3020 /// &["train".to_string()],
3021 /// &[],
3022 /// None,
3023 /// )
3024 /// .await?;
3025 /// println!("DataFrame shape: {:?}", df.shape());
3026 /// # Ok(())
3027 /// # }
3028 /// ```
3029 #[cfg(feature = "polars")]
3030 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3031 pub async fn samples_dataframe(
3032 &self,
3033 dataset_id: DatasetID,
3034 annotation_set_id: Option<AnnotationSetID>,
3035 groups: &[String],
3036 types: &[AnnotationType],
3037 progress: Option<Sender<Progress>>,
3038 ) -> Result<DataFrame, Error> {
3039 use crate::dataset::samples_dataframe;
3040
3041 let samples = self
3042 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
3043 .await?;
3044 samples_dataframe(&samples)
3045 }
3046
3047 /// Update image dimensions for existing samples in a dataset.
3048 ///
3049 /// This is useful for backfilling width/height data on samples that were
3050 /// uploaded before dimension extraction was added, or where dimensions
3051 /// could not be determined at upload time.
3052 ///
3053 /// # Arguments
3054 ///
3055 /// * `dataset_id` - The dataset containing the samples
3056 /// * `updates` - List of dimension updates (sample ID, width, height)
3057 ///
3058 /// # Returns
3059 ///
3060 /// The number of samples that were successfully updated.
3061 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, updates), fields(dataset_id = %dataset_id, count = updates.len())))]
3062 pub async fn update_sample_dimensions(
3063 &self,
3064 dataset_id: DatasetID,
3065 updates: Vec<crate::SampleDimensionUpdate>,
3066 ) -> Result<u64, Error> {
3067 use crate::api::SamplesUpdateDimensionsParams;
3068
3069 if updates.is_empty() {
3070 return Ok(0);
3071 }
3072
3073 // Batch in groups of 500 to stay within server limits
3074 let mut total_updated = 0u64;
3075 for chunk in updates.chunks(500) {
3076 let params = SamplesUpdateDimensionsParams {
3077 dataset_id,
3078 samples: chunk.to_vec(),
3079 };
3080 let result: crate::SamplesUpdateDimensionsResult = self
3081 .rpc("samples.update_dimensions".to_owned(), Some(params))
3082 .await?;
3083 total_updated += result.updated;
3084 }
3085 Ok(total_updated)
3086 }
3087
3088 /// Backfill missing image dimensions for a dataset.
3089 ///
3090 /// Downloads image data for samples that are missing width/height,
3091 /// extracts the dimensions using the `imagesize` crate, and updates
3092 /// the server with the computed values.
3093 ///
3094 /// This is a one-time repair operation for datasets that were uploaded
3095 /// before the client added automatic dimension extraction.
3096 ///
3097 /// # Arguments
3098 ///
3099 /// * `dataset_id` - The dataset to backfill
3100 /// * `progress` - Optional progress channel
3101 ///
3102 /// # Returns
3103 ///
3104 /// The number of samples whose dimensions were updated.
3105 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(dataset_id = %dataset_id)))]
3106 pub async fn backfill_sample_dimensions(
3107 &self,
3108 dataset_id: DatasetID,
3109 progress: Option<Sender<Progress>>,
3110 ) -> Result<u64, Error> {
3111 // Fetch all samples; listing progress is not forwarded to the caller
3112 // since it would interleave with the dimension-computing phase.
3113 let samples = self.samples(dataset_id, None, &[], &[], &[], None).await?;
3114
3115 // Filter to samples missing dimensions
3116 let missing: Vec<&Sample> = samples
3117 .iter()
3118 .filter(|s| s.width.is_none() || s.height.is_none())
3119 .collect();
3120
3121 if missing.is_empty() {
3122 return Ok(0);
3123 }
3124
3125 let total = missing.len();
3126 let mut updates: Vec<crate::SampleDimensionUpdate> = Vec::with_capacity(total);
3127
3128 for (i, sample) in missing.into_iter().enumerate() {
3129 let current = i + 1;
3130
3131 let Some(id) = sample.id() else {
3132 Self::send_progress(&progress, current, total).await;
3133 continue;
3134 };
3135
3136 let Some(url) = sample.image_url() else {
3137 #[cfg(feature = "profiling")]
3138 tracing::warn!(sample_id = %id, "skipping sample: no image URL");
3139 Self::send_progress(&progress, current, total).await;
3140 continue;
3141 };
3142
3143 // Download image data to determine dimensions
3144 let resp = self.bulk_http.get(url).send().await;
3145 let Ok(resp) = resp else {
3146 #[cfg(feature = "profiling")]
3147 tracing::warn!(sample_id = %id, "skipping sample: download failed");
3148 Self::send_progress(&progress, current, total).await;
3149 continue;
3150 };
3151
3152 // Skip non-success responses (e.g. 404, 500) rather than parsing error pages
3153 if !resp.status().is_success() {
3154 #[cfg(feature = "profiling")]
3155 tracing::warn!(sample_id = %id, status = %resp.status(), "skipping sample: non-success HTTP status");
3156 Self::send_progress(&progress, current, total).await;
3157 continue;
3158 }
3159
3160 let Ok(bytes) = resp.bytes().await else {
3161 #[cfg(feature = "profiling")]
3162 tracing::warn!(sample_id = %id, "skipping sample: failed to read response body");
3163 Self::send_progress(&progress, current, total).await;
3164 continue;
3165 };
3166
3167 // Extract dimensions from the downloaded image
3168 let Ok(size) = imagesize::blob_size(&bytes) else {
3169 #[cfg(feature = "profiling")]
3170 tracing::warn!(sample_id = %id, "skipping sample: could not determine dimensions");
3171 Self::send_progress(&progress, current, total).await;
3172 continue;
3173 };
3174
3175 let (Ok(width), Ok(height)) = (u32::try_from(size.width), u32::try_from(size.height))
3176 else {
3177 #[cfg(feature = "profiling")]
3178 tracing::warn!(sample_id = %id, width = size.width, height = size.height, "skipping sample: dimensions overflow u32");
3179 Self::send_progress(&progress, current, total).await;
3180 continue;
3181 };
3182
3183 updates.push(crate::SampleDimensionUpdate { id, width, height });
3184 Self::send_progress(&progress, current, total).await;
3185 }
3186
3187 // Send updates to server
3188 self.update_sample_dimensions(dataset_id, updates).await
3189 }
3190
3191 /// Emit a progress event if a progress channel is provided.
3192 async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
3193 if let Some(tx) = progress {
3194 let _ = tx
3195 .send(Progress {
3196 current,
3197 total,
3198 status: Some("Computing dimensions".to_string()),
3199 })
3200 .await;
3201 }
3202 }
3203
3204 /// List available snapshots. If a name is provided, only snapshots
3205 /// containing that name are returned.
3206 ///
3207 /// Results are sorted by match quality: exact matches first, then
3208 /// case-insensitive exact matches, then shorter descriptions (more
3209 /// specific), then alphabetically.
3210 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3211 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
3212 let snapshots: Vec<Snapshot> = self
3213 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
3214 .await?;
3215 if let Some(name) = name {
3216 Ok(filter_and_sort_by_name(snapshots, name, |s| {
3217 s.description()
3218 }))
3219 } else {
3220 Ok(snapshots)
3221 }
3222 }
3223
3224 /// Get the snapshot with the specified id.
3225 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3226 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
3227 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3228 self.rpc("snapshots.get".to_owned(), Some(params)).await
3229 }
3230
3231 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
3232 ///
3233 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
3234 /// pairs) that serve two primary purposes:
3235 ///
3236 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
3237 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
3238 /// restored with AGTG (Automatic Ground Truth Generation) and optional
3239 /// auto-depth processing.
3240 ///
3241 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
3242 /// migration between EdgeFirst Studio instances using the create →
3243 /// download → upload → restore workflow.
3244 ///
3245 /// Large files are automatically chunked into 100MB parts and uploaded
3246 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
3247 /// is streamed without loading into memory, maintaining constant memory
3248 /// usage.
3249 ///
3250 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3251 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
3252 /// better for large files to avoid timeout issues. Higher values (16-32)
3253 /// are better for many small files.
3254 ///
3255 /// # Arguments
3256 ///
3257 /// * `path` - Local file path to MCAP file or directory containing
3258 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
3259 /// * `progress` - Optional channel to receive upload progress updates
3260 ///
3261 /// # Progress
3262 ///
3263 /// Reports progress with `status: None` as file data is uploaded. Progress
3264 /// unit is bytes uploaded. For single files, total is the file size. For
3265 /// directories, total is the combined size of all files.
3266 ///
3267 /// # Returns
3268 ///
3269 /// Returns a `Snapshot` object with ID, description, status, path, and
3270 /// creation timestamp on success.
3271 ///
3272 /// # Errors
3273 ///
3274 /// Returns an error if:
3275 /// * Path doesn't exist or contains invalid UTF-8
3276 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
3277 /// * Upload fails or network error occurs
3278 /// * Server rejects the snapshot
3279 ///
3280 /// # Example
3281 ///
3282 /// ```no_run
3283 /// # use edgefirst_client::{Client, Progress};
3284 /// # use tokio::sync::mpsc;
3285 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3286 /// let client = Client::new()?.with_token_path(None)?;
3287 ///
3288 /// // Upload MCAP file with progress tracking
3289 /// let (tx, mut rx) = mpsc::channel(1);
3290 /// tokio::spawn(async move {
3291 /// while let Some(Progress {
3292 /// current,
3293 /// total,
3294 /// status,
3295 /// }) = rx.recv().await
3296 /// {
3297 /// println!(
3298 /// "{}: {}/{} bytes ({:.1}%)",
3299 /// status.as_deref().unwrap_or("Upload"),
3300 /// current,
3301 /// total,
3302 /// (current as f64 / total as f64) * 100.0
3303 /// );
3304 /// }
3305 /// });
3306 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
3307 /// println!("Created snapshot: {:?}", snapshot.id());
3308 ///
3309 /// // Upload dataset directory (no progress)
3310 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
3311 /// # Ok(())
3312 /// # }
3313 /// ```
3314 ///
3315 /// # See Also
3316 ///
3317 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3318 /// dataset
3319 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3320 /// data
3321 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3322 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3323 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
3324 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3325 pub async fn create_snapshot(
3326 &self,
3327 path: &str,
3328 progress: Option<Sender<Progress>>,
3329 ) -> Result<Snapshot, Error> {
3330 let path = Path::new(path);
3331
3332 if path.is_dir() {
3333 let path_str = path.to_str().ok_or_else(|| {
3334 Error::IoError(std::io::Error::new(
3335 std::io::ErrorKind::InvalidInput,
3336 "Path contains invalid UTF-8",
3337 ))
3338 })?;
3339 return self.create_snapshot_folder(path_str, progress).await;
3340 }
3341
3342 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3343 Error::IoError(std::io::Error::new(
3344 std::io::ErrorKind::InvalidInput,
3345 "Invalid filename",
3346 ))
3347 })?;
3348 let total = path.metadata()?.len() as usize;
3349 let current = Arc::new(AtomicUsize::new(0));
3350
3351 if let Some(progress) = &progress {
3352 let _ = progress
3353 .send(Progress {
3354 current: 0,
3355 total,
3356 status: None,
3357 })
3358 .await;
3359 }
3360
3361 let params = SnapshotCreateMultipartParams {
3362 snapshot_name: name.to_owned(),
3363 keys: vec![name.to_owned()],
3364 file_sizes: vec![total],
3365 snapshot_type: None,
3366 };
3367 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3368 .rpc(
3369 "snapshots.create_upload_url_multipart".to_owned(),
3370 Some(params),
3371 )
3372 .await?;
3373
3374 let snapshot_id = match multipart.get("snapshot_id") {
3375 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3376 _ => return Err(Error::InvalidResponse),
3377 };
3378
3379 let snapshot = self.snapshot(snapshot_id).await?;
3380 let part_prefix = snapshot
3381 .path()
3382 .split("::/")
3383 .last()
3384 .ok_or(Error::InvalidResponse)?
3385 .to_owned();
3386 let part_key = format!("{}/{}", part_prefix, name);
3387 let mut part = match multipart.get(&part_key) {
3388 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3389 _ => return Err(Error::InvalidResponse),
3390 }
3391 .clone();
3392 part.key = Some(part_key);
3393
3394 let params = upload_multipart(
3395 self.bulk_http.clone(),
3396 part.clone(),
3397 path.to_path_buf(),
3398 total,
3399 current,
3400 progress.clone(),
3401 )
3402 .await?;
3403
3404 let complete: String = self
3405 .rpc(
3406 "snapshots.complete_multipart_upload".to_owned(),
3407 Some(params),
3408 )
3409 .await?;
3410 debug!("Snapshot Multipart Complete: {:?}", complete);
3411
3412 let params: SnapshotStatusParams = SnapshotStatusParams {
3413 snapshot_id,
3414 status: "available".to_owned(),
3415 };
3416 let _: SnapshotStatusResult = self
3417 .rpc("snapshots.update".to_owned(), Some(params))
3418 .await?;
3419
3420 if let Some(progress) = progress {
3421 drop(progress);
3422 }
3423
3424 self.snapshot(snapshot_id).await
3425 }
3426
3427 async fn create_snapshot_folder(
3428 &self,
3429 path: &str,
3430 progress: Option<Sender<Progress>>,
3431 ) -> Result<Snapshot, Error> {
3432 let path = Path::new(path);
3433 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3434 Error::IoError(std::io::Error::new(
3435 std::io::ErrorKind::InvalidInput,
3436 "Invalid directory name",
3437 ))
3438 })?;
3439
3440 let files = WalkDir::new(path)
3441 .into_iter()
3442 .filter_map(|entry| entry.ok())
3443 .filter(|entry| entry.file_type().is_file())
3444 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3445 .collect::<Vec<_>>();
3446
3447 let total: usize = files
3448 .iter()
3449 .filter_map(|file| path.join(file).metadata().ok())
3450 .map(|metadata| metadata.len() as usize)
3451 .sum();
3452 let current = Arc::new(AtomicUsize::new(0));
3453
3454 if let Some(progress) = &progress {
3455 let _ = progress
3456 .send(Progress {
3457 current: 0,
3458 total,
3459 status: None,
3460 })
3461 .await;
3462 }
3463
3464 let keys = files
3465 .iter()
3466 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3467 .collect::<Vec<_>>();
3468 let file_sizes = files
3469 .iter()
3470 .filter_map(|key| path.join(key).metadata().ok())
3471 .map(|metadata| metadata.len() as usize)
3472 .collect::<Vec<_>>();
3473
3474 let params = SnapshotCreateMultipartParams {
3475 snapshot_name: name.to_owned(),
3476 keys,
3477 file_sizes,
3478 snapshot_type: None,
3479 };
3480
3481 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3482 .rpc(
3483 "snapshots.create_upload_url_multipart".to_owned(),
3484 Some(params),
3485 )
3486 .await?;
3487
3488 let snapshot_id = match multipart.get("snapshot_id") {
3489 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3490 _ => return Err(Error::InvalidResponse),
3491 };
3492
3493 let snapshot = self.snapshot(snapshot_id).await?;
3494 let part_prefix = snapshot
3495 .path()
3496 .split("::/")
3497 .last()
3498 .ok_or(Error::InvalidResponse)?
3499 .to_owned();
3500
3501 for file in files {
3502 let file_str = file.to_str().ok_or_else(|| {
3503 Error::IoError(std::io::Error::new(
3504 std::io::ErrorKind::InvalidInput,
3505 "File path contains invalid UTF-8",
3506 ))
3507 })?;
3508 let part_key = format!("{}/{}", part_prefix, file_str);
3509 let mut part = match multipart.get(&part_key) {
3510 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3511 _ => return Err(Error::InvalidResponse),
3512 }
3513 .clone();
3514 part.key = Some(part_key);
3515
3516 let params = upload_multipart(
3517 self.bulk_http.clone(),
3518 part.clone(),
3519 path.join(file),
3520 total,
3521 current.clone(),
3522 progress.clone(),
3523 )
3524 .await?;
3525
3526 let complete: String = self
3527 .rpc(
3528 "snapshots.complete_multipart_upload".to_owned(),
3529 Some(params),
3530 )
3531 .await?;
3532 debug!("Snapshot Part Complete: {:?}", complete);
3533 }
3534
3535 let params = SnapshotStatusParams {
3536 snapshot_id,
3537 status: "available".to_owned(),
3538 };
3539 let _: SnapshotStatusResult = self
3540 .rpc("snapshots.update".to_owned(), Some(params))
3541 .await?;
3542
3543 if let Some(progress) = progress {
3544 drop(progress);
3545 }
3546
3547 self.snapshot(snapshot_id).await
3548 }
3549
3550 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3551 ///
3552 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3553 /// This format is the native EdgeFirst Dataset Format used for efficient
3554 /// dataset storage and transfer.
3555 ///
3556 /// # Arguments
3557 ///
3558 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3559 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3560 /// * `description` - Optional description for the snapshot
3561 /// * `progress` - Optional progress channel for upload tracking
3562 ///
3563 /// # File Requirements
3564 ///
3565 /// - Arrow file must have `.arrow` extension
3566 /// - ZIP file must have `.zip` extension
3567 /// - Both files must exist and be readable
3568 ///
3569 /// # Example
3570 ///
3571 /// ```no_run
3572 /// # use edgefirst_client::Client;
3573 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3574 /// let client = Client::new()?.with_token_path(None)?;
3575 ///
3576 /// let snapshot = client
3577 /// .create_snapshot_edgefirst_format(
3578 /// "dataset.arrow",
3579 /// "dataset.zip",
3580 /// Some("My Dataset Snapshot"),
3581 /// None,
3582 /// )
3583 /// .await?;
3584 /// println!("Created snapshot: {}", snapshot.id());
3585 /// # Ok(())
3586 /// # }
3587 /// ```
3588 ///
3589 /// # See Also
3590 ///
3591 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3592 /// folder
3593 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3594 /// dataset
3595 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3596 pub async fn create_snapshot_edgefirst_format(
3597 &self,
3598 arrow_path: &str,
3599 zip_path: &str,
3600 description: Option<&str>,
3601 progress: Option<Sender<Progress>>,
3602 ) -> Result<Snapshot, Error> {
3603 let arrow_path = Path::new(arrow_path);
3604 let zip_path = Path::new(zip_path);
3605
3606 // Validate files exist
3607 if !arrow_path.exists() {
3608 return Err(Error::IoError(std::io::Error::new(
3609 std::io::ErrorKind::NotFound,
3610 format!("Arrow file not found: {}", arrow_path.display()),
3611 )));
3612 }
3613 if !zip_path.exists() {
3614 return Err(Error::IoError(std::io::Error::new(
3615 std::io::ErrorKind::NotFound,
3616 format!("ZIP file not found: {}", zip_path.display()),
3617 )));
3618 }
3619
3620 // Get file names
3621 let arrow_name = arrow_path
3622 .file_name()
3623 .and_then(|n| n.to_str())
3624 .ok_or_else(|| {
3625 Error::IoError(std::io::Error::new(
3626 std::io::ErrorKind::InvalidInput,
3627 "Invalid Arrow filename",
3628 ))
3629 })?;
3630 let zip_name = zip_path
3631 .file_name()
3632 .and_then(|n| n.to_str())
3633 .ok_or_else(|| {
3634 Error::IoError(std::io::Error::new(
3635 std::io::ErrorKind::InvalidInput,
3636 "Invalid ZIP filename",
3637 ))
3638 })?;
3639
3640 // Generate snapshot name from arrow file (without extension)
3641 let snapshot_name = description
3642 .map(|s| s.to_string())
3643 .or_else(|| {
3644 arrow_path
3645 .file_stem()
3646 .and_then(|s| s.to_str())
3647 .map(|s| s.to_string())
3648 })
3649 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3650
3651 // Calculate file sizes
3652 let arrow_size = arrow_path.metadata()?.len() as usize;
3653 let zip_size = zip_path.metadata()?.len() as usize;
3654 let total = arrow_size + zip_size;
3655 let current = Arc::new(AtomicUsize::new(0));
3656
3657 if let Some(progress) = &progress {
3658 let _ = progress
3659 .send(Progress {
3660 current: 0,
3661 total,
3662 status: None,
3663 })
3664 .await;
3665 }
3666
3667 // Create multipart upload request with "ziparrow" type
3668 let params = SnapshotCreateMultipartParams {
3669 snapshot_name,
3670 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
3671 file_sizes: vec![arrow_size, zip_size],
3672 snapshot_type: Some("ziparrow".to_string()),
3673 };
3674
3675 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3676 .rpc(
3677 "snapshots.create_upload_url_multipart".to_owned(),
3678 Some(params),
3679 )
3680 .await?;
3681
3682 let snapshot_id = match multipart.get("snapshot_id") {
3683 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3684 _ => return Err(Error::InvalidResponse),
3685 };
3686
3687 let snapshot = self.snapshot(snapshot_id).await?;
3688 let part_prefix = snapshot
3689 .path()
3690 .split("::/")
3691 .last()
3692 .ok_or(Error::InvalidResponse)?
3693 .to_owned();
3694
3695 // Upload Arrow file
3696 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
3697 let mut arrow_part = match multipart.get(&arrow_key) {
3698 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3699 _ => return Err(Error::InvalidResponse),
3700 };
3701 arrow_part.key = Some(arrow_key);
3702
3703 let params = upload_multipart(
3704 self.bulk_http.clone(),
3705 arrow_part,
3706 arrow_path.to_path_buf(),
3707 total,
3708 current.clone(),
3709 progress.clone(),
3710 )
3711 .await?;
3712
3713 let _: String = self
3714 .rpc(
3715 "snapshots.complete_multipart_upload".to_owned(),
3716 Some(params),
3717 )
3718 .await?;
3719 debug!("Arrow file upload complete");
3720
3721 // Upload ZIP file
3722 let zip_key = format!("{}/{}", part_prefix, zip_name);
3723 let mut zip_part = match multipart.get(&zip_key) {
3724 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3725 _ => return Err(Error::InvalidResponse),
3726 };
3727 zip_part.key = Some(zip_key);
3728
3729 let params = upload_multipart(
3730 self.bulk_http.clone(),
3731 zip_part,
3732 zip_path.to_path_buf(),
3733 total,
3734 current.clone(),
3735 progress.clone(),
3736 )
3737 .await?;
3738
3739 let _: String = self
3740 .rpc(
3741 "snapshots.complete_multipart_upload".to_owned(),
3742 Some(params),
3743 )
3744 .await?;
3745 debug!("ZIP file upload complete");
3746
3747 // Mark snapshot as available
3748 let params = SnapshotStatusParams {
3749 snapshot_id,
3750 status: "available".to_owned(),
3751 };
3752 let _: SnapshotStatusResult = self
3753 .rpc("snapshots.update".to_owned(), Some(params))
3754 .await?;
3755
3756 if let Some(progress) = progress {
3757 drop(progress);
3758 }
3759
3760 self.snapshot(snapshot_id).await
3761 }
3762
3763 /// Delete a snapshot from EdgeFirst Studio.
3764 ///
3765 /// Permanently removes a snapshot and its associated data. This operation
3766 /// cannot be undone.
3767 ///
3768 /// # Arguments
3769 ///
3770 /// * `snapshot_id` - The snapshot ID to delete
3771 ///
3772 /// # Errors
3773 ///
3774 /// Returns an error if:
3775 /// * Snapshot doesn't exist
3776 /// * User lacks permission to delete the snapshot
3777 /// * Server error occurs
3778 ///
3779 /// # Example
3780 ///
3781 /// ```no_run
3782 /// # use edgefirst_client::{Client, SnapshotID};
3783 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3784 /// let client = Client::new()?.with_token_path(None)?;
3785 /// let snapshot_id = SnapshotID::from(123);
3786 /// client.delete_snapshot(snapshot_id).await?;
3787 /// # Ok(())
3788 /// # }
3789 /// ```
3790 ///
3791 /// # See Also
3792 ///
3793 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3794 /// * [`snapshots`](Self::snapshots) - List all snapshots
3795 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3796 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
3797 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3798 let _: serde_json::Value = self
3799 .rpc("snapshots.delete".to_owned(), Some(params))
3800 .await?;
3801 Ok(())
3802 }
3803
3804 /// Create a snapshot from an existing dataset on the server.
3805 ///
3806 /// Triggers server-side snapshot generation which exports the dataset's
3807 /// images and annotations into a downloadable EdgeFirst Dataset Format
3808 /// snapshot.
3809 ///
3810 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
3811 /// while restore creates a dataset from a snapshot, this method creates a
3812 /// snapshot from a dataset.
3813 ///
3814 /// # Arguments
3815 ///
3816 /// * `dataset_id` - The dataset ID to create snapshot from
3817 /// * `description` - Description for the created snapshot
3818 ///
3819 /// # Returns
3820 ///
3821 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
3822 /// for monitoring progress.
3823 ///
3824 /// # Errors
3825 ///
3826 /// Returns an error if:
3827 /// * Dataset doesn't exist
3828 /// * User lacks permission to access the dataset
3829 /// * Server rejects the request
3830 ///
3831 /// # Example
3832 ///
3833 /// ```no_run
3834 /// # use edgefirst_client::{Client, DatasetID};
3835 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3836 /// let client = Client::new()?.with_token_path(None)?;
3837 /// let dataset_id = DatasetID::from(123);
3838 ///
3839 /// // Create snapshot from dataset (all annotation sets)
3840 /// let result = client
3841 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
3842 /// .await?;
3843 /// println!("Created snapshot: {:?}", result.id);
3844 ///
3845 /// // Monitor progress via task ID
3846 /// if let Some(task_id) = result.task_id {
3847 /// println!("Task: {}", task_id);
3848 /// }
3849 /// # Ok(())
3850 /// # }
3851 /// ```
3852 ///
3853 /// # See Also
3854 ///
3855 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
3856 /// snapshot
3857 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3858 /// dataset
3859 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3860 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3861 pub async fn create_snapshot_from_dataset(
3862 &self,
3863 dataset_id: DatasetID,
3864 description: &str,
3865 annotation_set_id: Option<AnnotationSetID>,
3866 ) -> Result<SnapshotFromDatasetResult, Error> {
3867 // Resolve annotation_set_id: use provided value or fetch default
3868 let annotation_set_id = match annotation_set_id {
3869 Some(id) => id,
3870 None => {
3871 // Fetch annotation sets and find default ("annotations") or use first
3872 let sets = self.annotation_sets(dataset_id).await?;
3873 if sets.is_empty() {
3874 return Err(Error::InvalidParameters(
3875 "No annotation sets available for dataset".to_owned(),
3876 ));
3877 }
3878 // Look for "annotations" set (default), otherwise use first
3879 sets.iter()
3880 .find(|s| s.name() == "annotations")
3881 .unwrap_or(&sets[0])
3882 .id()
3883 }
3884 };
3885 let params = SnapshotCreateFromDataset {
3886 description: description.to_owned(),
3887 dataset_id,
3888 annotation_set_id,
3889 };
3890 self.rpc("snapshots.create".to_owned(), Some(params)).await
3891 }
3892
3893 /// Download a snapshot from EdgeFirst Studio to local storage.
3894 ///
3895 /// Downloads all files in a snapshot (single MCAP file or directory of
3896 /// EdgeFirst Dataset Format files) to the specified output path. Files are
3897 /// downloaded concurrently with progress tracking.
3898 ///
3899 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3900 /// downloads (default: half of CPU cores, min 2, max 8).
3901 ///
3902 /// # Arguments
3903 ///
3904 /// * `snapshot_id` - The snapshot ID to download
3905 /// * `output` - Local directory path to save downloaded files
3906 /// * `progress` - Optional channel to receive download progress updates
3907 ///
3908 /// # Progress
3909 ///
3910 /// Reports progress with `status: None` as file data is received. Progress
3911 /// unit is bytes downloaded across all files combined. The total
3912 /// accumulates as file sizes become known (from HTTP Content-Length
3913 /// headers), so both `current` and `total` may increase during
3914 /// download.
3915 ///
3916 /// # Errors
3917 ///
3918 /// Returns an error if:
3919 /// * Snapshot doesn't exist
3920 /// * Output directory cannot be created
3921 /// * Download fails or network error occurs
3922 ///
3923 /// # Example
3924 ///
3925 /// ```no_run
3926 /// # use edgefirst_client::{Client, SnapshotID, Progress};
3927 /// # use tokio::sync::mpsc;
3928 /// # use std::path::PathBuf;
3929 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3930 /// let client = Client::new()?.with_token_path(None)?;
3931 /// let snapshot_id = SnapshotID::from(123);
3932 ///
3933 /// // Download with progress tracking
3934 /// let (tx, mut rx) = mpsc::channel(1);
3935 /// tokio::spawn(async move {
3936 /// while let Some(Progress {
3937 /// current,
3938 /// total,
3939 /// status,
3940 /// }) = rx.recv().await
3941 /// {
3942 /// println!(
3943 /// "{}: {}/{} bytes",
3944 /// status.as_deref().unwrap_or("Download"),
3945 /// current,
3946 /// total
3947 /// );
3948 /// }
3949 /// });
3950 /// client
3951 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
3952 /// .await?;
3953 /// # Ok(())
3954 /// # }
3955 /// ```
3956 ///
3957 /// # See Also
3958 ///
3959 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3960 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3961 /// dataset
3962 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3963 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
3964 pub async fn download_snapshot(
3965 &self,
3966 snapshot_id: SnapshotID,
3967 output: PathBuf,
3968 progress: Option<Sender<Progress>>,
3969 ) -> Result<(), Error> {
3970 fs::create_dir_all(&output).await?;
3971
3972 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3973 let items: HashMap<String, String> = self
3974 .rpc("snapshots.create_download_url".to_owned(), Some(params))
3975 .await?;
3976
3977 // Single-phase: each task holds its semaphore permit for the full
3978 // lifetime of the request (GET → headers → stream → disk). This bounds
3979 // the number of simultaneously-open connections to max_tasks() and
3980 // avoids accumulating all responses in memory before streaming.
3981 //
3982 // total is updated atomically as each response's Content-Length header
3983 // arrives, so progress tracking is accurate without a separate phase.
3984 let http = self.bulk_http.clone();
3985 let current = Arc::new(AtomicUsize::new(0));
3986 let total = Arc::new(AtomicUsize::new(0));
3987 let sem = Arc::new(Semaphore::new(max_tasks()));
3988
3989 let tasks = items
3990 .into_iter()
3991 .map(|(key, url)| {
3992 let http = http.clone();
3993 let output = output.clone();
3994 let progress = progress.clone();
3995 let current = current.clone();
3996 let total = total.clone();
3997 let sem = sem.clone();
3998
3999 tokio::spawn(async move {
4000 let _permit = sem.acquire().await.map_err(|_| {
4001 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4002 })?;
4003
4004 let res = http.get(url).send().await?;
4005 let res = res.error_for_status()?;
4006
4007 // Contribute this file's size to the running total so the
4008 // caller's progress bar knows the overall scope.
4009 if let Some(len) = res.content_length() {
4010 total.fetch_add(len as usize, Ordering::SeqCst);
4011 }
4012
4013 let mut file = File::create(output.join(key)).await?;
4014 let mut stream = res.bytes_stream();
4015
4016 while let Some(chunk) = stream.next().await {
4017 let chunk = chunk?;
4018 file.write_all(&chunk).await?;
4019 let len = chunk.len();
4020
4021 if let Some(progress) = &progress {
4022 let cur = current.fetch_add(len, Ordering::SeqCst) + len;
4023 let tot = total.load(Ordering::SeqCst);
4024 let _ = progress
4025 .send(Progress {
4026 current: cur,
4027 total: tot,
4028 status: None,
4029 })
4030 .await;
4031 }
4032 }
4033
4034 Ok::<(), Error>(())
4035 })
4036 })
4037 .collect::<Vec<_>>();
4038
4039 join_all(tasks)
4040 .await
4041 .into_iter()
4042 .collect::<Result<Vec<_>, _>>()?
4043 .into_iter()
4044 .collect::<Result<Vec<_>, _>>()?;
4045
4046 Ok(())
4047 }
4048
4049 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
4050 ///
4051 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
4052 /// the specified project. For MCAP files, supports:
4053 ///
4054 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
4055 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
4056 /// present)
4057 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
4058 /// * **Topic filtering**: Select specific MCAP topics to restore
4059 ///
4060 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
4061 /// dataset structure.
4062 ///
4063 /// # Arguments
4064 ///
4065 /// * `project_id` - Target project ID
4066 /// * `snapshot_id` - Snapshot ID to restore
4067 /// * `topics` - MCAP topics to include (empty = all topics)
4068 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
4069 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
4070 /// * `dataset_name` - Optional custom dataset name
4071 /// * `dataset_description` - Optional dataset description
4072 ///
4073 /// # Returns
4074 ///
4075 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
4076 ///
4077 /// # Errors
4078 ///
4079 /// Returns an error if:
4080 /// * Snapshot or project doesn't exist
4081 /// * Snapshot format is invalid
4082 /// * Server rejects restoration parameters
4083 ///
4084 /// # Example
4085 ///
4086 /// ```no_run
4087 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
4088 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4089 /// let client = Client::new()?.with_token_path(None)?;
4090 /// let project_id = ProjectID::from(1);
4091 /// let snapshot_id = SnapshotID::from(123);
4092 ///
4093 /// // Restore MCAP with AGTG for "person" and "car" detection
4094 /// let result = client
4095 /// .restore_snapshot(
4096 /// project_id,
4097 /// snapshot_id,
4098 /// &[], // All topics
4099 /// &["person".to_string(), "car".to_string()], // AGTG labels
4100 /// true, // Auto-depth
4101 /// Some("Highway Dataset"),
4102 /// Some("Collected on I-95"),
4103 /// )
4104 /// .await?;
4105 /// println!("Restored to dataset: {:?}", result.dataset_id);
4106 /// # Ok(())
4107 /// # }
4108 /// ```
4109 ///
4110 /// # See Also
4111 ///
4112 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4113 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
4114 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
4115 #[allow(clippy::too_many_arguments)]
4116 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4117 pub async fn restore_snapshot(
4118 &self,
4119 project_id: ProjectID,
4120 snapshot_id: SnapshotID,
4121 topics: &[String],
4122 autolabel: &[String],
4123 autodepth: bool,
4124 dataset_name: Option<&str>,
4125 dataset_description: Option<&str>,
4126 ) -> Result<SnapshotRestoreResult, Error> {
4127 let params = SnapshotRestore {
4128 project_id,
4129 snapshot_id,
4130 fps: 1,
4131 autodepth,
4132 agtg_pipeline: !autolabel.is_empty(),
4133 autolabel: autolabel.to_vec(),
4134 topics: topics.to_vec(),
4135 dataset_name: dataset_name.map(|s| s.to_owned()),
4136 dataset_description: dataset_description.map(|s| s.to_owned()),
4137 };
4138 self.rpc("snapshots.restore".to_owned(), Some(params)).await
4139 }
4140
4141 /// Returns a list of experiments available to the user. The experiments
4142 /// are returned as a vector of Experiment objects. If name is provided
4143 /// then only experiments containing this string are returned.
4144 ///
4145 /// Results are sorted by match quality: exact matches first, then
4146 /// case-insensitive exact matches, then shorter names (more specific),
4147 /// then alphabetically.
4148 ///
4149 /// Experiments provide a method of organizing training and validation
4150 /// sessions together and are akin to an Experiment in MLFlow terminology.
4151 /// Each experiment can have multiple trainer sessions associated with it,
4152 /// these would be akin to runs in MLFlow terminology.
4153 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4154 pub async fn experiments(
4155 &self,
4156 project_id: ProjectID,
4157 name: Option<&str>,
4158 ) -> Result<Vec<Experiment>, Error> {
4159 let params = HashMap::from([("project_id", project_id)]);
4160 let experiments: Vec<Experiment> =
4161 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
4162 if let Some(name) = name {
4163 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
4164 } else {
4165 Ok(experiments)
4166 }
4167 }
4168
4169 /// Return the experiment with the specified experiment ID. If the
4170 /// experiment does not exist, an error is returned.
4171 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4172 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
4173 let params = HashMap::from([("trainer_id", experiment_id)]);
4174 self.rpc("trainer.get".to_owned(), Some(params)).await
4175 }
4176
4177 /// Returns a list of trainer sessions available to the user. The trainer
4178 /// sessions are returned as a vector of TrainingSession objects. If name
4179 /// is provided then only trainer sessions containing this string are
4180 /// returned.
4181 ///
4182 /// Results are sorted by match quality: exact matches first, then
4183 /// case-insensitive exact matches, then shorter names (more specific),
4184 /// then alphabetically.
4185 ///
4186 /// Trainer sessions are akin to runs in MLFlow terminology. These
4187 /// represent an actual training session which will produce metrics and
4188 /// model artifacts.
4189 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4190 pub async fn training_sessions(
4191 &self,
4192 experiment_id: ExperimentID,
4193 name: Option<&str>,
4194 ) -> Result<Vec<TrainingSession>, Error> {
4195 let params = HashMap::from([("trainer_id", experiment_id)]);
4196 let sessions: Vec<TrainingSession> = self
4197 .rpc("trainer.session.list".to_owned(), Some(params))
4198 .await?;
4199 if let Some(name) = name {
4200 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
4201 } else {
4202 Ok(sessions)
4203 }
4204 }
4205
4206 /// Return the trainer session with the specified trainer session ID. If
4207 /// the trainer session does not exist, an error is returned.
4208 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4209 pub async fn training_session(
4210 &self,
4211 session_id: TrainingSessionID,
4212 ) -> Result<TrainingSession, Error> {
4213 let params = HashMap::from([("trainer_session_id", session_id)]);
4214 self.rpc("trainer.session.get".to_owned(), Some(params))
4215 .await
4216 }
4217
4218 /// List validation sessions for the given project.
4219 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4220 pub async fn validation_sessions(
4221 &self,
4222 project_id: ProjectID,
4223 ) -> Result<Vec<ValidationSession>, Error> {
4224 let params = HashMap::from([("project_id", project_id)]);
4225 self.rpc("validate.session.list".to_owned(), Some(params))
4226 .await
4227 }
4228
4229 /// Retrieve a specific validation session.
4230 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4231 pub async fn validation_session(
4232 &self,
4233 session_id: ValidationSessionID,
4234 ) -> Result<ValidationSession, Error> {
4235 let params = HashMap::from([("validate_session_id", session_id)]);
4236 self.rpc("validate.session.get".to_owned(), Some(params))
4237 .await
4238 }
4239
4240 /// Create a new validation session via Studio's `cloud.server.start`.
4241 ///
4242 /// Pass `is_local: true` in the [`StartValidationRequest`] to create
4243 /// a **user-managed** session: the database row is created and the
4244 /// session is fully usable for data uploads / downloads / metrics,
4245 /// but no EC2 instance is provisioned and no automated validator
4246 /// pipeline is started. That is the mode our integration tests use
4247 /// — they create a session, exercise the wrapper APIs against it,
4248 /// then call [`Client::delete_validation_sessions`] in teardown so
4249 /// no stray sessions accumulate on the test account.
4250 ///
4251 /// Returns a [`NewValidationSession`] carrying the backing task id
4252 /// and the freshly-minted validation session id.
4253 ///
4254 /// # Errors
4255 ///
4256 /// Surfaces any RPC error from `cloud.server.start`. Common cases:
4257 /// `RpcError(101, …)` if a required entity is missing (project,
4258 /// training session, dataset, …); `PermissionDenied` if the caller
4259 /// can't write to the target project.
4260 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, req)))]
4261 pub async fn start_validation_session(
4262 &self,
4263 req: StartValidationRequest,
4264 ) -> Result<NewValidationSession, Error> {
4265 // Build the params shape the server expects. `cloud.server.start`
4266 // is intentionally generic — different server types pull
4267 // different fields out of `params` — so we serialize manually to
4268 // match the JS frontend's call site verbatim (see
4269 // `dve-frontend/src/components/ValidationPage/StartValidatorModal.vue`).
4270 let mut body = serde_json::Map::new();
4271 body.insert(
4272 "type".into(),
4273 serde_json::Value::String("validation".into()),
4274 );
4275 body.insert("name".into(), serde_json::Value::String(req.name));
4276 body.insert("project_id".into(), serde_json::to_value(req.project_id)?);
4277 body.insert(
4278 "training_session_id".into(),
4279 serde_json::to_value(req.training_session_id)?,
4280 );
4281 body.insert(
4282 "model_file".into(),
4283 serde_json::Value::String(req.model_file),
4284 );
4285 body.insert("val_type".into(), serde_json::Value::String(req.val_type));
4286 body.insert("is_local".into(), serde_json::Value::Bool(req.is_local));
4287 body.insert(
4288 "is_kubernetes".into(),
4289 serde_json::Value::Bool(req.is_kubernetes),
4290 );
4291
4292 // `validate.session` reads its config from `params.params` (one
4293 // extra envelope level). The outer `params` wrapper is required
4294 // even when the inner map is empty.
4295 let inner = serde_json::to_value(req.params)?;
4296 let mut outer = serde_json::Map::new();
4297 outer.insert("params".into(), inner);
4298 body.insert("params".into(), serde_json::Value::Object(outer));
4299
4300 if let Some(d) = req.description {
4301 body.insert("description".into(), serde_json::Value::String(d));
4302 }
4303 if let Some(id) = req.dataset_id {
4304 body.insert("dataset_id".into(), serde_json::to_value(id)?);
4305 }
4306 if let Some(id) = req.annotation_set_id {
4307 body.insert("annotation_set_id".into(), serde_json::to_value(id)?);
4308 }
4309 if let Some(id) = req.snapshot_id {
4310 body.insert("snapshot_id".into(), serde_json::to_value(id)?);
4311 }
4312
4313 self.rpc("cloud.server.start".to_owned(), Some(body)).await
4314 }
4315
4316 /// Delete one or more validation sessions via
4317 /// `validate.session.delete`.
4318 ///
4319 /// Used by integration tests to tear down sessions they created
4320 /// with [`Client::start_validation_session`]; idempotent against
4321 /// already-deleted ids on the server side (the RPC accepts the
4322 /// list, deletes what it can, and surfaces an error only if none
4323 /// of the ids were resolvable).
4324 ///
4325 /// # Errors
4326 ///
4327 /// Surfaces any RPC error from `validate.session.delete`. A
4328 /// `PermissionDenied` indicates the caller lacks
4329 /// `TrainerWrite` on at least one of the listed sessions.
4330 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4331 pub async fn delete_validation_sessions(
4332 &self,
4333 session_ids: &[ValidationSessionID],
4334 ) -> Result<(), Error> {
4335 let mut body = serde_json::Map::new();
4336 body.insert("session_ids".into(), serde_json::to_value(session_ids)?);
4337 let _: serde_json::Value = self
4338 .rpc("validate.session.delete".to_owned(), Some(body))
4339 .await?;
4340 Ok(())
4341 }
4342
4343 /// List the artifacts for the specified trainer session. The artifacts
4344 /// are returned as a vector of strings.
4345 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4346 pub async fn artifacts(
4347 &self,
4348 training_session_id: TrainingSessionID,
4349 ) -> Result<Vec<Artifact>, Error> {
4350 let params = HashMap::from([("training_session_id", training_session_id)]);
4351 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
4352 .await
4353 }
4354
4355 /// Download the model artifact for the specified trainer session to the
4356 /// specified file path, if path is not provided it will be downloaded to
4357 /// the current directory with the same filename.
4358 ///
4359 /// # Progress
4360 ///
4361 /// Reports progress with `status: None` as file data is received. Progress
4362 /// unit is bytes downloaded. Total is determined from the HTTP
4363 /// Content-Length header (may be 0 if server doesn't provide it).
4364 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4365 pub async fn download_artifact(
4366 &self,
4367 training_session_id: TrainingSessionID,
4368 modelname: &str,
4369 filename: Option<PathBuf>,
4370 progress: Option<Sender<Progress>>,
4371 ) -> Result<(), Error> {
4372 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
4373 let resp = self
4374 .bulk_http
4375 .get(format!(
4376 "{}/download_model?training_session_id={}&file={}",
4377 self.url,
4378 training_session_id.value(),
4379 modelname
4380 ))
4381 .header("Authorization", format!("Bearer {}", self.token().await))
4382 .send()
4383 .await?;
4384 if !resp.status().is_success() {
4385 let err = resp.error_for_status_ref().unwrap_err();
4386 return Err(Error::HttpError(err));
4387 }
4388
4389 if let Some(parent) = filename.parent() {
4390 fs::create_dir_all(parent).await?;
4391 }
4392
4393 stream_response_to_file(resp, &filename, progress).await
4394 }
4395
4396 /// Download the model checkpoint associated with the specified trainer
4397 /// session to the specified file path, if path is not provided it will be
4398 /// downloaded to the current directory with the same filename.
4399 ///
4400 /// There is no API for listing checkpoints it is expected that trainers are
4401 /// aware of possible checkpoints and their names within the checkpoint
4402 /// folder on the server.
4403 ///
4404 /// # Progress
4405 ///
4406 /// Reports progress with `status: None` as file data is received. Progress
4407 /// unit is bytes downloaded. Total is determined from the HTTP
4408 /// Content-Length header (may be 0 if server doesn't provide it).
4409 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4410 pub async fn download_checkpoint(
4411 &self,
4412 training_session_id: TrainingSessionID,
4413 checkpoint: &str,
4414 filename: Option<PathBuf>,
4415 progress: Option<Sender<Progress>>,
4416 ) -> Result<(), Error> {
4417 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
4418 let resp = self
4419 .bulk_http
4420 .get(format!(
4421 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
4422 self.url,
4423 training_session_id.value(),
4424 checkpoint
4425 ))
4426 .header("Authorization", format!("Bearer {}", self.token().await))
4427 .send()
4428 .await?;
4429 if !resp.status().is_success() {
4430 let err = resp.error_for_status_ref().unwrap_err();
4431 return Err(Error::HttpError(err));
4432 }
4433
4434 if let Some(parent) = filename.parent() {
4435 fs::create_dir_all(parent).await?;
4436 }
4437
4438 stream_response_to_file(resp, &filename, progress).await
4439 }
4440
4441 /// Return a list of tasks for the current user.
4442 ///
4443 /// # Arguments
4444 ///
4445 /// * `name` - Optional filter for task name (client-side substring match)
4446 /// * `workflow` - Optional filter for workflow/task type. If provided,
4447 /// filters server-side by exact match. Valid values include: "trainer",
4448 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
4449 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
4450 /// "convertor", "twostage"
4451 /// * `status` - Optional filter for task status (e.g., "running",
4452 /// "complete", "error")
4453 /// * `manager` - Optional filter for task manager type (e.g., "aws",
4454 /// "user", "kubernetes")
4455 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4456 pub async fn tasks(
4457 &self,
4458 name: Option<&str>,
4459 workflow: Option<&str>,
4460 status: Option<&str>,
4461 manager: Option<&str>,
4462 ) -> Result<Vec<Task>, Error> {
4463 let mut params = TasksListParams {
4464 continue_token: None,
4465 types: workflow.map(|w| vec![w.to_owned()]),
4466 status: status.map(|s| vec![s.to_owned()]),
4467 manager: manager.map(|m| vec![m.to_owned()]),
4468 };
4469 let mut tasks = Vec::new();
4470
4471 loop {
4472 let result = self
4473 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
4474 .await?;
4475 tasks.extend(result.tasks);
4476
4477 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
4478 params.continue_token = None;
4479 } else {
4480 params.continue_token = result.continue_token;
4481 }
4482
4483 if params.continue_token.is_none() {
4484 break;
4485 }
4486 }
4487
4488 if let Some(name) = name {
4489 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4490 }
4491
4492 Ok(tasks)
4493 }
4494
4495 /// Submits a job (app run) to the server and returns the resulting `Job`
4496 /// record (which carries the linked task id alongside the cloud-batch
4497 /// metadata).
4498 ///
4499 /// # Arguments
4500 /// * `app_name` - The name of the registered app to run (e.g., `"edgefirst-validator"`).
4501 /// * `job_name` - A user-defined label for this run.
4502 /// * `env` - Environment variables passed to the job (string-string map).
4503 /// * `data` - Job input payload (e.g., session ids, parameters).
4504 ///
4505 /// # Returns
4506 /// The full `Job` record returned by the server (wraps the BK_BATCH object),
4507 /// including AWS Batch job ID, state, and the linked `task_id`. Callers that
4508 /// only need the task ID can call `.task_id()` on the returned `Job`.
4509 pub async fn job_run(
4510 &self,
4511 app_name: &str,
4512 job_name: &str,
4513 env: std::collections::HashMap<String, String>,
4514 data: std::collections::HashMap<String, crate::api::Parameter>,
4515 ) -> Result<crate::api::Job, Error> {
4516 let req = JobRunRequest {
4517 name: app_name.to_owned(),
4518 job_name: job_name.to_owned(),
4519 env,
4520 data,
4521 };
4522 let resp: crate::api::Job = match self.rpc("job.run".to_owned(), Some(&req)).await {
4523 Ok(r) => r,
4524 Err(Error::RpcError(code, msg)) => {
4525 return Err(map_rpc_error("job.run", code, msg, None));
4526 }
4527 Err(e) => return Err(e),
4528 };
4529 Ok(resp)
4530 }
4531
4532 /// Requests a running job task be stopped.
4533 ///
4534 /// Returns `Ok(())` if the stop request was accepted by the server. The
4535 /// task may still take time to fully terminate; poll `task_info` if you
4536 /// need to wait for shutdown.
4537 pub async fn job_stop(&self, task_id: crate::api::TaskID) -> Result<(), Error> {
4538 let req = JobStopRequest {
4539 task_id: task_id.value(),
4540 };
4541 // We don't care about the response body; deserialize as serde_json::Value.
4542 let _resp: serde_json::Value = match self.rpc("job.stop".to_owned(), Some(&req)).await {
4543 Ok(r) => r,
4544 Err(Error::RpcError(code, msg)) => {
4545 return Err(map_rpc_error("job.stop", code, msg, Some(task_id)));
4546 }
4547 Err(e) => return Err(e),
4548 };
4549 Ok(())
4550 }
4551
4552 /// Lists job (app-run) entries visible to the authenticated user.
4553 ///
4554 /// The server returns AWS Batch-wrapper entries (not bare `Task` objects),
4555 /// surfacing cloud-batch state (`RUNNING`/`SUCCEEDED`/...) and the linked
4556 /// `task_id`. Use `Job::task_id()` + `Client::task_info` to fetch the
4557 /// underlying task details.
4558 ///
4559 /// The server does not support server-side filters, so the optional
4560 /// `name` argument is applied client-side as a substring match against
4561 /// each job's `job_name`.
4562 pub async fn jobs(&self, name: Option<&str>) -> Result<Vec<crate::api::Job>, Error> {
4563 let req = JobsListRequest {};
4564 let mut jobs: Vec<crate::api::Job> = match self.rpc("job.list".to_owned(), Some(&req)).await
4565 {
4566 Ok(r) => r,
4567 Err(Error::RpcError(code, msg)) => {
4568 return Err(map_rpc_error("job.list", code, msg, None));
4569 }
4570 Err(e) => return Err(e),
4571 };
4572 if let Some(name) = name {
4573 let needle = name.to_lowercase();
4574 jobs.retain(|j| j.job_name.to_lowercase().contains(&needle));
4575 jobs.sort_by(|a, b| a.job_name.cmp(&b.job_name));
4576 }
4577 Ok(jobs)
4578 }
4579
4580 /// Retrieve the task information and status.
4581 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4582 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4583 self.rpc(
4584 "task.get".to_owned(),
4585 Some(HashMap::from([("id", task_id)])),
4586 )
4587 .await
4588 }
4589
4590 /// Updates the tasks status.
4591 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4592 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4593 let status = TaskStatus {
4594 task_id,
4595 status: status.to_owned(),
4596 };
4597 self.rpc("docker.update.status".to_owned(), Some(status))
4598 .await
4599 }
4600
4601 /// Defines the stages for the task. The stages are defined as a mapping
4602 /// from stage names to their descriptions. Once stages are defined their
4603 /// status can be updated using the update_stage method.
4604 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4605 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4606 let stages: Vec<HashMap<String, String>> = stages
4607 .iter()
4608 .map(|(key, value)| {
4609 let mut stage_map = HashMap::new();
4610 stage_map.insert(key.to_string(), value.to_string());
4611 stage_map
4612 })
4613 .collect();
4614 let params = TaskStages { task_id, stages };
4615 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4616 Ok(())
4617 }
4618
4619 /// Updates the progress of the task for the provided stage and status
4620 /// information.
4621 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4622 pub async fn update_stage(
4623 &self,
4624 task_id: TaskID,
4625 stage: &str,
4626 status: &str,
4627 message: &str,
4628 percentage: u8,
4629 ) -> Result<(), Error> {
4630 let stage = Stage::new(
4631 Some(task_id),
4632 stage.to_owned(),
4633 Some(status.to_owned()),
4634 Some(message.to_owned()),
4635 percentage,
4636 );
4637 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4638 Ok(())
4639 }
4640
4641 /// Authenticated fetch from the Studio server using the bulk HTTP client
4642 /// (no total-request timeout; idle read timeout per chunk).
4643 ///
4644 /// **Buffers the entire response body into memory.** Suitable for small to
4645 /// medium payloads. For very large binary downloads (multi-GB artifacts or
4646 /// checkpoints), prefer a streaming approach that writes directly to disk.
4647 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4648 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4649 let req = self
4650 .bulk_http
4651 .get(format!("{}/{}", self.url, query))
4652 .header("User-Agent", "EdgeFirst Client")
4653 .header("Authorization", format!("Bearer {}", self.token().await));
4654 let resp = req.send().await?;
4655
4656 if resp.status().is_success() {
4657 let body = resp.bytes().await?;
4658
4659 if log_enabled!(Level::Trace) {
4660 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
4661 }
4662
4663 Ok(body.to_vec())
4664 } else {
4665 let err = resp.error_for_status_ref().unwrap_err();
4666 Err(Error::HttpError(err))
4667 }
4668 }
4669
4670 /// Sends a multipart post request to the server. This is used by the
4671 /// upload and download APIs which do not use JSON-RPC but instead transfer
4672 /// files using multipart/form-data.
4673 ///
4674 /// The result field is deserialized as `serde_json::Value` rather than
4675 /// `String` because different server endpoints return different shapes —
4676 /// `val.data.upload` returns a plain string while `task.data.upload`
4677 /// returns an object `{"message":…,"path":…,"size":…}`. All current
4678 /// callers discard the return value so this is backwards-compatible.
4679 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
4680 pub async fn post_multipart(
4681 &self,
4682 method: &str,
4683 form: Form,
4684 ) -> Result<serde_json::Value, Error> {
4685 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4686 .ok()
4687 .and_then(|s| s.parse().ok())
4688 .unwrap_or(600u64);
4689
4690 let req = self
4691 .http
4692 .post(format!("{}/api?method={}", self.url, method))
4693 .header("Accept", "application/json")
4694 .header("User-Agent", "EdgeFirst Client")
4695 .header("Authorization", format!("Bearer {}", self.token().await))
4696 .timeout(Duration::from_secs(upload_timeout_secs))
4697 .multipart(form);
4698 let resp = req.send().await?;
4699
4700 if resp.status().is_success() {
4701 let body = resp.bytes().await?;
4702
4703 if log_enabled!(Level::Trace) {
4704 trace!(
4705 "POST Multipart Response: {}",
4706 String::from_utf8_lossy(&body)
4707 );
4708 }
4709
4710 let response: RpcResponse<serde_json::Value> = match serde_json::from_slice(&body) {
4711 Ok(response) => response,
4712 Err(err) => {
4713 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4714 return Err(err.into());
4715 }
4716 };
4717
4718 if let Some(error) = response.error {
4719 Err(Error::RpcError(error.code, error.message))
4720 } else if let Some(result) = response.result {
4721 Ok(result)
4722 } else {
4723 Err(Error::InvalidResponse)
4724 }
4725 } else {
4726 // HTTP-level failure on the multipart upload. Map 413 to the
4727 // typed `PayloadTooLarge` variant so callers see the same error
4728 // type from both single-file rpc_download paths and multipart
4729 // upload paths; everything else falls through to HttpError.
4730 let status = resp.status();
4731 if status.as_u16() == 413 {
4732 return Err(Error::PayloadTooLarge {
4733 method: method.to_string(),
4734 size_hint: None,
4735 });
4736 }
4737 let err = resp.error_for_status_ref().unwrap_err();
4738 Err(Error::HttpError(err))
4739 }
4740 }
4741
4742 /// Internal helper: POST a JSON-RPC request and stream the binary response
4743 /// to `output_path`. The response is assumed to be raw binary (not a JSON
4744 /// envelope). Use for endpoints that return file contents directly.
4745 ///
4746 /// On HTTP non-success, the response body is read as text and surfaced
4747 /// via `Error::RpcError(status_code, body)`.
4748 pub(crate) async fn rpc_download<P: Serialize>(
4749 &self,
4750 method: &str,
4751 params: &P,
4752 output_path: &std::path::Path,
4753 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
4754 ) -> Result<(), Error> {
4755 let envelope = serde_json::json!({
4756 "jsonrpc": "2.0",
4757 "id": 0,
4758 "method": method,
4759 "params": params,
4760 });
4761
4762 let url = format!("{}/api", self.url);
4763 let resp = self
4764 .bulk_http
4765 .post(&url)
4766 .header("Authorization", format!("Bearer {}", self.token().await))
4767 .json(&envelope)
4768 .send()
4769 .await?;
4770
4771 let status = resp.status();
4772 if !status.is_success() {
4773 if status.as_u16() == 413 {
4774 return Err(Error::PayloadTooLarge {
4775 method: method.to_string(),
4776 size_hint: None,
4777 });
4778 }
4779 let body = resp.text().await.unwrap_or_default();
4780 return Err(Error::RpcError(status.as_u16() as i32, body));
4781 }
4782
4783 // HTTP 200 with Content-Type: application/json can mean two things:
4784 // (a) a JSON-RPC error envelope when the server failed mid-way
4785 // (e.g. {"jsonrpc":"2.0","error":{"code":N,"message":"..."}}),
4786 // (b) a legitimate JSON file payload — validation traces, chart
4787 // bodies, metrics, etc., are typically served with this MIME.
4788 //
4789 // Disambiguate structurally: a JSON-RPC 2.0 envelope is required to
4790 // carry a `jsonrpc` member, and an *error* envelope further requires
4791 // an `error.code` integer (per RFC 8259 + JSON-RPC 2.0 §5). Only
4792 // decode the body as an error if both markers are present. This is
4793 // strict enough to leave legitimate JSON artifacts that happen to
4794 // contain a free-form `error` field (metrics, diagnostics, log
4795 // dumps) untouched, while still catching every real server
4796 // failure.
4797 let content_type = resp
4798 .headers()
4799 .get(reqwest::header::CONTENT_TYPE)
4800 .and_then(|v| v.to_str().ok())
4801 .unwrap_or("")
4802 .to_owned();
4803 if content_type.contains("application/json") {
4804 let body = resp.bytes().await?;
4805 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&body)
4806 && is_jsonrpc_error_envelope(&val)
4807 && let Some(err_obj) = val.get("error")
4808 {
4809 let code = err_obj.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
4810 let message = err_obj
4811 .get("message")
4812 .and_then(|m| m.as_str())
4813 .unwrap_or("unknown error")
4814 .to_string();
4815 return Err(Error::RpcError(code, message));
4816 }
4817 // Not an error envelope — body is a JSON file. Write it to disk
4818 // and emit a single completion progress event so callers (e.g.,
4819 // Python download_data progress callbacks) see the download
4820 // finish.
4821 //
4822 // `Path::parent` returns `Some("")` for a bare filename like
4823 // "metrics.json"; `create_dir_all("")` errors out with
4824 // `NotFound`, so only create the parent when it actually names
4825 // a directory.
4826 if let Some(parent) = output_path.parent()
4827 && !parent.as_os_str().is_empty()
4828 {
4829 tokio::fs::create_dir_all(parent).await?;
4830 }
4831 let mut file = tokio::fs::File::create(output_path).await?;
4832 file.write_all(&body).await?;
4833 file.flush().await?;
4834 if let Some(tx) = progress {
4835 let total = body.len();
4836 // Use the awaited send for the final event so completion
4837 // handlers are never silently dropped.
4838 let _ = tx
4839 .send(Progress {
4840 current: total,
4841 total,
4842 status: None,
4843 })
4844 .await;
4845 }
4846 return Ok(());
4847 }
4848
4849 // Same empty-parent guard for the streaming download path: passing
4850 // a bare filename like "metrics.json" must write to the current
4851 // directory rather than failing on `create_dir_all("")`.
4852 if let Some(parent) = output_path.parent()
4853 && !parent.as_os_str().is_empty()
4854 {
4855 tokio::fs::create_dir_all(parent).await?;
4856 }
4857
4858 stream_response_to_file(resp, output_path, progress).await
4859 }
4860
4861 /// Send a JSON-RPC request to the server. The method is the name of the
4862 /// method to call on the server. The params are the parameters to pass to
4863 /// the method. The method and params are serialized into a JSON-RPC
4864 /// request and sent to the server. The response is deserialized into
4865 /// the specified type and returned to the caller.
4866 ///
4867 /// NOTE: This API would generally not be called directly and instead users
4868 /// should use the higher-level methods provided by the client.
4869 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
4870 pub async fn rpc<Params, RpcResult>(
4871 &self,
4872 method: String,
4873 params: Option<Params>,
4874 ) -> Result<RpcResult, Error>
4875 where
4876 Params: Serialize,
4877 RpcResult: DeserializeOwned,
4878 {
4879 let auth_expires = self.token_expiration().await?;
4880 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
4881 self.renew_token().await?;
4882 }
4883
4884 self.rpc_without_auth(method, params).await
4885 }
4886
4887 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
4888 async fn rpc_without_auth<Params, RpcResult>(
4889 &self,
4890 method: String,
4891 params: Option<Params>,
4892 ) -> Result<RpcResult, Error>
4893 where
4894 Params: Serialize,
4895 RpcResult: DeserializeOwned,
4896 {
4897 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4898 .ok()
4899 .and_then(|s| s.parse().ok())
4900 .unwrap_or(5usize);
4901
4902 let url = format!("{}/api", self.url);
4903
4904 // Serialize request body once before retry loop to avoid Clone bound on Params
4905 let request = RpcRequest {
4906 method: method.clone(),
4907 params,
4908 ..Default::default()
4909 };
4910
4911 // Log request for debugging (log crate) and profiling (tracing crate)
4912 let request_json = if method == "auth.login" {
4913 // Redact auth.login params (contains password)
4914 serde_json::json!({
4915 "jsonrpc": "2.0",
4916 "method": &method,
4917 "params": "[REDACTED - contains credentials]",
4918 "id": request.id
4919 })
4920 .to_string()
4921 } else {
4922 serde_json::to_string(&request)?
4923 };
4924
4925 if log_enabled!(Level::Trace) {
4926 trace!("RPC Request: {}", request_json);
4927 }
4928
4929 // Record request on current span for Perfetto when profiling is enabled
4930 #[cfg(feature = "profiling")]
4931 tracing::Span::current().record("request", &request_json);
4932
4933 let request_body = serde_json::to_vec(&request)?;
4934 let mut last_error: Option<Error> = None;
4935
4936 for attempt in 0..=max_retries {
4937 if attempt > 0 {
4938 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
4939 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
4940 // while ensuring we never retry faster than the base delay
4941 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
4942 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
4943 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
4944 let delay = Duration::from_millis(delay_ms);
4945 warn!(
4946 "Retry {}/{} for RPC '{}' after {:?}",
4947 attempt, max_retries, method, delay
4948 );
4949 tokio::time::sleep(delay).await;
4950 }
4951
4952 let result = self
4953 .http
4954 .post(&url)
4955 .header("Accept", "application/json")
4956 .header("Content-Type", "application/json")
4957 .header("User-Agent", "EdgeFirst Client")
4958 .header("Authorization", format!("Bearer {}", self.token().await))
4959 .body(request_body.clone())
4960 .send()
4961 .await;
4962
4963 match result {
4964 Ok(res) => {
4965 let status = res.status();
4966 let status_code = status.as_u16();
4967
4968 // Check for retryable HTTP status codes before processing response
4969 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
4970 && attempt < max_retries
4971 {
4972 warn!(
4973 "RPC '{}' failed with HTTP {} (retrying)",
4974 method, status_code
4975 );
4976 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
4977 continue;
4978 }
4979
4980 // Process the response
4981 match self.process_rpc_response(res).await {
4982 Ok(result) => {
4983 if attempt > 0 {
4984 debug!("RPC '{}' succeeded on retry {}", method, attempt);
4985 }
4986 return Ok(result);
4987 }
4988 Err(e) => {
4989 // Don't retry client errors (4xx except 408, 429)
4990 if attempt > 0 {
4991 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4992 }
4993 return Err(e);
4994 }
4995 }
4996 }
4997 Err(e) => {
4998 // Transport error (timeout, connection failure, etc.)
4999 let is_timeout = e.is_timeout();
5000 let is_connect = e.is_connect();
5001
5002 if (is_timeout || is_connect) && attempt < max_retries {
5003 warn!(
5004 "RPC '{}' transport error (retrying): {}",
5005 method,
5006 if is_timeout {
5007 "timeout"
5008 } else {
5009 "connection failed"
5010 }
5011 );
5012 last_error = Some(Error::HttpError(e));
5013 continue;
5014 }
5015
5016 if attempt > 0 {
5017 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
5018 }
5019 return Err(Error::HttpError(e));
5020 }
5021 }
5022 }
5023
5024 // Should not reach here
5025 Err(last_error.unwrap_or_else(|| {
5026 Error::InvalidParameters(format!(
5027 "RPC '{}' failed after {} retries",
5028 method, max_retries
5029 ))
5030 }))
5031 }
5032
5033 async fn process_rpc_response<RpcResult>(
5034 &self,
5035 res: reqwest::Response,
5036 ) -> Result<RpcResult, Error>
5037 where
5038 RpcResult: DeserializeOwned,
5039 {
5040 let body = res.bytes().await?;
5041 let response_str = String::from_utf8_lossy(&body);
5042
5043 if log_enabled!(Level::Trace) {
5044 trace!("RPC Response: {}", response_str);
5045 }
5046
5047 // Record response on current span for Perfetto when profiling is enabled
5048 // Truncate large responses to avoid bloating trace files
5049 #[cfg(feature = "profiling")]
5050 {
5051 const MAX_RESPONSE_LEN: usize = 4096;
5052 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
5053 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
5054 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
5055 format!(
5056 "{}...[truncated {} bytes]",
5057 &response_str[..safe_end],
5058 response_str.len() - safe_end
5059 )
5060 } else {
5061 response_str.to_string()
5062 };
5063 tracing::Span::current().record("response", &truncated);
5064 }
5065
5066 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
5067 Ok(response) => response,
5068 Err(err) => {
5069 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
5070 return Err(err.into());
5071 }
5072 };
5073
5074 // FIXME: Studio Server always returns 999 as the id.
5075 // if request.id.to_string() != response.id {
5076 // return Err(Error::InvalidRpcId(response.id));
5077 // }
5078
5079 if let Some(error) = response.error {
5080 Err(Error::RpcError(error.code, error.message))
5081 } else if let Some(result) = response.result {
5082 Ok(result)
5083 } else {
5084 Err(Error::InvalidResponse)
5085 }
5086 }
5087}
5088
5089/// Process items in parallel with semaphore concurrency control and progress
5090/// tracking.
5091///
5092/// This helper eliminates boilerplate for parallel item processing with:
5093/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
5094/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
5095/// - Atomic progress counter with automatic item-level updates
5096/// - Progress updates sent after each item completes (not byte-level streaming)
5097/// - Proper error propagation from spawned tasks
5098///
5099/// Note: This is optimized for discrete items with post-completion progress
5100/// updates. For byte-level streaming progress or custom retry logic, use
5101/// specialized implementations.
5102///
5103/// # Arguments
5104///
5105/// * `items` - Collection of items to process in parallel
5106/// * `progress` - Optional progress channel for tracking completion
5107/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
5108/// * `work_fn` - Async function to execute for each item
5109///
5110/// # Examples
5111///
5112/// ```rust,ignore
5113/// // Use default concurrency
5114/// parallel_foreach_items(samples, progress, None, |sample| async move {
5115/// sample.download(&client, file_type).await?;
5116/// Ok(())
5117/// }).await?;
5118/// ```
5119async fn parallel_foreach_items<T, F, Fut>(
5120 items: Vec<T>,
5121 progress: Option<Sender<Progress>>,
5122 concurrency: Option<usize>,
5123 work_fn: F,
5124) -> Result<(), Error>
5125where
5126 T: Send + 'static,
5127 F: Fn(T) -> Fut + Send + Sync + 'static,
5128 Fut: Future<Output = Result<(), Error>> + Send + 'static,
5129{
5130 let total = items.len();
5131 let current = Arc::new(AtomicUsize::new(0));
5132 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
5133 let work_fn = Arc::new(work_fn);
5134
5135 let tasks = items
5136 .into_iter()
5137 .map(|item| {
5138 let sem = sem.clone();
5139 let current = current.clone();
5140 let progress = progress.clone();
5141 let work_fn = work_fn.clone();
5142
5143 tokio::spawn(async move {
5144 let _permit = sem.acquire().await.map_err(|_| {
5145 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5146 })?;
5147
5148 // Execute the actual work
5149 work_fn(item).await?;
5150
5151 // Update progress
5152 if let Some(progress) = &progress {
5153 let current = current.fetch_add(1, Ordering::SeqCst);
5154 let _ = progress
5155 .send(Progress {
5156 current: current + 1,
5157 total,
5158 status: None,
5159 })
5160 .await;
5161 }
5162
5163 Ok::<(), Error>(())
5164 })
5165 })
5166 .collect::<Vec<_>>();
5167
5168 join_all(tasks)
5169 .await
5170 .into_iter()
5171 .collect::<Result<Vec<_>, _>>()?
5172 .into_iter()
5173 .collect::<Result<Vec<_>, _>>()?;
5174
5175 if let Some(progress) = progress {
5176 drop(progress);
5177 }
5178
5179 Ok(())
5180}
5181
5182/// Upload a file to S3 using multipart upload with presigned URLs.
5183///
5184/// Splits a file into chunks (100MB each) and uploads them in parallel using
5185/// S3 multipart upload protocol. Returns completion parameters with ETags for
5186/// finalizing the upload.
5187///
5188/// This function handles:
5189/// - Splitting files into parts based on PART_SIZE (100MB)
5190/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
5191/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
5192/// - Retry logic (handled by reqwest client)
5193/// - Progress tracking across all parts
5194///
5195/// # Arguments
5196///
5197/// * `http` - HTTP client for making requests
5198/// * `part` - Snapshot part info with presigned URLs for each chunk
5199/// * `path` - Local file path to upload
5200/// * `total` - Total bytes across all files for progress calculation
5201/// * `current` - Atomic counter tracking bytes uploaded across all operations
5202/// * `progress` - Optional channel for sending progress updates
5203///
5204/// # Returns
5205///
5206/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
5207async fn upload_multipart(
5208 http: reqwest::Client,
5209 part: SnapshotPart,
5210 path: PathBuf,
5211 total: usize,
5212 confirmed_bytes: Arc<AtomicUsize>,
5213 progress: Option<Sender<Progress>>,
5214) -> Result<SnapshotCompleteMultipartParams, Error> {
5215 let filesize = path.metadata()?.len() as usize;
5216 let n_parts = filesize.div_ceil(PART_SIZE);
5217 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
5218
5219 let key = part.key.ok_or(Error::InvalidResponse)?;
5220 let upload_id = part.upload_id;
5221
5222 let urls = part.urls.clone();
5223
5224 // Pre-allocate ETag slots for all parts
5225 let etags = Arc::new(tokio::sync::Mutex::new(vec![
5226 EtagPart {
5227 etag: "".to_owned(),
5228 part_number: 0,
5229 };
5230 n_parts
5231 ]));
5232
5233 // Per-part byte counters for streaming progress (reset on retry)
5234 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
5235 (0..n_parts)
5236 .map(|_| AtomicUsize::new(0))
5237 .collect::<Vec<_>>(),
5238 );
5239
5240 // Upload all parts in parallel with concurrency limiting
5241 let tasks = (0..n_parts)
5242 .map(|part_idx| {
5243 let http = http.clone();
5244 let url = urls[part_idx].clone();
5245 let etags = etags.clone();
5246 let path = path.to_owned();
5247 let sem = sem.clone();
5248 let progress = progress.clone();
5249 let confirmed_bytes = confirmed_bytes.clone();
5250 let part_bytes = part_bytes.clone();
5251
5252 // Calculate this part's size
5253 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5254 filesize % PART_SIZE
5255 } else {
5256 PART_SIZE
5257 };
5258
5259 tokio::spawn(async move {
5260 // Acquire semaphore permit to limit concurrent uploads
5261 let _permit = sem.acquire().await.map_err(|_| {
5262 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5263 })?;
5264
5265 // Upload part with streaming progress and retry logic
5266 let etag = upload_part_with_progress(
5267 http,
5268 url,
5269 path,
5270 part_idx,
5271 n_parts,
5272 part_size,
5273 total,
5274 confirmed_bytes.clone(),
5275 part_bytes.clone(),
5276 progress.clone(),
5277 )
5278 .await?;
5279
5280 // Store ETag for this part (needed to complete multipart upload)
5281 let mut etags_guard = etags.lock().await;
5282 etags_guard[part_idx] = EtagPart {
5283 etag,
5284 part_number: part_idx + 1,
5285 };
5286
5287 // Part completed successfully - add to confirmed bytes
5288 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
5289 // Reset part counter since it's now confirmed
5290 part_bytes[part_idx].store(0, Ordering::SeqCst);
5291
5292 // Send final progress update for this part
5293 if let Some(progress) = &progress {
5294 let current = confirmed_bytes.load(Ordering::SeqCst)
5295 + part_bytes
5296 .iter()
5297 .map(|p| p.load(Ordering::SeqCst))
5298 .sum::<usize>();
5299 let _ = progress
5300 .send(Progress {
5301 current,
5302 total,
5303 status: None,
5304 })
5305 .await;
5306 }
5307
5308 Ok::<(), Error>(())
5309 })
5310 })
5311 .collect::<Vec<_>>();
5312
5313 // Wait for all parts to complete (double collect to handle both JoinError and
5314 // inner Error)
5315 join_all(tasks)
5316 .await
5317 .into_iter()
5318 .collect::<Result<Vec<_>, _>>()?
5319 .into_iter()
5320 .collect::<Result<Vec<_>, _>>()?;
5321
5322 Ok(SnapshotCompleteMultipartParams {
5323 key,
5324 upload_id,
5325 etag_list: etags.lock().await.clone(),
5326 })
5327}
5328
5329/// Upload a single part with streaming progress tracking and retry logic.
5330///
5331/// Progress is reported continuously as bytes are sent. On retry, the part's
5332/// progress counter is reset to avoid over-reporting.
5333#[allow(clippy::too_many_arguments)]
5334async fn upload_part_with_progress(
5335 http: reqwest::Client,
5336 url: String,
5337 path: PathBuf,
5338 part_idx: usize,
5339 n_parts: usize,
5340 part_size: usize,
5341 total: usize,
5342 confirmed_bytes: Arc<AtomicUsize>,
5343 part_bytes: Arc<Vec<AtomicUsize>>,
5344 progress: Option<Sender<Progress>>,
5345) -> Result<String, Error> {
5346 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5347 .ok()
5348 .and_then(|s| s.parse().ok())
5349 .unwrap_or(5usize);
5350
5351 // Per-part total upload timeout. Covers the send phase (request body) where
5352 // read_timeout does not apply. Each part is at most PART_SIZE (100MB), so
5353 // this bounds how long a stalled upload can block before retrying.
5354 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5355 .ok()
5356 .and_then(|s| s.parse().ok())
5357 .unwrap_or(600u64); // 600s = 100MB at ~170 KB/s minimum
5358
5359 let mut last_error: Option<Error> = None;
5360
5361 for attempt in 0..=max_retries {
5362 if attempt > 0 {
5363 // Reset this part's progress counter before retry
5364 part_bytes[part_idx].store(0, Ordering::SeqCst);
5365
5366 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5367 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5368 warn!(
5369 "Retry {}/{} for part {} after {:?}",
5370 attempt, max_retries, part_idx, delay
5371 );
5372 tokio::time::sleep(delay).await;
5373 }
5374
5375 match upload_part_streaming(
5376 http.clone(),
5377 url.clone(),
5378 path.clone(),
5379 part_idx,
5380 n_parts,
5381 part_size,
5382 total,
5383 upload_timeout_secs,
5384 confirmed_bytes.clone(),
5385 part_bytes.clone(),
5386 progress.clone(),
5387 )
5388 .await
5389 {
5390 Ok(etag) => return Ok(etag),
5391 Err(e) => {
5392 // Check if error is retryable
5393 let is_retryable = matches!(
5394 &e,
5395 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
5396 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
5397 );
5398
5399 if is_retryable && attempt < max_retries {
5400 last_error = Some(e);
5401 continue;
5402 }
5403
5404 return Err(e);
5405 }
5406 }
5407 }
5408
5409 Err(last_error
5410 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
5411}
5412
5413/// Perform the actual upload with streaming progress.
5414#[allow(clippy::too_many_arguments)]
5415async fn upload_part_streaming(
5416 http: reqwest::Client,
5417 url: String,
5418 path: PathBuf,
5419 part_idx: usize,
5420 n_parts: usize,
5421 _part_size: usize,
5422 total: usize,
5423 upload_timeout_secs: u64,
5424 confirmed_bytes: Arc<AtomicUsize>,
5425 part_bytes: Arc<Vec<AtomicUsize>>,
5426 progress: Option<Sender<Progress>>,
5427) -> Result<String, Error> {
5428 let filesize = path.metadata()?.len() as usize;
5429 let mut file = File::open(&path).await?;
5430 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
5431 .await?;
5432 let file = file.take(PART_SIZE as u64);
5433
5434 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5435 filesize % PART_SIZE
5436 } else {
5437 PART_SIZE
5438 };
5439
5440 // Create stream with progress tracking
5441 let stream = FramedRead::new(file, BytesCodec::new());
5442
5443 // Wrap stream to track bytes sent and report progress
5444 let progress_stream = stream.map(move |result| {
5445 if let Ok(ref bytes) = result {
5446 let bytes_len = bytes.len();
5447 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
5448
5449 // Send progress update (fire-and-forget via try_send to avoid blocking)
5450 if let Some(ref progress) = progress {
5451 let current = confirmed_bytes.load(Ordering::SeqCst)
5452 + part_bytes
5453 .iter()
5454 .map(|p| p.load(Ordering::SeqCst))
5455 .sum::<usize>();
5456 // Best-effort progress reporting: use try_send to avoid blocking.
5457 // If the channel is full or closed, we intentionally skip this update
5458 // to avoid stalling the upload; subsequent updates will still be delivered.
5459 let _ = progress.try_send(Progress {
5460 current,
5461 total,
5462 status: None,
5463 });
5464 }
5465 }
5466 result.map(|b| b.freeze())
5467 });
5468
5469 let body = Body::wrap_stream(progress_stream);
5470
5471 let resp = http
5472 .put(url)
5473 .header(CONTENT_LENGTH, body_length)
5474 .timeout(Duration::from_secs(upload_timeout_secs))
5475 .body(body)
5476 .send()
5477 .await?
5478 .error_for_status()?;
5479
5480 let etag = resp
5481 .headers()
5482 .get("etag")
5483 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
5484 .to_str()
5485 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
5486 .to_owned();
5487
5488 // Studio Server requires etag without the quotes.
5489 let etag = etag
5490 .strip_prefix("\"")
5491 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
5492 let etag = etag
5493 .strip_suffix("\"")
5494 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
5495
5496 Ok(etag.to_owned())
5497}
5498
5499/// Upload a complete file to a presigned S3 URL using HTTP PUT.
5500///
5501/// This is used for populate_samples to upload files to S3 after
5502/// receiving presigned URLs from the server.
5503///
5504/// Includes explicit retry logic with exponential backoff for transient
5505/// failures.
5506async fn upload_file_to_presigned_url(
5507 http: reqwest::Client,
5508 url: &str,
5509 path: PathBuf,
5510) -> Result<(), Error> {
5511 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5512 .ok()
5513 .and_then(|s| s.parse().ok())
5514 .unwrap_or(5usize);
5515
5516 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5517 .ok()
5518 .and_then(|s| s.parse().ok())
5519 .unwrap_or(600u64);
5520
5521 // Read the entire file into memory once
5522 let file_data = fs::read(&path).await?;
5523 let file_size = file_data.len();
5524 let filename = path.file_name().unwrap_or_default().to_string_lossy();
5525
5526 let mut last_error: Option<Error> = None;
5527
5528 for attempt in 0..=max_retries {
5529 if attempt > 0 {
5530 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5531 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5532 warn!(
5533 "Retry {}/{} for upload '{}' after {:?}",
5534 attempt, max_retries, filename, delay
5535 );
5536 tokio::time::sleep(delay).await;
5537 }
5538
5539 // Attempt upload
5540 let result = http
5541 .put(url)
5542 .header(CONTENT_LENGTH, file_size)
5543 .timeout(Duration::from_secs(upload_timeout_secs))
5544 .body(file_data.clone())
5545 .send()
5546 .await;
5547
5548 match result {
5549 Ok(resp) => {
5550 if resp.status().is_success() {
5551 if attempt > 0 {
5552 debug!(
5553 "Upload '{}' succeeded on retry {} ({} bytes)",
5554 filename, attempt, file_size
5555 );
5556 } else {
5557 debug!(
5558 "Successfully uploaded file: {} ({} bytes)",
5559 filename, file_size
5560 );
5561 }
5562 return Ok(());
5563 }
5564
5565 let status = resp.status();
5566 let status_code = status.as_u16();
5567
5568 // Check if error is retryable
5569 let is_retryable =
5570 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5571
5572 if is_retryable && attempt < max_retries {
5573 let error_text = resp.text().await.unwrap_or_default();
5574 warn!(
5575 "Upload '{}' failed with HTTP {} (retryable): {}",
5576 filename, status_code, error_text
5577 );
5578 last_error = Some(Error::InvalidParameters(format!(
5579 "Upload failed: HTTP {} - {}",
5580 status, error_text
5581 )));
5582 continue;
5583 }
5584
5585 // Non-retryable error or max retries exceeded
5586 let error_text = resp.text().await.unwrap_or_default();
5587 if attempt > 0 {
5588 error!(
5589 "Upload '{}' failed after {} retries: HTTP {} - {}",
5590 filename, attempt, status, error_text
5591 );
5592 }
5593 return Err(Error::InvalidParameters(format!(
5594 "Upload failed: HTTP {} - {}",
5595 status, error_text
5596 )));
5597 }
5598 Err(e) => {
5599 // Transport error (timeout, connection failure, etc.)
5600 let is_timeout = e.is_timeout();
5601 let is_connect = e.is_connect();
5602
5603 if (is_timeout || is_connect) && attempt < max_retries {
5604 warn!(
5605 "Upload '{}' transport error (retrying): {}",
5606 filename,
5607 if is_timeout {
5608 "timeout"
5609 } else {
5610 "connection failed"
5611 }
5612 );
5613 last_error = Some(Error::HttpError(e));
5614 continue;
5615 }
5616
5617 // Non-retryable or max retries exceeded
5618 if attempt > 0 {
5619 error!(
5620 "Upload '{}' failed after {} retries: {}",
5621 filename, attempt, e
5622 );
5623 }
5624 return Err(Error::HttpError(e));
5625 }
5626 }
5627 }
5628
5629 // Should not reach here, but return last error if we do
5630 Err(last_error.unwrap_or_else(|| {
5631 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5632 }))
5633}
5634
5635/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
5636///
5637/// This is used for populate_samples to upload file content from memory
5638/// (e.g., from ZIP archives) without writing to disk first.
5639///
5640/// Includes explicit retry logic with exponential backoff for transient
5641/// failures.
5642async fn upload_bytes_to_presigned_url(
5643 http: reqwest::Client,
5644 url: &str,
5645 file_data: Vec<u8>,
5646 filename: &str,
5647) -> Result<(), Error> {
5648 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5649 .ok()
5650 .and_then(|s| s.parse().ok())
5651 .unwrap_or(5usize);
5652
5653 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5654 .ok()
5655 .and_then(|s| s.parse().ok())
5656 .unwrap_or(600u64);
5657
5658 let file_size = file_data.len();
5659 let mut last_error: Option<Error> = None;
5660
5661 for attempt in 0..=max_retries {
5662 if attempt > 0 {
5663 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5664 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5665 warn!(
5666 "Retry {}/{} for upload '{}' after {:?}",
5667 attempt, max_retries, filename, delay
5668 );
5669 tokio::time::sleep(delay).await;
5670 }
5671
5672 // Attempt upload
5673 let result = http
5674 .put(url)
5675 .header(CONTENT_LENGTH, file_size)
5676 .timeout(Duration::from_secs(upload_timeout_secs))
5677 .body(file_data.clone())
5678 .send()
5679 .await;
5680
5681 match result {
5682 Ok(resp) => {
5683 if resp.status().is_success() {
5684 if attempt > 0 {
5685 debug!(
5686 "Upload '{}' succeeded on retry {} ({} bytes)",
5687 filename, attempt, file_size
5688 );
5689 } else {
5690 debug!(
5691 "Successfully uploaded file: {} ({} bytes)",
5692 filename, file_size
5693 );
5694 }
5695 return Ok(());
5696 }
5697
5698 let status = resp.status();
5699 let status_code = status.as_u16();
5700
5701 // Check if error is retryable
5702 let is_retryable =
5703 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5704
5705 if is_retryable && attempt < max_retries {
5706 let error_text = resp.text().await.unwrap_or_default();
5707 warn!(
5708 "Upload '{}' failed with HTTP {} (retryable): {}",
5709 filename, status_code, error_text
5710 );
5711 last_error = Some(Error::InvalidParameters(format!(
5712 "Upload failed: HTTP {} - {}",
5713 status, error_text
5714 )));
5715 continue;
5716 }
5717
5718 // Non-retryable error or max retries exceeded
5719 let error_text = resp.text().await.unwrap_or_default();
5720 if attempt > 0 {
5721 error!(
5722 "Upload '{}' failed after {} retries: HTTP {} - {}",
5723 filename, attempt, status, error_text
5724 );
5725 }
5726 return Err(Error::InvalidParameters(format!(
5727 "Upload failed: HTTP {} - {}",
5728 status, error_text
5729 )));
5730 }
5731 Err(e) => {
5732 // Transport error (timeout, connection failure, etc.)
5733 let is_timeout = e.is_timeout();
5734 let is_connect = e.is_connect();
5735
5736 if (is_timeout || is_connect) && attempt < max_retries {
5737 warn!(
5738 "Upload '{}' transport error (retrying): {}",
5739 filename,
5740 if is_timeout {
5741 "timeout"
5742 } else {
5743 "connection failed"
5744 }
5745 );
5746 last_error = Some(Error::HttpError(e));
5747 continue;
5748 }
5749
5750 // Non-retryable or max retries exceeded
5751 if attempt > 0 {
5752 error!(
5753 "Upload '{}' failed after {} retries: {}",
5754 filename, attempt, e
5755 );
5756 }
5757 return Err(Error::HttpError(e));
5758 }
5759 }
5760 }
5761
5762 // Should not reach here, but return last error if we do
5763 Err(last_error.unwrap_or_else(|| {
5764 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5765 }))
5766}
5767
5768#[cfg(test)]
5769mod tests {
5770 use super::*;
5771
5772 #[test]
5773 fn test_filter_and_sort_by_name_exact_match_first() {
5774 // Test that exact matches come first
5775 let items = vec![
5776 "Deer Roundtrip 123".to_string(),
5777 "Deer".to_string(),
5778 "Reindeer".to_string(),
5779 "DEER".to_string(),
5780 ];
5781 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5782 assert_eq!(result[0], "Deer"); // Exact match first
5783 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
5784 }
5785
5786 #[test]
5787 fn test_filter_and_sort_by_name_shorter_names_preferred() {
5788 // Test that shorter names (more specific) come before longer ones
5789 let items = vec![
5790 "Test Dataset ABC".to_string(),
5791 "Test".to_string(),
5792 "Test Dataset".to_string(),
5793 ];
5794 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5795 assert_eq!(result[0], "Test"); // Exact match first
5796 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
5797 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
5798 }
5799
5800 #[test]
5801 fn test_filter_and_sort_by_name_case_insensitive_filter() {
5802 // Test that filtering is case-insensitive
5803 let items = vec![
5804 "UPPERCASE".to_string(),
5805 "lowercase".to_string(),
5806 "MixedCase".to_string(),
5807 ];
5808 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
5809 assert_eq!(result.len(), 3); // All items should match
5810 }
5811
5812 #[test]
5813 fn test_filter_and_sort_by_name_no_matches() {
5814 // Test that empty result is returned when no matches
5815 let items = vec!["Apple".to_string(), "Banana".to_string()];
5816 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
5817 assert!(result.is_empty());
5818 }
5819
5820 #[test]
5821 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
5822 // Test alphabetical ordering for same-length names
5823 let items = vec![
5824 "TestC".to_string(),
5825 "TestA".to_string(),
5826 "TestB".to_string(),
5827 ];
5828 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5829 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
5830 }
5831
5832 #[test]
5833 fn test_build_filename_no_flatten() {
5834 // When flatten=false, should return base_name unchanged
5835 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
5836 assert_eq!(result, "image.jpg");
5837
5838 let result = Client::build_filename("test.png", false, None, None);
5839 assert_eq!(result, "test.png");
5840 }
5841
5842 #[test]
5843 fn test_build_filename_flatten_no_sequence() {
5844 // When flatten=true but no sequence, should return base_name unchanged
5845 let result = Client::build_filename("standalone.jpg", true, None, None);
5846 assert_eq!(result, "standalone.jpg");
5847 }
5848
5849 #[test]
5850 fn test_build_filename_flatten_with_sequence_not_prefixed() {
5851 // When flatten=true, in sequence, filename not prefixed → add prefix
5852 let result = Client::build_filename(
5853 "image.camera.jpeg",
5854 true,
5855 Some(&"deer_sequence".to_string()),
5856 Some(42),
5857 );
5858 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
5859 }
5860
5861 #[test]
5862 fn test_build_filename_flatten_with_sequence_no_frame() {
5863 // When flatten=true, in sequence, no frame number → prefix with sequence only
5864 let result =
5865 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
5866 assert_eq!(result, "sequence_A_image.jpg");
5867 }
5868
5869 #[test]
5870 fn test_build_filename_flatten_already_prefixed() {
5871 // When flatten=true, filename already starts with sequence_ → return unchanged
5872 let result = Client::build_filename(
5873 "deer_sequence_042.camera.jpeg",
5874 true,
5875 Some(&"deer_sequence".to_string()),
5876 Some(42),
5877 );
5878 assert_eq!(result, "deer_sequence_042.camera.jpeg");
5879 }
5880
5881 #[test]
5882 fn test_build_filename_flatten_already_prefixed_different_frame() {
5883 // Edge case: filename has sequence prefix but we're adding different frame
5884 // Should still respect existing prefix
5885 let result = Client::build_filename(
5886 "sequence_A_001.jpg",
5887 true,
5888 Some(&"sequence_A".to_string()),
5889 Some(2),
5890 );
5891 assert_eq!(result, "sequence_A_001.jpg");
5892 }
5893
5894 #[test]
5895 fn test_build_filename_flatten_partial_match() {
5896 // Edge case: filename contains sequence name but not as prefix
5897 let result = Client::build_filename(
5898 "test_sequence_A_image.jpg",
5899 true,
5900 Some(&"sequence_A".to_string()),
5901 Some(5),
5902 );
5903 // Should add prefix because it doesn't START with "sequence_A_"
5904 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
5905 }
5906
5907 #[test]
5908 fn test_build_filename_flatten_preserves_extension() {
5909 // Verify that file extensions are preserved correctly
5910 let extensions = vec![
5911 "jpeg",
5912 "jpg",
5913 "png",
5914 "camera.jpeg",
5915 "lidar.pcd",
5916 "depth.png",
5917 ];
5918
5919 for ext in extensions {
5920 let filename = format!("image.{}", ext);
5921 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
5922 assert!(
5923 result.ends_with(&format!(".{}", ext)),
5924 "Extension .{} not preserved in {}",
5925 ext,
5926 result
5927 );
5928 }
5929 }
5930
5931 #[test]
5932 fn test_build_filename_flatten_sanitization_compatibility() {
5933 // Test with sanitized path components (no special chars)
5934 let result = Client::build_filename(
5935 "sample_001.jpg",
5936 true,
5937 Some(&"seq_name_with_underscores".to_string()),
5938 Some(10),
5939 );
5940 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
5941 }
5942
5943 // =========================================================================
5944 // Additional filter_and_sort_by_name tests for exact match determinism
5945 // =========================================================================
5946
5947 #[test]
5948 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
5949 // Test that searching for "Deer" always returns "Deer" first, not
5950 // "Deer Roundtrip 20251129" or similar
5951 let items = vec![
5952 "Deer Roundtrip 20251129".to_string(),
5953 "White-Tailed Deer".to_string(),
5954 "Deer".to_string(),
5955 "Deer Snapshot Test".to_string(),
5956 "Reindeer Dataset".to_string(),
5957 ];
5958
5959 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5960
5961 // CRITICAL: First result must be exact match "Deer"
5962 assert_eq!(
5963 result.first().map(|s| s.as_str()),
5964 Some("Deer"),
5965 "Expected exact match 'Deer' first, got: {:?}",
5966 result.first()
5967 );
5968
5969 // Verify all items containing "Deer" are present (case-insensitive)
5970 assert_eq!(result.len(), 5);
5971 }
5972
5973 #[test]
5974 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
5975 // Verify case-sensitive exact match takes priority over case-insensitive
5976 let items = vec![
5977 "DEER".to_string(),
5978 "deer".to_string(),
5979 "Deer".to_string(),
5980 "Deer Test".to_string(),
5981 ];
5982
5983 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5984
5985 // Priority 1: Case-sensitive exact match "Deer" first
5986 assert_eq!(result[0], "Deer");
5987 // Priority 2: Case-insensitive exact matches next
5988 assert!(result[1] == "DEER" || result[1] == "deer");
5989 assert!(result[2] == "DEER" || result[2] == "deer");
5990 }
5991
5992 #[test]
5993 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
5994 // Realistic scenario: User searches for snapshot "Deer" and multiple
5995 // snapshots exist with similar names
5996 let items = vec![
5997 "Unit Testing - Deer Dataset Backup".to_string(),
5998 "Deer".to_string(),
5999 "Deer Snapshot 2025-01-15".to_string(),
6000 "Original Deer".to_string(),
6001 ];
6002
6003 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6004
6005 // MUST return exact match first for deterministic test behavior
6006 assert_eq!(
6007 result[0], "Deer",
6008 "Searching for 'Deer' should return exact 'Deer' first"
6009 );
6010 }
6011
6012 #[test]
6013 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
6014 // Realistic scenario: User searches for dataset "Deer" but multiple
6015 // datasets have "Deer" in their name
6016 let items = vec![
6017 "Deer Roundtrip".to_string(),
6018 "Deer".to_string(),
6019 "deer".to_string(),
6020 "White-Tailed Deer".to_string(),
6021 "Deer-V2".to_string(),
6022 ];
6023
6024 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6025
6026 // Exact case-sensitive match must be first
6027 assert_eq!(result[0], "Deer");
6028 // Case-insensitive exact match should be second
6029 assert_eq!(result[1], "deer");
6030 // Shorter names should come before longer names
6031 assert!(
6032 result.iter().position(|s| s == "Deer-V2").unwrap()
6033 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
6034 );
6035 }
6036
6037 #[test]
6038 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
6039 // CRITICAL: The first result should ALWAYS be the best match
6040 // This is essential for deterministic test behavior
6041 let scenarios = vec![
6042 // (items, filter, expected_first)
6043 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
6044 (vec!["test", "TEST", "Test Data"], "test", "test"),
6045 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
6046 ];
6047
6048 for (items, filter, expected_first) in scenarios {
6049 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
6050 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
6051
6052 assert_eq!(
6053 result.first().map(|s| s.as_str()),
6054 Some(expected_first),
6055 "For filter '{}', expected first result '{}', got: {:?}",
6056 filter,
6057 expected_first,
6058 result.first()
6059 );
6060 }
6061 }
6062
6063 #[test]
6064 fn test_with_server_clears_storage() {
6065 use crate::storage::MemoryTokenStorage;
6066
6067 // Create client with memory storage and a token
6068 let storage = Arc::new(MemoryTokenStorage::new());
6069 storage.store("test-token").unwrap();
6070
6071 let client = Client::new().unwrap().with_storage(storage.clone());
6072
6073 // Verify token is loaded
6074 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
6075
6076 // Change server - should clear storage
6077 let _new_client = client.with_server("test").unwrap();
6078
6079 // Verify storage was cleared
6080 assert_eq!(storage.load().unwrap(), None);
6081 }
6082
6083 #[test]
6084 fn test_with_server_clears_storage_even_for_full_url() {
6085 // Regression: `with_server` used to short-circuit to `with_url`
6086 // when given a full URL, which preserved the bearer token. The
6087 // contract for `with_server` is that switching servers means
6088 // the token from the old server is no longer trusted.
6089 use crate::storage::MemoryTokenStorage;
6090
6091 let storage = Arc::new(MemoryTokenStorage::new());
6092 storage.store("token-from-old-server").unwrap();
6093 let client = Client::new().unwrap().with_storage(storage.clone());
6094 assert_eq!(
6095 storage.load().unwrap(),
6096 Some("token-from-old-server".to_string())
6097 );
6098
6099 // Switch to a self-hosted Studio (full URL). Storage must be
6100 // cleared, and the new client must have a blank in-memory token.
6101 let new_client = client
6102 .with_server("https://studio.example.com")
6103 .expect("https full URL through with_server");
6104 assert_eq!(storage.load().unwrap(), None);
6105 assert_eq!(new_client.url(), "https://studio.example.com");
6106
6107 // The new client should not carry the old token in memory either.
6108 let in_mem = tokio::runtime::Runtime::new()
6109 .unwrap()
6110 .block_on(async { new_client.token.read().await.clone() });
6111 assert!(in_mem.is_empty(), "expected blank token, got {in_mem:?}");
6112 }
6113
6114 #[test]
6115 fn test_with_server_rejects_insecure_full_url() {
6116 // `with_server` validates full URLs through `with_url`, so the
6117 // HTTPS rule applies uniformly. Plain http to a public host
6118 // must be rejected — the bearer token would otherwise leak in
6119 // plaintext when the caller next authenticates.
6120 let client = Client::new().unwrap();
6121 let err = client.with_server("http://studio.example.com").unwrap_err();
6122 assert!(matches!(err, Error::InsecureUrl(_)));
6123 }
6124
6125 // ===== with_url HTTPS enforcement =====
6126 //
6127 // The bearer token rides in the Authorization header, so plain
6128 // http:// to a public host would leak it in the clear. The function
6129 // must reject those URLs, but still let wiremock / local-dev URLs
6130 // through (loopback addresses, "localhost", "*.localhost").
6131
6132 #[test]
6133 fn with_url_accepts_https_public_host() {
6134 let client = Client::new().unwrap();
6135 let out = client
6136 .with_url("https://studio.example.com")
6137 .expect("https public host must be accepted");
6138 assert_eq!(out.url(), "https://studio.example.com");
6139 }
6140
6141 #[test]
6142 fn with_url_accepts_http_loopback_ipv4() {
6143 let client = Client::new().unwrap();
6144 let out = client
6145 .with_url("http://127.0.0.1:8080")
6146 .expect("http://127.0.0.1 must be accepted (loopback)");
6147 assert_eq!(out.url(), "http://127.0.0.1:8080");
6148 }
6149
6150 #[test]
6151 fn with_url_accepts_http_loopback_ipv6() {
6152 let client = Client::new().unwrap();
6153 let out = client
6154 .with_url("http://[::1]:8080")
6155 .expect("http://[::1] must be accepted (loopback)");
6156 assert!(out.url().starts_with("http://[::1]"));
6157 }
6158
6159 #[test]
6160 fn with_url_accepts_http_localhost() {
6161 let client = Client::new().unwrap();
6162 client
6163 .with_url("http://localhost:8080")
6164 .expect("http://localhost must be accepted");
6165 client
6166 .with_url("http://LOCALHOST")
6167 .expect("http://LOCALHOST must be accepted (case-insensitive)");
6168 client
6169 .with_url("http://wiremock.localhost")
6170 .expect("http://*.localhost must be accepted");
6171 }
6172
6173 #[test]
6174 fn with_url_rejects_http_public_host() {
6175 let client = Client::new().unwrap();
6176 let err = client.with_url("http://studio.example.com").unwrap_err();
6177 match err {
6178 Error::InsecureUrl(u) => assert_eq!(u, "http://studio.example.com"),
6179 other => panic!("expected InsecureUrl, got {other:?}"),
6180 }
6181 }
6182
6183 #[test]
6184 fn with_url_rejects_http_public_ip() {
6185 let client = Client::new().unwrap();
6186 // 8.8.8.8 is not loopback; must be rejected.
6187 let err = client.with_url("http://8.8.8.8").unwrap_err();
6188 assert!(matches!(err, Error::InsecureUrl(_)));
6189 }
6190
6191 #[test]
6192 fn with_url_rejects_non_http_scheme() {
6193 let client = Client::new().unwrap();
6194 // file:// would otherwise parse, but it's not a transport we
6195 // can use for RPC and we don't want to silently accept it.
6196 let err = client.with_url("file:///etc/passwd").unwrap_err();
6197 assert!(matches!(err, Error::InsecureUrl(_)));
6198 }
6199}
6200
6201#[cfg(test)]
6202mod tests_map_rpc_error {
6203 use super::*;
6204 use crate::api::TaskID;
6205
6206 #[test]
6207 fn maps_not_found_with_task_id_to_typed_variant() {
6208 // Server code 101 + "not found" message + task_id present → TaskNotFound
6209 let task_id = TaskID::try_from("task-1a2b").unwrap();
6210 let err = map_rpc_error(
6211 "task.data.list",
6212 101,
6213 "task not found".to_string(),
6214 Some(task_id),
6215 );
6216 assert!(matches!(err, Error::TaskNotFound(_)));
6217 }
6218
6219 #[test]
6220 fn maps_cannot_find_phrasing_to_typed_variant() {
6221 // The DVE server emits "Cannot find task..." — the original "not found"
6222 // substring match missed this and the caller saw a generic RpcError.
6223 let task_id = TaskID::try_from("task-1a2b").unwrap();
6224 let err = map_rpc_error(
6225 "task.data.list",
6226 101,
6227 "Cannot find task with id 6789".to_string(),
6228 Some(task_id),
6229 );
6230 assert!(
6231 matches!(err, Error::TaskNotFound(_)),
6232 "'Cannot find task' should map to TaskNotFound, got {err:?}"
6233 );
6234 }
6235
6236 #[test]
6237 fn maps_does_not_exist_phrasing_to_typed_variant() {
6238 let task_id = TaskID::try_from("task-1a2b").unwrap();
6239 let err = map_rpc_error(
6240 "task.chart.get",
6241 101,
6242 "task does not exist".to_string(),
6243 Some(task_id),
6244 );
6245 assert!(matches!(err, Error::TaskNotFound(_)));
6246 }
6247
6248 #[test]
6249 fn maps_code_101_with_unknown_phrasing_when_task_id_supplied() {
6250 // Server contract for code 101 is "resource not found"; even if the
6251 // phrasing is novel, the typed variant should be returned so callers
6252 // can write a stable `match`.
6253 let task_id = TaskID::try_from("task-1a2b").unwrap();
6254 let err = map_rpc_error(
6255 "task.data.list",
6256 101,
6257 "completely novel server message".to_string(),
6258 Some(task_id),
6259 );
6260 assert!(
6261 matches!(err, Error::TaskNotFound(_)),
6262 "code 101 + task_id should always map to TaskNotFound, got {err:?}"
6263 );
6264 }
6265
6266 #[test]
6267 fn maps_permission_codes_to_typed_variant() {
6268 for code in [401, 403] {
6269 let err = map_rpc_error("task.chart.add", code, "denied".to_string(), None);
6270 assert!(
6271 matches!(err, Error::PermissionDenied(_)),
6272 "code {} did not map",
6273 code
6274 );
6275 }
6276 }
6277
6278 #[test]
6279 fn permission_denied_records_method_for_diagnostics() {
6280 let err = map_rpc_error("task.data.upload", 403, "forbidden".to_string(), None);
6281 match err {
6282 Error::PermissionDenied(method) => assert_eq!(method, "task.data.upload"),
6283 other => panic!("expected PermissionDenied, got {:?}", other),
6284 }
6285 }
6286
6287 #[test]
6288 fn maps_payload_too_large_to_typed_variant() {
6289 let err = map_rpc_error("val.data.upload", 413, "request too large".into(), None);
6290 match err {
6291 Error::PayloadTooLarge { method, size_hint } => {
6292 assert_eq!(method, "val.data.upload");
6293 assert!(size_hint.is_none());
6294 }
6295 other => panic!("expected PayloadTooLarge, got {:?}", other),
6296 }
6297 }
6298
6299 #[test]
6300 fn falls_through_to_generic_rpc_error_for_unknown_codes() {
6301 let err = map_rpc_error("task.data.list", -99999, "weird".to_string(), None);
6302 match err {
6303 Error::RpcError(code, msg) => {
6304 assert_eq!(code, -99999);
6305 assert_eq!(msg, "weird");
6306 }
6307 other => panic!("expected RpcError, got {:?}", other),
6308 }
6309 }
6310
6311 #[test]
6312 fn not_found_without_task_id_falls_through() {
6313 // Code 101 without task_id → generic RpcError (no task to name)
6314 let err = map_rpc_error("task.data.list", 101, "not found".to_string(), None);
6315 assert!(matches!(err, Error::RpcError(101, _)));
6316 }
6317
6318 #[test]
6319 fn code_101_with_task_id_always_maps_even_with_unrelated_message() {
6320 // Previously the test asserted fall-through for non-"not found"
6321 // messages, but the contract for code 101 is "resource not found"
6322 // (see api.go), so when a task_id is present the typed variant is
6323 // returned unconditionally to give callers a stable error type.
6324 let task_id = TaskID::try_from("task-1a2b").unwrap();
6325 let err = map_rpc_error(
6326 "task.data.list",
6327 101,
6328 "permission denied".to_string(),
6329 Some(task_id),
6330 );
6331 assert!(matches!(err, Error::TaskNotFound(_)));
6332 }
6333}
6334
6335#[cfg(test)]
6336mod tests_jobs {
6337 use super::*;
6338
6339 #[test]
6340 fn jobs_list_request_serializes_to_empty_object() {
6341 let req = JobsListRequest {};
6342 assert_eq!(serde_json::to_value(&req).unwrap(), serde_json::json!({}));
6343 }
6344
6345 #[test]
6346 fn job_deserializes_from_bk_batch_shape() {
6347 let json = r#"{
6348 "code": "edgefirst-validator:2.9.5",
6349 "title": "EdgeFirst Validator",
6350 "job_name": "smoke-test",
6351 "job_id": "aws-batch-abc",
6352 "state": "RUNNING",
6353 "launch": "2026-05-14T15:00:00Z",
6354 "task_id": 6789,
6355 "docker_task": {},
6356 "extra_field": "ignored"
6357 }"#;
6358 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6359 assert_eq!(job.code, "edgefirst-validator:2.9.5");
6360 assert_eq!(job.state, "RUNNING");
6361 assert_eq!(job.task_id, 6789);
6362 assert_eq!(job.task_id().value(), 6789);
6363 }
6364}
6365
6366#[cfg(test)]
6367mod tests_job_run {
6368 use super::*;
6369 use crate::api::Parameter;
6370 use std::collections::HashMap;
6371
6372 #[test]
6373 fn job_run_request_serializes_with_expected_fields() {
6374 let req = JobRunRequest {
6375 name: "edgefirst-validator".into(),
6376 job_name: "post-profile-run".into(),
6377 env: HashMap::from([("LOG_LEVEL".into(), "info".into())]),
6378 data: HashMap::from([("validation_session_id".into(), Parameter::Integer(2707))]),
6379 };
6380 let json = serde_json::to_value(&req).unwrap();
6381 assert_eq!(json["name"], "edgefirst-validator");
6382 assert_eq!(json["job_name"], "post-profile-run");
6383 assert_eq!(json["env"]["LOG_LEVEL"], "info");
6384 assert_eq!(json["data"]["validation_session_id"], 2707);
6385 }
6386
6387 #[test]
6388 fn job_run_response_deserializes_as_job() {
6389 // job.run now returns the full BK_BATCH record; deserialize as Job.
6390 let json = r#"{
6391 "code": "edgefirst-validator:2.9.5",
6392 "title": "EdgeFirst Validator",
6393 "job_name": "post-profile-run",
6394 "job_id": "aws-batch-job-xxx",
6395 "state": "SUBMITTED",
6396 "task_id": 6789
6397 }"#;
6398 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6399 assert_eq!(job.task_id, 6789);
6400 assert_eq!(job.job_id, "aws-batch-job-xxx");
6401 assert_eq!(job.state, "SUBMITTED");
6402 }
6403}
6404
6405#[cfg(test)]
6406mod tests_job_stop {
6407 use super::*;
6408 use crate::api::TaskID;
6409
6410 #[test]
6411 fn job_stop_request_serializes_with_task_id() {
6412 let task_id = TaskID::try_from("task-1a2b").unwrap();
6413 let req = JobStopRequest {
6414 task_id: task_id.value(),
6415 };
6416 let json = serde_json::to_value(&req).unwrap();
6417 assert_eq!(json["task_id"], task_id.value());
6418 }
6419}
6420
6421#[cfg(test)]
6422mod tests_task_data_list_request {
6423 use super::*;
6424 use crate::api::TaskID;
6425
6426 #[test]
6427 fn task_data_list_request_serializes_with_task_id() {
6428 let task_id = TaskID::try_from("task-1a2b").unwrap();
6429 let req = TaskDataListRequest {
6430 task_id: task_id.value(),
6431 };
6432 let json = serde_json::to_value(&req).unwrap();
6433 assert_eq!(json["task_id"], task_id.value());
6434 }
6435}
6436
6437#[cfg(test)]
6438mod tests_task_data_download {
6439 use super::*;
6440 use crate::api::TaskID;
6441
6442 #[test]
6443 fn task_data_download_request_serializes_with_all_fields() {
6444 let task_id = TaskID::try_from("task-1a2b").unwrap();
6445 let req = TaskDataDownloadRequest {
6446 task_id: task_id.value(),
6447 folder: "predictions".into(),
6448 file: "predictions.parquet".into(),
6449 };
6450 let json = serde_json::to_value(&req).unwrap();
6451 assert_eq!(json["task_id"], task_id.value());
6452 assert_eq!(json["folder"], "predictions");
6453 assert_eq!(json["file"], "predictions.parquet");
6454 }
6455}
6456
6457#[cfg(test)]
6458mod tests_task_chart_add {
6459 use super::*;
6460 use crate::api::{Parameter, TaskID};
6461
6462 #[test]
6463 fn task_chart_add_request_serializes_with_correct_fields() {
6464 let task_id = TaskID::try_from("task-1a2b").unwrap();
6465 let data = Parameter::Object(std::collections::HashMap::from([(
6466 "type".into(),
6467 Parameter::String("line".into()),
6468 )]));
6469 let req = TaskChartAddRequest {
6470 task_id: task_id.value(),
6471 group_name: "metrics".into(),
6472 chart_name: "loss".into(),
6473 params: None,
6474 data,
6475 };
6476 let json = serde_json::to_value(&req).unwrap();
6477 assert_eq!(json["task_id"], task_id.value());
6478 assert_eq!(json["group_name"], "metrics");
6479 assert_eq!(json["chart_name"], "loss");
6480 assert_eq!(json["data"]["type"], "line");
6481 assert!(json["params"].is_null());
6482 }
6483}
6484
6485#[cfg(test)]
6486mod tests_task_chart_list {
6487 use super::*;
6488 use crate::api::TaskID;
6489
6490 #[test]
6491 fn task_chart_list_request_omits_empty_group_name() {
6492 let task_id = TaskID::try_from("task-1a2b").unwrap();
6493 let req = TaskChartListRequest {
6494 task_id: task_id.value(),
6495 group_name: String::new(),
6496 };
6497 let json = serde_json::to_value(&req).unwrap();
6498 assert_eq!(json["task_id"], task_id.value());
6499 assert_eq!(json["group_name"], "");
6500 }
6501}
6502
6503#[cfg(test)]
6504mod tests_task_chart_get {
6505 use super::*;
6506 use crate::api::TaskID;
6507
6508 #[test]
6509 fn task_chart_get_request_serializes_with_all_fields() {
6510 let task_id = TaskID::try_from("task-1a2b").unwrap();
6511 let req = TaskChartGetRequest {
6512 task_id: task_id.value(),
6513 group_name: "metrics".into(),
6514 chart_name: "loss".into(),
6515 };
6516 let json = serde_json::to_value(&req).unwrap();
6517 assert_eq!(json["task_id"], task_id.value());
6518 assert_eq!(json["group_name"], "metrics");
6519 assert_eq!(json["chart_name"], "loss");
6520 }
6521}
6522
6523#[cfg(test)]
6524mod tests_val_data_download {
6525 use super::*;
6526
6527 #[test]
6528 fn val_data_download_request_serializes() {
6529 let req = ValDataDownloadRequest {
6530 session_id: 2707,
6531 filename: "trace/imx95.json".into(),
6532 };
6533 let json = serde_json::to_value(&req).unwrap();
6534 assert_eq!(json["session_id"], 2707);
6535 assert_eq!(json["filename"], "trace/imx95.json");
6536 }
6537}
6538
6539#[cfg(test)]
6540mod tests_val_data_list {
6541 use super::*;
6542
6543 #[test]
6544 fn val_data_list_request_serializes() {
6545 let req = ValDataListRequest { session_id: 2707 };
6546 assert_eq!(
6547 serde_json::to_value(&req).unwrap(),
6548 serde_json::json!({"session_id": 2707})
6549 );
6550 }
6551}
6552
6553#[cfg(test)]
6554mod tests_jsonrpc_envelope_detection {
6555 use super::*;
6556
6557 #[test]
6558 fn detects_real_envelope() {
6559 let v = serde_json::json!({
6560 "jsonrpc": "2.0",
6561 "id": 0,
6562 "error": { "code": 101, "message": "Cannot find task" },
6563 });
6564 assert!(is_jsonrpc_error_envelope(&v));
6565 }
6566
6567 #[test]
6568 fn rejects_plain_json_artifact_with_error_field() {
6569 // A diagnostics file with a free-form `error` object — must not be
6570 // misread as an RPC envelope just because the key collides.
6571 let v = serde_json::json!({
6572 "metric": "loss",
6573 "value": 0.42,
6574 "error": { "code": "ENV_NOT_FOUND", "message": "missing var" },
6575 });
6576 assert!(
6577 !is_jsonrpc_error_envelope(&v),
6578 "missing jsonrpc sentinel should mean 'not an envelope'"
6579 );
6580 }
6581
6582 #[test]
6583 fn rejects_envelope_missing_jsonrpc_sentinel() {
6584 // Bare `error` block without the protocol-version marker.
6585 let v = serde_json::json!({
6586 "id": 0,
6587 "error": { "code": 101, "message": "x" },
6588 });
6589 assert!(!is_jsonrpc_error_envelope(&v));
6590 }
6591
6592 #[test]
6593 fn rejects_envelope_with_non_object_error_field() {
6594 // A diagnostics file shaped like JSON-RPC accidentally but using
6595 // a string for `error`.
6596 let v = serde_json::json!({
6597 "jsonrpc": "2.0",
6598 "error": "something went wrong",
6599 });
6600 assert!(!is_jsonrpc_error_envelope(&v));
6601 }
6602
6603 #[test]
6604 fn rejects_envelope_without_error_code() {
6605 // Real envelopes always carry an integer error.code; missing one
6606 // is suspicious enough to refuse the envelope classification.
6607 let v = serde_json::json!({
6608 "jsonrpc": "2.0",
6609 "error": { "message": "no code" },
6610 });
6611 assert!(!is_jsonrpc_error_envelope(&v));
6612 }
6613
6614 #[test]
6615 fn rejects_envelope_with_non_numeric_error_code() {
6616 let v = serde_json::json!({
6617 "jsonrpc": "2.0",
6618 "error": { "code": "ENOENT", "message": "x" },
6619 });
6620 assert!(!is_jsonrpc_error_envelope(&v));
6621 }
6622
6623 #[test]
6624 fn rejects_non_object_root() {
6625 // A JSON file whose root is an array — common for metrics dumps —
6626 // must not be misread.
6627 let v = serde_json::json!([1, 2, 3]);
6628 assert!(!is_jsonrpc_error_envelope(&v));
6629 }
6630
6631 #[test]
6632 fn accepts_unsigned_error_code() {
6633 // The server's code is technically i32 but JSON has no signed/
6634 // unsigned distinction — accept both shapes.
6635 let v = serde_json::json!({
6636 "jsonrpc": "2.0",
6637 "error": { "code": 101u32, "message": "x" },
6638 });
6639 assert!(is_jsonrpc_error_envelope(&v));
6640 }
6641}
6642
6643#[cfg(test)]
6644mod tests_validate_chart_args {
6645 use super::*;
6646
6647 #[test]
6648 fn rejects_empty_group() {
6649 let err = validate_chart_args("", "name").unwrap_err();
6650 assert!(matches!(err, Error::InvalidParameters(_)));
6651 }
6652
6653 #[test]
6654 fn rejects_empty_name() {
6655 let err = validate_chart_args("group", "").unwrap_err();
6656 assert!(matches!(err, Error::InvalidParameters(_)));
6657 }
6658
6659 #[test]
6660 fn rejects_both_empty() {
6661 let err = validate_chart_args("", "").unwrap_err();
6662 assert!(matches!(err, Error::InvalidParameters(_)));
6663 }
6664
6665 #[test]
6666 fn accepts_valid_args() {
6667 assert!(validate_chart_args("group", "name").is_ok());
6668 }
6669
6670 #[test]
6671 fn accepts_unicode_args() {
6672 // Unicode names are allowed; only emptiness is rejected.
6673 assert!(validate_chart_args("metrics-集合", "损失").is_ok());
6674 }
6675}
6676
6677// ---------------------------------------------------------------------------
6678// Additional offline tests for request shapes + helpers added in DE-2565.
6679//
6680// These focus on the wire-shape and helper logic that does not require a
6681// live Studio server — they significantly boost coverage of client.rs.
6682// ---------------------------------------------------------------------------
6683
6684#[cfg(test)]
6685mod tests_job_run_request_shape {
6686 use super::*;
6687 use crate::api::Parameter;
6688 use std::collections::HashMap;
6689
6690 #[test]
6691 fn empty_env_and_data_serialize_as_empty_objects() {
6692 let req = JobRunRequest {
6693 name: "edgefirst-validator".into(),
6694 job_name: "smoke".into(),
6695 env: HashMap::new(),
6696 data: HashMap::new(),
6697 };
6698 let json = serde_json::to_value(&req).unwrap();
6699 assert_eq!(json["name"], "edgefirst-validator");
6700 assert_eq!(json["env"], serde_json::json!({}));
6701 assert_eq!(json["data"], serde_json::json!({}));
6702 }
6703
6704 #[test]
6705 fn data_passes_through_parameter_object_payloads() {
6706 // Confirms the Parameter wrapper survives JSON serialization round-trip
6707 // for the kind of structured chart payload that exercises Parameter
6708 // variants (Real, Integer, String, Array, Object, Boolean).
6709 let req = JobRunRequest {
6710 name: "edgefirst-validator".into(),
6711 job_name: "feat".into(),
6712 env: HashMap::new(),
6713 data: HashMap::from([
6714 ("flag".into(), Parameter::Boolean(true)),
6715 ("epochs".into(), Parameter::Integer(50)),
6716 ("lr".into(), Parameter::Real(1e-3)),
6717 ("name".into(), Parameter::String("hello".into())),
6718 ]),
6719 };
6720 let json = serde_json::to_value(&req).unwrap();
6721 assert_eq!(json["data"]["flag"], true);
6722 assert_eq!(json["data"]["epochs"], 50);
6723 assert!(json["data"]["lr"].as_f64().unwrap() > 0.0);
6724 assert_eq!(json["data"]["name"], "hello");
6725 }
6726}
6727
6728#[cfg(test)]
6729mod tests_task_data_chart_request_shape {
6730 use super::*;
6731 use crate::api::{Parameter, TaskID};
6732
6733 #[test]
6734 fn chart_add_request_with_params_serializes_object() {
6735 let task_id = TaskID::try_from("task-1a2b").unwrap();
6736 let params = Parameter::Object(std::collections::HashMap::from([(
6737 "y_axis".into(),
6738 Parameter::String("log".into()),
6739 )]));
6740 let data = Parameter::Object(std::collections::HashMap::from([(
6741 "type".into(),
6742 Parameter::String("line".into()),
6743 )]));
6744 let req = TaskChartAddRequest {
6745 task_id: task_id.value(),
6746 group_name: "metrics".into(),
6747 chart_name: "loss".into(),
6748 params: Some(params),
6749 data,
6750 };
6751 let json = serde_json::to_value(&req).unwrap();
6752 assert_eq!(json["params"]["y_axis"], "log");
6753 }
6754
6755 #[test]
6756 fn task_data_list_request_round_trips() {
6757 let task_id = TaskID::try_from("task-1a2b").unwrap();
6758 let req = TaskDataListRequest {
6759 task_id: task_id.value(),
6760 };
6761 let json = serde_json::to_string(&req).unwrap();
6762 // Field order is stable for a single-field struct, so an exact match
6763 // is meaningful here.
6764 assert_eq!(json, format!("{{\"task_id\":{}}}", task_id.value()));
6765 }
6766
6767 #[test]
6768 fn task_data_download_request_treats_folder_and_file_independently() {
6769 let task_id = TaskID::try_from("task-1a2b").unwrap();
6770 let req = TaskDataDownloadRequest {
6771 task_id: task_id.value(),
6772 folder: "validation/run-01".into(),
6773 file: "metrics.json".into(),
6774 };
6775 let json = serde_json::to_value(&req).unwrap();
6776 // Server takes folder + file separately (not a single combined path)
6777 // so callers don't have to escape slashes themselves.
6778 assert_eq!(json["folder"], "validation/run-01");
6779 assert_eq!(json["file"], "metrics.json");
6780 }
6781}
6782
6783#[cfg(test)]
6784mod tests_val_data_request_shape {
6785 use super::*;
6786
6787 #[test]
6788 fn val_data_list_round_trips() {
6789 let req = ValDataListRequest { session_id: 2707 };
6790 let s = serde_json::to_string(&req).unwrap();
6791 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6792 assert_eq!(back["session_id"], 2707);
6793 }
6794
6795 #[test]
6796 fn val_data_download_round_trips_with_nested_path() {
6797 let req = ValDataDownloadRequest {
6798 session_id: 2707,
6799 filename: "subfolder/imx95.json".into(),
6800 };
6801 let s = serde_json::to_string(&req).unwrap();
6802 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6803 assert_eq!(back["session_id"], 2707);
6804 assert_eq!(back["filename"], "subfolder/imx95.json");
6805 }
6806}
6807
6808#[cfg(test)]
6809mod tests_progress_struct {
6810 use super::*;
6811
6812 #[test]
6813 fn progress_can_be_constructed_with_zero_total() {
6814 // Servers sometimes omit Content-Length; progress events should still
6815 // be representable. This guards the public field-level API.
6816 let p = Progress {
6817 current: 0,
6818 total: 0,
6819 status: None,
6820 };
6821 assert_eq!(p.current, 0);
6822 assert_eq!(p.total, 0);
6823 assert!(p.status.is_none());
6824 }
6825
6826 #[test]
6827 fn progress_tracks_current_independently_of_total() {
6828 let p = Progress {
6829 current: 123,
6830 total: 456,
6831 status: Some("Downloading".into()),
6832 };
6833 assert_eq!(p.current, 123);
6834 assert_eq!(p.total, 456);
6835 assert_eq!(p.status.as_deref(), Some("Downloading"));
6836 }
6837
6838 #[test]
6839 fn progress_can_be_cloned() {
6840 // Progress is consumed by progress sinks which may need to retain a
6841 // copy independently of the channel — derive(Clone) must hold.
6842 let p = Progress {
6843 current: 10,
6844 total: 20,
6845 status: Some("phase".into()),
6846 };
6847 let q = p.clone();
6848 assert_eq!(q.current, p.current);
6849 assert_eq!(q.total, p.total);
6850 assert_eq!(q.status, p.status);
6851 }
6852}
6853
6854#[cfg(test)]
6855mod tests_bare_filename_parent {
6856 // Documents the empty-parent guard added for `rpc_download` so that
6857 // callers passing a bare filename like "metrics.json" download to the
6858 // current directory instead of erroring on `create_dir_all("")`.
6859 use std::path::Path;
6860
6861 #[test]
6862 fn bare_filename_parent_is_empty_path() {
6863 // This is the invariant our guard depends on. If a future Rust
6864 // release ever changed `Path::parent` for bare filenames, the guard
6865 // would need revisiting.
6866 let p = Path::new("metrics.json");
6867 let parent = p.parent().expect("bare filename always has Some parent");
6868 assert!(
6869 parent.as_os_str().is_empty(),
6870 "Path::parent for bare filename should be empty, got: {parent:?}"
6871 );
6872 }
6873
6874 #[test]
6875 fn path_with_directory_has_non_empty_parent() {
6876 // The companion case: when the path includes a directory, the
6877 // parent is non-empty and `create_dir_all` should be invoked.
6878 let p = Path::new("dir/metrics.json");
6879 let parent = p.parent().expect("path-with-dir always has Some parent");
6880 assert!(!parent.as_os_str().is_empty());
6881 assert_eq!(parent, Path::new("dir"));
6882 }
6883}