edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult,
8 NewValidationSession, Organization, Project, ProjectID, SampleID, SamplesCountResult,
9 SamplesListParams, SamplesListResult, Snapshot, SnapshotCreateFromDataset,
10 SnapshotFromDatasetResult, SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage,
11 StartValidationRequest, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
12 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
13 ValidationSessionID,
14 },
15 dataset::{
16 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
17 },
18 retry::{create_retry_policy, log_retry_configuration},
19 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
20};
21use base64::Engine as _;
22use chrono::{DateTime, Utc};
23use directories::ProjectDirs;
24use futures::{StreamExt as _, future::join_all};
25use log::{Level, debug, error, log_enabled, trace, warn};
26use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::{
29 collections::HashMap,
30 ffi::OsStr,
31 fs::create_dir_all,
32 io::{SeekFrom, Write as _},
33 path::{Path, PathBuf},
34 sync::{
35 Arc,
36 atomic::{AtomicUsize, Ordering},
37 },
38 time::Duration,
39 vec,
40};
41use tokio::{
42 fs::{self, File},
43 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
44 sync::{RwLock, Semaphore, mpsc::Sender},
45};
46use tokio_util::codec::{BytesCodec, FramedRead};
47use walkdir::WalkDir;
48
49#[cfg(feature = "polars")]
50use polars::prelude::*;
51
52/// Maps a JSON-RPC error code to a typed `Error` variant when the code is
53/// well-known; otherwise returns `Error::RpcError(code, message)` unchanged.
54///
55/// Scoped to the new DE-2565 methods. Existing methods continue to return
56/// `Error::RpcError` directly.
57///
58/// Server error codes (from `api.go` via `jrpc.Fail`):
59/// - `1` – generic server error
60/// - `3` – validation / bad request
61/// - `10` – internal server error
62/// - `101` – resource not found (e.g. "Cannot find task...", "not found in DB")
63/// - `401` – unauthenticated
64/// - `403` – forbidden
65/// - `413` – payload too large
66pub(crate) fn map_rpc_error(
67 method: &str,
68 code: i32,
69 message: String,
70 task_id: Option<crate::api::TaskID>,
71) -> Error {
72 // Server emits "Cannot find task...", "not found in DB", and other phrasings
73 // for code 101. Code 101 with a task_id is task-not-found by contract
74 // (see api.go), so we return the typed variant unconditionally when the
75 // caller supplied a task_id — message phrasing is treated as informational
76 // and is preserved by the RPC layer for diagnostic logging upstream.
77 if code == 101
78 && let Some(id) = task_id
79 {
80 return Error::TaskNotFound(id);
81 }
82 match code {
83 401 | 403 => Error::PermissionDenied(method.to_string()),
84 413 => Error::PayloadTooLarge {
85 method: method.to_string(),
86 size_hint: None,
87 },
88 _ => Error::RpcError(code, message),
89 }
90}
91
92/// Returns true if `val` is structurally a JSON-RPC 2.0 *error* envelope.
93///
94/// A real envelope must:
95/// 1. Be a JSON object,
96/// 2. Carry a `"jsonrpc"` member (the protocol-version sentinel — JSON-RPC
97/// 2.0 §5 mandates this on every response object),
98/// 3. Carry an `"error"` object that includes a numeric `"code"` field.
99///
100/// This is intentionally stricter than a "looks for a top-level `error`
101/// key" check so that legitimate JSON file payloads (validation traces,
102/// metrics dumps, diagnostics) which happen to include a free-form `error`
103/// field are *not* misclassified as RPC failures.
104///
105/// Extracted so it can be unit-tested without a live server.
106pub(crate) fn is_jsonrpc_error_envelope(val: &serde_json::Value) -> bool {
107 let Some(obj) = val.as_object() else {
108 return false;
109 };
110 // Protocol-version sentinel — only JSON-RPC envelopes carry this.
111 if !obj.contains_key("jsonrpc") {
112 return false;
113 }
114 let Some(err) = obj.get("error").and_then(|e| e.as_object()) else {
115 return false;
116 };
117 err.get("code")
118 .map(|c| c.is_i64() || c.is_u64())
119 .unwrap_or(false)
120}
121
122/// Validates that `group` and `name` are both non-empty strings for chart
123/// operations (`add_chart`, `get_chart`). Extracted so it can be unit-tested
124/// without a live server.
125pub(crate) fn validate_chart_args(group: &str, name: &str) -> Result<(), Error> {
126 if group.is_empty() || name.is_empty() {
127 return Err(Error::InvalidParameters(
128 "chart: group and name must be non-empty".into(),
129 ));
130 }
131 Ok(())
132}
133
134static PART_SIZE: usize = 100 * 1024 * 1024;
135
136/// Source for file content during upload - either a local path or raw bytes.
137#[derive(Clone)]
138enum FileSource {
139 /// File content from a local filesystem path.
140 Path(PathBuf),
141 /// File content as raw bytes (e.g., from a ZIP archive).
142 Bytes(Vec<u8>),
143}
144
145fn max_tasks() -> usize {
146 std::env::var("MAX_TASKS")
147 .ok()
148 .and_then(|v| v.parse().ok())
149 .unwrap_or_else(|| {
150 // Default to half the number of CPUs, minimum 2, maximum 8
151 let cpus = std::thread::available_parallelism()
152 .map(|n| n.get())
153 .unwrap_or(4);
154 (cpus / 2).clamp(2, 8)
155 })
156}
157
158/// Maximum concurrent upload tasks for multipart S3 uploads.
159///
160/// Higher concurrency improves upload throughput by saturating available
161/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
162fn max_upload_tasks() -> usize {
163 std::env::var("MAX_UPLOAD_TASKS")
164 .ok()
165 .and_then(|v| v.parse().ok())
166 .unwrap_or(8) // Default to 8 concurrent part uploads
167}
168
169/// Filters items by name and sorts by match quality.
170///
171/// Match quality priority (best to worst):
172/// 1. Exact match (case-sensitive)
173/// 2. Exact match (case-insensitive)
174/// 3. Substring match (shorter names first, then alphabetically)
175///
176/// This ensures that searching for "Deer" returns "Deer" before
177/// "Deer Roundtrip 20251129" or "Reindeer".
178fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
179where
180 F: Fn(&T) -> &str,
181{
182 let filter_lower = filter.to_lowercase();
183 let mut filtered: Vec<T> = items
184 .into_iter()
185 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
186 .collect();
187
188 filtered.sort_by(|a, b| {
189 let name_a = get_name(a);
190 let name_b = get_name(b);
191
192 // Priority 1: Exact match (case-sensitive)
193 let exact_a = name_a == filter;
194 let exact_b = name_b == filter;
195 if exact_a != exact_b {
196 return exact_b.cmp(&exact_a); // true (exact) comes first
197 }
198
199 // Priority 2: Exact match (case-insensitive)
200 let exact_ci_a = name_a.to_lowercase() == filter_lower;
201 let exact_ci_b = name_b.to_lowercase() == filter_lower;
202 if exact_ci_a != exact_ci_b {
203 return exact_ci_b.cmp(&exact_ci_a);
204 }
205
206 // Priority 3: Shorter names first (more specific matches)
207 let len_cmp = name_a.len().cmp(&name_b.len());
208 if len_cmp != std::cmp::Ordering::Equal {
209 return len_cmp;
210 }
211
212 // Priority 4: Alphabetical order for stability
213 name_a.cmp(name_b)
214 });
215
216 filtered
217}
218
219/// Whether `host` refers to a loopback (machine-local) endpoint.
220///
221/// Used by [`Client::with_url`] to decide whether a plain-`http://` URL is
222/// safe to accept. Loopback traffic never leaves the machine, so the
223/// usual concern about leaking the Studio bearer token in plaintext does
224/// not apply — that's how wiremock and local dev servers connect.
225fn is_loopback_host(host: Option<&url::Host<&str>>) -> bool {
226 match host {
227 Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
228 Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
229 // RFC 6761 reserves "localhost" (and `*.localhost`) as a loopback
230 // name. Compare case-insensitively because URL hosts are matched
231 // that way and developers do type capitalized variants.
232 Some(url::Host::Domain(d)) => {
233 d.eq_ignore_ascii_case("localhost") || d.to_ascii_lowercase().ends_with(".localhost")
234 }
235 None => false,
236 }
237}
238
239fn sanitize_path_component(name: &str) -> String {
240 let trimmed = name.trim();
241 if trimmed.is_empty() {
242 return "unnamed".to_string();
243 }
244
245 let component = Path::new(trimmed)
246 .file_name()
247 .unwrap_or_else(|| OsStr::new(trimmed));
248
249 let sanitized: String = component
250 .to_string_lossy()
251 .chars()
252 .map(|c| match c {
253 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
254 _ => c,
255 })
256 .collect();
257
258 if sanitized.is_empty() {
259 "unnamed".to_string()
260 } else {
261 sanitized
262 }
263}
264
265/// Progress information for long-running operations.
266///
267/// This struct tracks the current progress of operations like file uploads,
268/// downloads, or dataset processing. It provides the current count, total
269/// count, and an optional status string to enable progress reporting in
270/// applications.
271///
272/// # Multi-Stage Progress
273///
274/// The `status` field enables multi-stage progress tracking. When an operation
275/// has multiple phases, the status field changes to indicate the current phase.
276/// Applications should detect status changes to reset their progress display.
277///
278/// # Operation Progress Details
279///
280/// | Operation | Status | Unit | Notes |
281/// |-----------|--------|------|-------|
282/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
283/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
284/// | [`samples`] | `None` | samples | Paginated API fetch |
285/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
286/// | [`annotations`] | `None` | samples | Samples processed for annotations |
287/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
288/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
289/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
290///
291/// [`download_dataset`]: Client::download_dataset
292/// [`populate_samples`]: Client::populate_samples
293/// [`samples`]: Client::samples
294/// [`sample_names`]: Client::sample_names
295/// [`annotations`]: Client::annotations
296/// [`download_artifact`]: Client::download_artifact
297/// [`download_checkpoint`]: Client::download_checkpoint
298/// [`download_snapshot`]: Client::download_snapshot
299///
300/// # Examples
301///
302/// Basic progress display:
303///
304/// ```rust
305/// use edgefirst_client::Progress;
306///
307/// let progress = Progress {
308/// current: 25,
309/// total: 100,
310/// status: Some("Downloading".to_string()),
311/// };
312/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
313/// println!(
314/// "{}: {:.1}% ({}/{})",
315/// progress.status.as_deref().unwrap_or("Progress"),
316/// percentage,
317/// progress.current,
318/// progress.total
319/// );
320/// ```
321///
322/// Multi-stage progress handling (e.g., for `download_dataset`):
323///
324/// ```rust,ignore
325/// let mut last_status: Option<String> = None;
326///
327/// while let Some(progress) = rx.recv().await {
328/// // Detect stage change and reset progress bar
329/// if progress.status != last_status {
330/// if let Some(ref status) = progress.status {
331/// println!("\n{}", status);
332/// }
333/// last_status = progress.status.clone();
334/// }
335///
336/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
337/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
338/// }
339/// ```
340#[derive(Debug, Clone)]
341pub struct Progress {
342 /// Current number of completed items or bytes.
343 pub current: usize,
344 /// Total number of items or bytes to process.
345 pub total: usize,
346 /// Optional status describing the current operation phase.
347 ///
348 /// When this value changes from `None` to `Some(...)` or between different
349 /// values, it indicates a new phase has started. Applications should reset
350 /// their progress display when the status changes.
351 ///
352 /// Currently only [`Client::download_dataset`] uses status changes:
353 /// - Phase 1: `None` while fetching sample metadata
354 /// - Phase 2: `"Downloading"` while downloading files
355 ///
356 /// All other operations use `None` throughout.
357 pub status: Option<String>,
358}
359
360#[derive(Serialize)]
361struct RpcRequest<Params> {
362 id: u64,
363 jsonrpc: String,
364 method: String,
365 params: Option<Params>,
366}
367
368impl<T> Default for RpcRequest<T> {
369 fn default() -> Self {
370 RpcRequest {
371 id: 0,
372 jsonrpc: "2.0".to_string(),
373 method: "".to_string(),
374 params: None,
375 }
376 }
377}
378
379#[derive(Deserialize)]
380struct RpcError {
381 code: i32,
382 message: String,
383}
384
385#[derive(Deserialize)]
386struct RpcResponse<RpcResult> {
387 #[allow(dead_code)]
388 id: String,
389 #[allow(dead_code)]
390 jsonrpc: String,
391 error: Option<RpcError>,
392 result: Option<RpcResult>,
393}
394
395#[derive(Deserialize)]
396#[allow(dead_code)]
397struct EmptyResult {}
398
399#[derive(Debug, Serialize)]
400#[allow(dead_code)]
401struct SnapshotCreateParams {
402 snapshot_name: String,
403 keys: Vec<String>,
404}
405
406#[derive(Debug, Deserialize)]
407#[allow(dead_code)]
408struct SnapshotCreateResult {
409 snapshot_id: SnapshotID,
410 urls: Vec<String>,
411}
412
413#[derive(Debug, Serialize)]
414struct SnapshotCreateMultipartParams {
415 snapshot_name: String,
416 keys: Vec<String>,
417 file_sizes: Vec<usize>,
418 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
419 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
420 snapshot_type: Option<String>,
421}
422
423#[derive(Debug, Deserialize)]
424#[serde(untagged)]
425enum SnapshotCreateMultipartResultField {
426 Id(u64),
427 Part(SnapshotPart),
428}
429
430#[derive(Debug, Serialize)]
431struct SnapshotCompleteMultipartParams {
432 key: String,
433 upload_id: String,
434 etag_list: Vec<EtagPart>,
435}
436
437#[derive(Debug, Clone, Serialize)]
438struct EtagPart {
439 #[serde(rename = "ETag")]
440 etag: String,
441 #[serde(rename = "PartNumber")]
442 part_number: usize,
443}
444
445#[derive(Debug, Clone, Deserialize)]
446struct SnapshotPart {
447 key: Option<String>,
448 upload_id: String,
449 urls: Vec<String>,
450}
451
452#[derive(Debug, Serialize)]
453struct SnapshotStatusParams {
454 snapshot_id: SnapshotID,
455 status: String,
456}
457
458#[derive(Deserialize, Debug)]
459struct SnapshotStatusResult {
460 #[allow(dead_code)]
461 pub id: SnapshotID,
462 #[allow(dead_code)]
463 pub uid: String,
464 #[allow(dead_code)]
465 pub description: String,
466 #[allow(dead_code)]
467 pub date: String,
468 #[allow(dead_code)]
469 pub status: String,
470}
471
472#[derive(Serialize)]
473#[allow(dead_code)]
474struct ImageListParams {
475 images_filter: ImagesFilter,
476 image_files_filter: HashMap<String, String>,
477 only_ids: bool,
478}
479
480#[derive(Serialize)]
481#[allow(dead_code)]
482struct ImagesFilter {
483 dataset_id: DatasetID,
484}
485
486/// Main client for interacting with EdgeFirst Studio Server.
487///
488/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
489/// and manages authentication, RPC calls, and data operations. It provides
490/// methods for managing projects, datasets, experiments, training sessions,
491/// and various utility functions for data processing.
492///
493/// The client supports multiple authentication methods and can work with both
494/// SaaS and self-hosted EdgeFirst Studio instances.
495///
496/// # Features
497///
498/// - **Authentication**: Token-based authentication with automatic persistence
499/// - **Dataset Management**: Upload, download, and manipulate datasets
500/// - **Project Operations**: Create and manage projects and experiments
501/// - **Training & Validation**: Submit and monitor ML training jobs
502/// - **Data Integration**: Convert between EdgeFirst datasets and popular
503/// formats
504/// - **Progress Tracking**: Real-time progress updates for long-running
505/// operations
506///
507/// # Examples
508///
509/// ```no_run
510/// use edgefirst_client::{Client, DatasetID};
511/// use std::str::FromStr;
512///
513/// # async fn example() -> Result<(), edgefirst_client::Error> {
514/// // Create a new client and authenticate
515/// let mut client = Client::new()?;
516/// let client = client
517/// .with_login("your-email@example.com", "password")
518/// .await?;
519///
520/// // Or use an existing token
521/// let base_client = Client::new()?;
522/// let client = base_client.with_token("your-token-here")?;
523///
524/// // Get organization and projects
525/// let org = client.organization().await?;
526/// let projects = client.projects(None).await?;
527///
528/// // Work with datasets
529/// let dataset_id = DatasetID::from_str("ds-abc123")?;
530/// let dataset = client.dataset(dataset_id).await?;
531/// # Ok(())
532/// # }
533/// ```
534/// Client is Clone but cannot derive Debug due to dyn TokenStorage
535#[derive(Clone)]
536pub struct Client {
537 http: reqwest::Client,
538 /// HTTP client for long-running bulk transfers (uploads/downloads, no total-request
539 /// timeout). An idle read timeout is still configured on the underlying client, and
540 /// some operations (such as uploads) may apply additional per-request timeouts.
541 bulk_http: reqwest::Client,
542 url: String,
543 token: Arc<RwLock<String>>,
544 /// Token storage backend. When set, tokens are automatically persisted.
545 storage: Option<Arc<dyn TokenStorage>>,
546 /// Legacy token path field for backwards compatibility with
547 /// with_token_path(). Deprecated: Use with_storage() instead.
548 token_path: Option<PathBuf>,
549}
550
551impl std::fmt::Debug for Client {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("Client")
554 .field("url", &self.url)
555 .field("has_storage", &self.storage.is_some())
556 .field("token_path", &self.token_path)
557 .finish()
558 }
559}
560
561/// Private context struct for pagination operations
562struct FetchContext<'a> {
563 dataset_id: DatasetID,
564 annotation_set_id: Option<AnnotationSetID>,
565 groups: &'a [String],
566 types: Vec<String>,
567 labels: &'a HashMap<String, u64>,
568}
569
570#[derive(Debug, Serialize)]
571struct JobsListRequest {}
572
573#[derive(Debug, Serialize)]
574struct JobRunRequest {
575 name: String,
576 job_name: String,
577 env: std::collections::HashMap<String, String>,
578 data: std::collections::HashMap<String, crate::api::Parameter>,
579}
580
581#[derive(Debug, Serialize)]
582struct JobStopRequest {
583 task_id: u64,
584}
585
586#[derive(Debug, Serialize)]
587pub(crate) struct TaskDataListRequest {
588 pub(crate) task_id: u64,
589}
590
591#[derive(Debug, Serialize)]
592pub(crate) struct TaskDataDownloadRequest {
593 pub(crate) task_id: u64,
594 pub(crate) folder: String,
595 pub(crate) file: String,
596}
597
598#[derive(Debug, Serialize)]
599pub(crate) struct TaskChartAddRequest {
600 pub(crate) task_id: u64,
601 pub(crate) group_name: String,
602 pub(crate) chart_name: String,
603 pub(crate) params: Option<crate::api::Parameter>,
604 pub(crate) data: crate::api::Parameter,
605}
606
607#[derive(Debug, Serialize)]
608pub(crate) struct TaskChartListRequest {
609 pub(crate) task_id: u64,
610 pub(crate) group_name: String,
611}
612
613#[derive(Debug, Serialize)]
614pub(crate) struct TaskChartGetRequest {
615 pub(crate) task_id: u64,
616 pub(crate) group_name: String,
617 pub(crate) chart_name: String,
618}
619
620#[derive(Debug, Serialize)]
621pub(crate) struct ValDataDownloadRequest {
622 pub(crate) session_id: u64,
623 pub(crate) filename: String,
624}
625
626#[derive(Debug, Serialize)]
627pub(crate) struct ValDataListRequest {
628 pub(crate) session_id: u64,
629}
630
631/// Streams the body of a successful `reqwest` response to a file on disk,
632/// emitting optional progress events.
633///
634/// Both `download_artifact` and `rpc_download` share this logic. The caller is
635/// responsible for creating any required parent directories before calling this
636/// function.
637///
638/// # Arguments
639/// * `resp` - A successful (HTTP 2xx) `reqwest::Response` whose body will
640/// be streamed to `path`.
641/// * `path` - Destination file path (created or truncated).
642/// * `progress` - Optional channel; events carry bytes received and
643/// `Content-Length` total (0 if the server omits it).
644///
645/// # Errors
646/// Returns `Error::IoError` on file I/O failures or propagates stream errors.
647async fn stream_response_to_file(
648 resp: reqwest::Response,
649 path: &std::path::Path,
650 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
651) -> Result<(), Error> {
652 use tokio::io::AsyncWriteExt as _;
653 let total = resp.content_length().unwrap_or(0) as usize;
654 let mut stream = resp.bytes_stream();
655 let mut file = tokio::fs::File::create(path).await?;
656 let mut current = 0usize;
657
658 if let Some(ref tx) = progress {
659 let _ = tx
660 .send(Progress {
661 current: 0,
662 total,
663 status: None,
664 })
665 .await;
666 }
667
668 while let Some(chunk) = stream.next().await {
669 let chunk = chunk?;
670 file.write_all(&chunk).await?;
671 current += chunk.len();
672 if let Some(ref tx) = progress {
673 let _ = tx
674 .send(Progress {
675 current,
676 total,
677 status: None,
678 })
679 .await;
680 }
681 }
682
683 // Flush tokio's internal write buffer to the OS before returning.
684 // tokio::fs::File buffers writes internally; without this, the buffer
685 // may not reach the filesystem before the caller reads the file.
686 file.flush().await?;
687 Ok(())
688}
689
690impl Client {
691 /// Create a new unauthenticated client with the default saas server.
692 ///
693 /// By default, the client uses [`FileTokenStorage`] for token persistence.
694 /// Use [`with_storage`][Self::with_storage],
695 /// [`with_memory_storage`][Self::with_memory_storage],
696 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
697 /// behavior.
698 ///
699 /// To connect to a different server, use [`with_server`][Self::with_server]
700 /// or [`with_token`][Self::with_token] (tokens include the server
701 /// instance).
702 ///
703 /// This client is created without a token and will need to authenticate
704 /// before using methods that require authentication.
705 ///
706 /// # Examples
707 ///
708 /// ```rust,no_run
709 /// use edgefirst_client::Client;
710 ///
711 /// # fn main() -> Result<(), edgefirst_client::Error> {
712 /// // Create client with default file storage
713 /// let client = Client::new()?;
714 ///
715 /// // Create client without token persistence
716 /// let client = Client::new()?.with_memory_storage();
717 /// # Ok(())
718 /// # }
719 /// ```
720 pub fn new() -> Result<Self, Error> {
721 log_retry_configuration();
722
723 // Get timeout from environment or use default
724 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
725 .ok()
726 .and_then(|s| s.parse().ok())
727 .unwrap_or(30); // Default 30s total deadline for API calls
728
729 // Per-chunk idle timeout for bulk transfers: fires only when no bytes
730 // arrive for this duration. Resets after every received chunk, so a
731 // healthy multi-GB transfer will never be interrupted.
732 let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
733 .ok()
734 .and_then(|s| s.parse().ok())
735 .unwrap_or(120); // Default 120s idle timeout for bulk transfers
736
737 // Create single HTTP client with URL-based retry policy
738 //
739 // The retry policy classifies requests into two categories:
740 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
741 // errors
742 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
743 //
744 // This allows the same client to handle both API calls and file operations
745 // with appropriate retry behavior for each. See retry.rs for details.
746 let http = reqwest::Client::builder()
747 .connect_timeout(Duration::from_secs(10))
748 .timeout(Duration::from_secs(timeout_secs))
749 .pool_idle_timeout(Duration::from_secs(90))
750 .pool_max_idle_per_host(10)
751 .retry(create_retry_policy())
752 .build()?;
753
754 // Separate HTTP client for bulk transfers (uploads and downloads).
755 // No total-request timeout (EDGEFIRST_TIMEOUT does not apply here).
756 // Uses read_timeout instead: resets after every received chunk, so a
757 // healthy large transfer is never interrupted, but a truly stalled
758 // connection (no bytes for EDGEFIRST_READ_TIMEOUT seconds) is aborted.
759 let bulk_http = reqwest::Client::builder()
760 .connect_timeout(Duration::from_secs(30))
761 .read_timeout(Duration::from_secs(read_timeout_secs))
762 .pool_idle_timeout(Duration::from_secs(90))
763 .pool_max_idle_per_host(10)
764 .retry(create_retry_policy())
765 .build()?;
766
767 // Default to file storage, loading any existing token
768 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
769 Ok(file_storage) => Arc::new(file_storage),
770 Err(e) => {
771 warn!(
772 "Could not initialize file token storage: {}. Using memory storage.",
773 e
774 );
775 Arc::new(MemoryTokenStorage::new())
776 }
777 };
778
779 // Try to load existing token from storage
780 let token = match storage.load() {
781 Ok(Some(t)) => t,
782 Ok(None) => String::new(),
783 Err(e) => {
784 warn!(
785 "Failed to load token from storage: {}. Starting with empty token.",
786 e
787 );
788 String::new()
789 }
790 };
791
792 // Extract server from token if available
793 let url = if !token.is_empty() {
794 match Self::extract_server_from_token(&token) {
795 Ok(server) => format!("https://{}.edgefirst.studio", server),
796 Err(e) => {
797 warn!(
798 "Failed to extract server from token: {}. Using default server.",
799 e
800 );
801 "https://edgefirst.studio".to_string()
802 }
803 }
804 } else {
805 "https://edgefirst.studio".to_string()
806 };
807
808 Ok(Client {
809 http,
810 bulk_http,
811 url,
812 token: Arc::new(tokio::sync::RwLock::new(token)),
813 storage: Some(storage),
814 token_path: None,
815 })
816 }
817
818 /// Returns a new client connected to the specified server instance.
819 ///
820 /// The server parameter is an instance name that maps to a URL:
821 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
822 /// server)
823 /// - `"test"` → `https://test.edgefirst.studio`
824 /// - `"stage"` → `https://stage.edgefirst.studio`
825 /// - `"dev"` → `https://dev.edgefirst.studio`
826 /// - `"{name}"` → `https://{name}.edgefirst.studio`
827 ///
828 /// # Server Selection Priority
829 ///
830 /// When using the CLI or Python API, server selection follows this
831 /// priority:
832 ///
833 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
834 /// they were issued for. If you have a valid token, its server is used.
835 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
836 /// token is available. If a token exists with a different server, a
837 /// warning is emitted and the token's server takes priority.
838 /// 3. **Default `"saas"`** - If no token and no server specified, the
839 /// production server (`https://edgefirst.studio`) is used.
840 ///
841 /// # Important Notes
842 ///
843 /// - If a token is already set in the client, calling this method will
844 /// **drop the token** as tokens are specific to the server instance.
845 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
846 /// token's server before calling this method.
847 /// - For login operations, call `with_server()` first, then authenticate.
848 ///
849 /// # Examples
850 ///
851 /// ```rust,no_run
852 /// use edgefirst_client::Client;
853 ///
854 /// # fn main() -> Result<(), edgefirst_client::Error> {
855 /// let client = Client::new()?.with_server("test")?;
856 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
857 /// # Ok(())
858 /// # }
859 /// ```
860 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
861 // Resolve the target URL. Full URLs (self-hosted Studio,
862 // wiremock) are validated through `with_url` so the HTTPS rules
863 // there apply uniformly. Short names map to the SaaS pattern.
864 // We extract only the URL string and rebuild the Client below,
865 // because `with_url` preserves the in-memory token (the contract
866 // for self-hosted deployments) whereas `with_server` deliberately
867 // clears it (a different server means a stale token).
868 let url = if server.starts_with("http://") || server.starts_with("https://") {
869 self.with_url(server)?.url().to_string()
870 } else {
871 match server {
872 "" | "saas" => "https://edgefirst.studio".to_string(),
873 name => format!("https://{}.edgefirst.studio", name),
874 }
875 };
876
877 // Clear token from storage when changing servers to prevent
878 // authentication issues with stale tokens from different
879 // instances. This runs whether the caller passed a short name
880 // or a full URL — both reach a new server.
881 if let Some(ref storage) = self.storage
882 && let Err(e) = storage.clear()
883 {
884 warn!(
885 "Failed to clear token from storage when changing servers: {}",
886 e
887 );
888 }
889
890 Ok(Client {
891 url,
892 token: Arc::new(tokio::sync::RwLock::new(String::new())),
893 ..self.clone()
894 })
895 }
896
897 /// Returns a new client pointed at an explicit URL.
898 ///
899 /// Used for self-hosted Studio deployments (e.g.
900 /// `https://studio.example.com`) and for offline integration tests
901 /// against a mock HTTP server (e.g. `http://127.0.0.1:8080`). The
902 /// token is preserved so callers can chain
903 /// `Client::new()?.with_url(...)?.with_token(...)`.
904 ///
905 /// # Errors
906 ///
907 /// Returns [`Error::UrlParseError`] for syntactically invalid URLs and
908 /// [`Error::InsecureUrl`] for plain `http://` URLs that resolve to a
909 /// non-loopback host: the Studio bearer token rides in the
910 /// `Authorization` header, and plain HTTP would leak it in the clear.
911 /// Loopback URLs (`127.0.0.1`, `::1`, `localhost`, `*.localhost`) are
912 /// permitted because traffic never leaves the machine — wiremock and
913 /// local dev servers go through that path.
914 pub fn with_url(&self, url: &str) -> Result<Self, Error> {
915 // Reject malformed inputs early so test failures point at the test
916 // rather than a downstream reqwest send.
917 let parsed = url::Url::parse(url)?;
918 let scheme = parsed.scheme();
919 if scheme == "http" {
920 if !is_loopback_host(parsed.host().as_ref()) {
921 return Err(Error::InsecureUrl(url.to_string()));
922 }
923 } else if scheme != "https" {
924 return Err(Error::InsecureUrl(url.to_string()));
925 }
926 Ok(Client {
927 url: url.trim_end_matches('/').to_string(),
928 ..self.clone()
929 })
930 }
931
932 /// Returns a new client with the specified token storage backend.
933 ///
934 /// Use this to configure custom token storage, such as platform-specific
935 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
936 ///
937 /// # Examples
938 ///
939 /// ```rust,no_run
940 /// use edgefirst_client::{Client, FileTokenStorage};
941 /// use std::{path::PathBuf, sync::Arc};
942 ///
943 /// # fn main() -> Result<(), edgefirst_client::Error> {
944 /// // Use a custom file path for token storage
945 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
946 /// let client = Client::new()?.with_storage(Arc::new(storage));
947 /// # Ok(())
948 /// # }
949 /// ```
950 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
951 // Try to load existing token from the new storage
952 let token = match storage.load() {
953 Ok(Some(t)) => t,
954 Ok(None) => String::new(),
955 Err(e) => {
956 warn!(
957 "Failed to load token from storage: {}. Starting with empty token.",
958 e
959 );
960 String::new()
961 }
962 };
963
964 Client {
965 token: Arc::new(tokio::sync::RwLock::new(token)),
966 storage: Some(storage),
967 token_path: None,
968 ..self
969 }
970 }
971
972 /// Returns a new client with in-memory token storage (no persistence).
973 ///
974 /// Tokens are stored in memory only and lost when the application exits.
975 /// This is useful for testing or when you want to manage token persistence
976 /// externally.
977 ///
978 /// # Examples
979 ///
980 /// ```rust,no_run
981 /// use edgefirst_client::Client;
982 ///
983 /// # fn main() -> Result<(), edgefirst_client::Error> {
984 /// let client = Client::new()?.with_memory_storage();
985 /// # Ok(())
986 /// # }
987 /// ```
988 pub fn with_memory_storage(self) -> Self {
989 Client {
990 token: Arc::new(tokio::sync::RwLock::new(String::new())),
991 storage: Some(Arc::new(MemoryTokenStorage::new())),
992 token_path: None,
993 ..self
994 }
995 }
996
997 /// Returns a new client with no token storage.
998 ///
999 /// Tokens are not persisted. Use this when you want to manage tokens
1000 /// entirely manually.
1001 ///
1002 /// # Examples
1003 ///
1004 /// ```rust,no_run
1005 /// use edgefirst_client::Client;
1006 ///
1007 /// # fn main() -> Result<(), edgefirst_client::Error> {
1008 /// let client = Client::new()?.with_no_storage();
1009 /// # Ok(())
1010 /// # }
1011 /// ```
1012 pub fn with_no_storage(self) -> Self {
1013 Client {
1014 storage: None,
1015 token_path: None,
1016 ..self
1017 }
1018 }
1019
1020 /// Returns a new client authenticated with the provided username and
1021 /// password.
1022 ///
1023 /// The token is automatically persisted to storage (if configured).
1024 ///
1025 /// # Examples
1026 ///
1027 /// ```rust,no_run
1028 /// use edgefirst_client::Client;
1029 ///
1030 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1031 /// let client = Client::new()?
1032 /// .with_server("test")?
1033 /// .with_login("user@example.com", "password")
1034 /// .await?;
1035 /// # Ok(())
1036 /// # }
1037 /// ```
1038 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
1039 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
1040 let params = HashMap::from([("username", username), ("password", password)]);
1041 let login: LoginResult = self
1042 .rpc_without_auth("auth.login".to_owned(), Some(params))
1043 .await?;
1044
1045 // Validate that the server returned a non-empty token
1046 if login.token.is_empty() {
1047 return Err(Error::EmptyToken);
1048 }
1049
1050 // Persist token to storage if configured
1051 if let Some(ref storage) = self.storage
1052 && let Err(e) = storage.store(&login.token)
1053 {
1054 warn!("Failed to persist token to storage: {}", e);
1055 }
1056
1057 Ok(Client {
1058 token: Arc::new(tokio::sync::RwLock::new(login.token)),
1059 ..self.clone()
1060 })
1061 }
1062
1063 /// Returns a new client which will load and save the token to the specified
1064 /// path.
1065 ///
1066 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
1067 /// [`FileTokenStorage`] instead for more flexible token management.
1068 ///
1069 /// This method is maintained for backwards compatibility with existing
1070 /// code. It disables the default storage and uses file-based storage at
1071 /// the specified path.
1072 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
1073 let token_path = match token_path {
1074 Some(path) => path.to_path_buf(),
1075 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1076 .ok_or_else(|| {
1077 Error::IoError(std::io::Error::new(
1078 std::io::ErrorKind::NotFound,
1079 "Could not determine user config directory",
1080 ))
1081 })?
1082 .config_dir()
1083 .join("token"),
1084 };
1085
1086 debug!("Using token path (legacy): {:?}", token_path);
1087
1088 let token = match token_path.exists() {
1089 true => std::fs::read_to_string(&token_path)?,
1090 false => "".to_string(),
1091 };
1092
1093 if !token.is_empty() {
1094 match self.with_token(&token) {
1095 Ok(client) => Ok(Client {
1096 token_path: Some(token_path),
1097 storage: None, // Disable new storage when using legacy token_path
1098 ..client
1099 }),
1100 Err(e) => {
1101 // Token is corrupted or invalid - remove it and continue with no token
1102 warn!(
1103 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
1104 token_path, e
1105 );
1106 if let Err(remove_err) = std::fs::remove_file(&token_path) {
1107 warn!("Failed to remove corrupted token file: {:?}", remove_err);
1108 }
1109 // Clear any token from default storage to ensure we don't use it
1110 Ok(Client {
1111 token_path: Some(token_path),
1112 storage: None,
1113 token: Arc::new(RwLock::new("".to_string())),
1114 ..self.clone()
1115 })
1116 }
1117 }
1118 } else {
1119 // No token in the legacy file - clear any token from default storage
1120 Ok(Client {
1121 token_path: Some(token_path),
1122 storage: None,
1123 token: Arc::new(RwLock::new("".to_string())),
1124 ..self.clone()
1125 })
1126 }
1127 }
1128
1129 /// Returns a new client authenticated with the provided token.
1130 ///
1131 /// The token is automatically persisted to storage (if configured).
1132 /// The server URL is extracted from the token payload.
1133 ///
1134 /// # Examples
1135 ///
1136 /// ```rust,no_run
1137 /// use edgefirst_client::Client;
1138 ///
1139 /// # fn main() -> Result<(), edgefirst_client::Error> {
1140 /// let client = Client::new()?.with_token("your-jwt-token")?;
1141 /// # Ok(())
1142 /// # }
1143 /// ```
1144 /// Extract server name from JWT token payload.
1145 ///
1146 /// Helper method to parse the JWT token and extract the "server" field
1147 /// from the payload. Returns the server name (e.g., "test", "stage", "")
1148 /// or an error if the token is invalid.
1149 fn extract_server_from_token(token: &str) -> Result<String, Error> {
1150 let token_parts: Vec<&str> = token.split('.').collect();
1151 if token_parts.len() != 3 {
1152 return Err(Error::InvalidToken);
1153 }
1154
1155 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1156 .decode(token_parts[1])
1157 .map_err(|_| Error::InvalidToken)?;
1158 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1159 let server = match payload.get("server") {
1160 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
1161 None => return Err(Error::InvalidToken),
1162 };
1163
1164 Ok(server)
1165 }
1166
1167 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
1168 if token.is_empty() {
1169 return Ok(self.clone());
1170 }
1171
1172 let server = Self::extract_server_from_token(token)?;
1173
1174 // Persist token to storage if configured
1175 if let Some(ref storage) = self.storage
1176 && let Err(e) = storage.store(token)
1177 {
1178 warn!("Failed to persist token to storage: {}", e);
1179 }
1180
1181 Ok(Client {
1182 url: format!("https://{}.edgefirst.studio", server),
1183 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
1184 ..self.clone()
1185 })
1186 }
1187
1188 /// Persist the current token to storage.
1189 ///
1190 /// This is automatically called when using [`with_login`][Self::with_login]
1191 /// or [`with_token`][Self::with_token], so you typically don't need to call
1192 /// this directly.
1193 ///
1194 /// If using the legacy `token_path` configuration, saves to the file path.
1195 /// If using the new storage abstraction, saves to the configured storage.
1196 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1197 pub async fn save_token(&self) -> Result<(), Error> {
1198 let token = self.token.read().await;
1199
1200 // Try new storage first
1201 if let Some(ref storage) = self.storage {
1202 storage.store(&token)?;
1203 debug!("Token saved to storage");
1204 return Ok(());
1205 }
1206
1207 // Fall back to legacy token_path behavior
1208 let path = self.token_path.clone().unwrap_or_else(|| {
1209 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1210 .map(|dirs| dirs.config_dir().join("token"))
1211 .unwrap_or_else(|| PathBuf::from(".token"))
1212 });
1213
1214 create_dir_all(path.parent().ok_or_else(|| {
1215 Error::IoError(std::io::Error::new(
1216 std::io::ErrorKind::InvalidInput,
1217 "Token path has no parent directory",
1218 ))
1219 })?)?;
1220 let mut file = std::fs::File::create(&path)?;
1221 file.write_all(token.as_bytes())?;
1222
1223 debug!("Saved token to {:?}", path);
1224
1225 Ok(())
1226 }
1227
1228 /// Return the version of the EdgeFirst Studio server for the current
1229 /// client connection.
1230 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1231 pub async fn version(&self) -> Result<String, Error> {
1232 let version: HashMap<String, String> = self
1233 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
1234 .await?;
1235 let version = version.get("version").ok_or(Error::InvalidResponse)?;
1236 Ok(version.to_owned())
1237 }
1238
1239 /// Clear the token used to authenticate the client with the server.
1240 ///
1241 /// Clears the token from memory and from storage (if configured).
1242 /// If using the legacy `token_path` configuration, removes the token file.
1243 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1244 pub async fn logout(&self) -> Result<(), Error> {
1245 {
1246 let mut token = self.token.write().await;
1247 *token = "".to_string();
1248 }
1249
1250 // Clear from new storage if configured
1251 if let Some(ref storage) = self.storage
1252 && let Err(e) = storage.clear()
1253 {
1254 warn!("Failed to clear token from storage: {}", e);
1255 }
1256
1257 // Also clear legacy token_path if configured
1258 if let Some(path) = &self.token_path
1259 && path.exists()
1260 {
1261 fs::remove_file(path).await?;
1262 }
1263
1264 Ok(())
1265 }
1266
1267 /// Return the token used to authenticate the client with the server. When
1268 /// logging into the server using a username and password, the token is
1269 /// returned by the server and stored in the client for future interactions.
1270 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1271 pub async fn token(&self) -> String {
1272 self.token.read().await.clone()
1273 }
1274
1275 /// Verify the token used to authenticate the client with the server. This
1276 /// method is used to ensure that the token is still valid and has not
1277 /// expired. If the token is invalid, the server will return an error and
1278 /// the client will need to login again.
1279 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1280 pub async fn verify_token(&self) -> Result<(), Error> {
1281 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
1282 .await?;
1283 Ok::<(), Error>(())
1284 }
1285
1286 /// Renew the token used to authenticate the client with the server.
1287 ///
1288 /// Refreshes the token before it expires. If the token has already expired,
1289 /// the server will return an error and you will need to login again.
1290 ///
1291 /// The new token is automatically persisted to storage (if configured).
1292 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1293 pub async fn renew_token(&self) -> Result<(), Error> {
1294 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1295 let result: LoginResult = self
1296 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1297 .await?;
1298
1299 {
1300 let mut token = self.token.write().await;
1301 *token = result.token.clone();
1302 }
1303
1304 // Persist to new storage if configured
1305 if let Some(ref storage) = self.storage
1306 && let Err(e) = storage.store(&result.token)
1307 {
1308 warn!("Failed to persist renewed token to storage: {}", e);
1309 }
1310
1311 // Also persist to legacy token_path if configured
1312 if self.token_path.is_some() {
1313 self.save_token().await?;
1314 }
1315
1316 Ok(())
1317 }
1318
1319 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1320 let token = self.token.read().await;
1321 if token.is_empty() {
1322 return Err(Error::EmptyToken);
1323 }
1324
1325 let token_parts: Vec<&str> = token.split('.').collect();
1326 if token_parts.len() != 3 {
1327 return Err(Error::InvalidToken);
1328 }
1329
1330 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1331 .decode(token_parts[1])
1332 .map_err(|_| Error::InvalidToken)?;
1333 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1334 match payload.get(field) {
1335 Some(value) => Ok(value.to_owned()),
1336 None => Err(Error::InvalidToken),
1337 }
1338 }
1339
1340 /// Returns the URL of the EdgeFirst Studio server for the current client.
1341 pub fn url(&self) -> &str {
1342 &self.url
1343 }
1344
1345 /// Returns the server name for the current client.
1346 ///
1347 /// This extracts the server name from the client's URL:
1348 /// - `https://edgefirst.studio` → `"saas"`
1349 /// - `https://test.edgefirst.studio` → `"test"`
1350 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1351 ///
1352 /// # Examples
1353 ///
1354 /// ```rust,no_run
1355 /// use edgefirst_client::Client;
1356 ///
1357 /// # fn main() -> Result<(), edgefirst_client::Error> {
1358 /// let client = Client::new()?.with_server("test")?;
1359 /// assert_eq!(client.server(), "test");
1360 ///
1361 /// let client = Client::new()?; // default
1362 /// assert_eq!(client.server(), "saas");
1363 /// # Ok(())
1364 /// # }
1365 /// ```
1366 pub fn server(&self) -> &str {
1367 if self.url == "https://edgefirst.studio" {
1368 "saas"
1369 } else if let Some(name) = self.url.strip_prefix("https://") {
1370 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1371 } else {
1372 "saas"
1373 }
1374 }
1375
1376 /// Returns the username associated with the current token.
1377 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1378 pub async fn username(&self) -> Result<String, Error> {
1379 match self.token_field("username").await? {
1380 serde_json::Value::String(username) => Ok(username),
1381 _ => Err(Error::InvalidToken),
1382 }
1383 }
1384
1385 /// Returns the expiration time for the current token.
1386 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1387 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1388 let ts = match self.token_field("exp").await? {
1389 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1390 _ => return Err(Error::InvalidToken),
1391 };
1392
1393 match DateTime::<Utc>::from_timestamp(ts, 0) {
1394 Some(dt) => Ok(dt),
1395 None => Err(Error::InvalidToken),
1396 }
1397 }
1398
1399 /// Returns the organization information for the current user.
1400 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1401 pub async fn organization(&self) -> Result<Organization, Error> {
1402 self.rpc::<(), Organization>("org.get".to_owned(), None)
1403 .await
1404 }
1405
1406 /// Returns a list of projects available to the user. The projects are
1407 /// returned as a vector of Project objects. If a name filter is
1408 /// provided, only projects matching the filter are returned.
1409 ///
1410 /// Results are sorted by match quality: exact matches first, then
1411 /// case-insensitive exact matches, then shorter names (more specific),
1412 /// then alphabetically.
1413 ///
1414 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1415 /// Projects contain datasets, trainers, and trainer sessions. Projects
1416 /// are used to group related datasets and trainers together.
1417 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1418 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1419 let projects = self
1420 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1421 .await?;
1422 if let Some(name) = name {
1423 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1424 } else {
1425 Ok(projects)
1426 }
1427 }
1428
1429 /// Return the project with the specified project ID. If the project does
1430 /// not exist, an error is returned.
1431 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1432 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1433 let params = HashMap::from([("project_id", project_id)]);
1434 self.rpc("project.get".to_owned(), Some(params)).await
1435 }
1436
1437 /// Returns a list of datasets available to the user. The datasets are
1438 /// returned as a vector of Dataset objects. If a name filter is
1439 /// provided, only datasets matching the filter are returned.
1440 ///
1441 /// Results are sorted by match quality: exact matches first, then
1442 /// case-insensitive exact matches, then shorter names (more specific),
1443 /// then alphabetically. This ensures "Deer" returns before "Deer
1444 /// Roundtrip".
1445 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1446 pub async fn datasets(
1447 &self,
1448 project_id: ProjectID,
1449 name: Option<&str>,
1450 ) -> Result<Vec<Dataset>, Error> {
1451 let params = HashMap::from([("project_id", project_id)]);
1452 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1453 if let Some(name) = name {
1454 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1455 } else {
1456 Ok(datasets)
1457 }
1458 }
1459
1460 /// Return the dataset with the specified dataset ID. If the dataset does
1461 /// not exist, an error is returned.
1462 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1463 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1464 let params = HashMap::from([("dataset_id", dataset_id)]);
1465 self.rpc("dataset.get".to_owned(), Some(params)).await
1466 }
1467
1468 /// Lists the labels for the specified dataset.
1469 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1470 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1471 let params = HashMap::from([("dataset_id", dataset_id)]);
1472 self.rpc("label.list".to_owned(), Some(params)).await
1473 }
1474
1475 /// Add a new label to the dataset with the specified name.
1476 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1477 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1478 let new_label = NewLabel {
1479 dataset_id,
1480 labels: vec![NewLabelObject {
1481 name: name.to_owned(),
1482 }],
1483 };
1484 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1485 Ok(())
1486 }
1487
1488 /// Removes the label with the specified ID from the dataset. Label IDs are
1489 /// globally unique so the dataset_id is not required.
1490 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1491 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1492 let params = HashMap::from([("label_id", label_id)]);
1493 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1494 Ok(())
1495 }
1496
1497 /// Creates a new dataset in the specified project.
1498 ///
1499 /// # Arguments
1500 ///
1501 /// * `project_id` - The ID of the project to create the dataset in
1502 /// * `name` - The name of the new dataset
1503 /// * `description` - Optional description for the dataset
1504 ///
1505 /// # Returns
1506 ///
1507 /// Returns the dataset ID of the newly created dataset.
1508 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1509 pub async fn create_dataset(
1510 &self,
1511 project_id: &str,
1512 name: &str,
1513 description: Option<&str>,
1514 ) -> Result<DatasetID, Error> {
1515 let mut params = HashMap::new();
1516 params.insert("project_id", project_id);
1517 params.insert("name", name);
1518 if let Some(desc) = description {
1519 params.insert("description", desc);
1520 }
1521
1522 #[derive(Deserialize)]
1523 struct CreateDatasetResult {
1524 id: DatasetID,
1525 }
1526
1527 let result: CreateDatasetResult =
1528 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1529 Ok(result.id)
1530 }
1531
1532 /// Deletes a dataset by marking it as deleted.
1533 ///
1534 /// # Arguments
1535 ///
1536 /// * `dataset_id` - The ID of the dataset to delete
1537 ///
1538 /// # Returns
1539 ///
1540 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1541 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1542 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1543 let params = HashMap::from([("id", dataset_id)]);
1544 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1545 Ok(())
1546 }
1547
1548 /// Updates the label with the specified ID to have the new name or index.
1549 /// Label IDs cannot be changed. Label IDs are globally unique so the
1550 /// dataset_id is not required.
1551 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1552 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1553 #[derive(Serialize)]
1554 struct Params {
1555 dataset_id: DatasetID,
1556 label_id: u64,
1557 label_name: String,
1558 label_index: u64,
1559 }
1560
1561 let _: String = self
1562 .rpc(
1563 "label.update".to_owned(),
1564 Some(Params {
1565 dataset_id: label.dataset_id(),
1566 label_id: label.id(),
1567 label_name: label.name().to_owned(),
1568 label_index: label.index(),
1569 }),
1570 )
1571 .await?;
1572 Ok(())
1573 }
1574
1575 /// Lists the groups for the specified dataset.
1576 ///
1577 /// Groups are used to organize samples into logical subsets such as
1578 /// "train", "val", "test", etc. Each sample can belong to at most one
1579 /// group at a time.
1580 ///
1581 /// # Arguments
1582 ///
1583 /// * `dataset_id` - The ID of the dataset to list groups for
1584 ///
1585 /// # Returns
1586 ///
1587 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1588 /// empty vector if no groups have been created yet.
1589 ///
1590 /// # Errors
1591 ///
1592 /// Returns an error if the dataset does not exist or cannot be accessed.
1593 ///
1594 /// # Example
1595 ///
1596 /// ```rust,no_run
1597 /// # use edgefirst_client::{Client, DatasetID};
1598 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1599 /// let client = Client::new()?.with_token_path(None)?;
1600 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1601 ///
1602 /// let groups = client.groups(dataset_id).await?;
1603 /// for group in groups {
1604 /// println!("{}: {}", group.id, group.name);
1605 /// }
1606 /// # Ok(())
1607 /// # }
1608 /// ```
1609 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1610 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1611 let params = HashMap::from([("dataset_id", dataset_id)]);
1612 self.rpc("groups.list".to_owned(), Some(params)).await
1613 }
1614
1615 /// Gets an existing group by name or creates a new one.
1616 ///
1617 /// This is a convenience method that first checks if a group with the
1618 /// specified name exists, and creates it if not. This is useful when
1619 /// you need to ensure a group exists before assigning samples to it.
1620 ///
1621 /// # Arguments
1622 ///
1623 /// * `dataset_id` - The ID of the dataset
1624 /// * `name` - The name of the group (e.g., "train", "val", "test")
1625 ///
1626 /// # Returns
1627 ///
1628 /// Returns the group ID (either existing or newly created).
1629 ///
1630 /// # Errors
1631 ///
1632 /// Returns an error if:
1633 /// - The dataset does not exist or cannot be accessed
1634 /// - The group creation fails
1635 ///
1636 /// # Concurrency
1637 ///
1638 /// This method handles concurrent creation attempts gracefully. If another
1639 /// process creates the group between the existence check and creation,
1640 /// this method will return the existing group's ID.
1641 ///
1642 /// # Example
1643 ///
1644 /// ```rust,no_run
1645 /// # use edgefirst_client::{Client, DatasetID};
1646 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1647 /// let client = Client::new()?.with_token_path(None)?;
1648 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1649 ///
1650 /// // Get or create a "train" group
1651 /// let train_group_id = client
1652 /// .get_or_create_group(dataset_id.clone(), "train")
1653 /// .await?;
1654 /// println!("Train group ID: {}", train_group_id);
1655 ///
1656 /// // Calling again returns the same ID
1657 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1658 /// assert_eq!(train_group_id, same_id);
1659 /// # Ok(())
1660 /// # }
1661 /// ```
1662 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1663 pub async fn get_or_create_group(
1664 &self,
1665 dataset_id: DatasetID,
1666 name: &str,
1667 ) -> Result<u64, Error> {
1668 // First check if the group already exists
1669 let groups = self.groups(dataset_id).await?;
1670 if let Some(group) = groups.iter().find(|g| g.name == name) {
1671 return Ok(group.id);
1672 }
1673
1674 // Create the group
1675 #[derive(Serialize)]
1676 struct CreateGroupParams {
1677 dataset_id: DatasetID,
1678 group_names: Vec<String>,
1679 group_splits: Vec<i64>,
1680 }
1681
1682 let params = CreateGroupParams {
1683 dataset_id,
1684 group_names: vec![name.to_string()],
1685 group_splits: vec![0], // No automatic splitting
1686 };
1687
1688 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
1689 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
1690 Ok(group.id)
1691 } else {
1692 // Group might have been created by concurrent call, try fetching again
1693 let groups = self.groups(dataset_id).await?;
1694 groups
1695 .iter()
1696 .find(|g| g.name == name)
1697 .map(|g| g.id)
1698 .ok_or_else(|| {
1699 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
1700 })
1701 }
1702 }
1703
1704 /// Sets the group for a sample.
1705 ///
1706 /// Assigns a sample to a specific group. Each sample can belong to at most
1707 /// one group at a time. Setting a new group replaces any existing group
1708 /// assignment.
1709 ///
1710 /// # Arguments
1711 ///
1712 /// * `sample_id` - The ID of the sample (image) to update
1713 /// * `group_id` - The ID of the group to assign. Use
1714 /// [`get_or_create_group`] to obtain a group ID from a name.
1715 ///
1716 /// # Returns
1717 ///
1718 /// Returns `Ok(())` on success.
1719 ///
1720 /// # Errors
1721 ///
1722 /// Returns an error if:
1723 /// - The sample does not exist
1724 /// - The group does not exist
1725 /// - Insufficient permissions to modify the sample
1726 ///
1727 /// # Example
1728 ///
1729 /// ```rust,no_run
1730 /// # use edgefirst_client::{Client, DatasetID, SampleID};
1731 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1732 /// let client = Client::new()?.with_token_path(None)?;
1733 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1734 /// let sample_id: SampleID = 12345.into();
1735 ///
1736 /// // Get or create the "val" group
1737 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
1738 ///
1739 /// // Assign the sample to the "val" group
1740 /// client.set_sample_group_id(sample_id, val_group_id).await?;
1741 /// # Ok(())
1742 /// # }
1743 /// ```
1744 ///
1745 /// [`get_or_create_group`]: Self::get_or_create_group
1746 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1747 pub async fn set_sample_group_id(
1748 &self,
1749 sample_id: SampleID,
1750 group_id: u64,
1751 ) -> Result<(), Error> {
1752 #[derive(Serialize)]
1753 struct SetGroupParams {
1754 image_id: SampleID,
1755 group_id: u64,
1756 }
1757
1758 let params = SetGroupParams {
1759 image_id: sample_id,
1760 group_id,
1761 };
1762 let _: String = self
1763 .rpc("image.set_group_id".to_owned(), Some(params))
1764 .await?;
1765 Ok(())
1766 }
1767
1768 /// Downloads dataset samples to the local filesystem.
1769 ///
1770 /// # Arguments
1771 ///
1772 /// * `dataset_id` - The unique identifier of the dataset
1773 /// * `groups` - Dataset groups to include (e.g., "train", "val")
1774 /// * `file_types` - File types to download. Supported types:
1775 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
1776 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
1777 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
1778 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
1779 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
1780 /// - `FileType::RadarCube` - Radar cube data (.png format)
1781 /// - `FileType::All` - All sensor types (expands to all of the above)
1782 /// * `output` - Local directory to save downloaded files
1783 /// * `flatten` - If true, download all files to output root without
1784 /// sequence subdirectories. When flattening, filenames are prefixed with
1785 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
1786 /// unavailable) unless the filename already starts with
1787 /// `{sequence_name}_`, to avoid conflicts between sequences.
1788 /// * `progress` - Optional channel for progress updates
1789 ///
1790 /// # Progress
1791 ///
1792 /// This operation has two phases with distinct progress reporting:
1793 ///
1794 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
1795 /// from the server. Progress counts samples fetched.
1796 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
1797 /// files to disk. Progress counts samples completed (each sample may
1798 /// have multiple files for different sensor types).
1799 ///
1800 /// Applications should detect the status change from `None` to
1801 /// `"Downloading"` to reset their progress bar for the second phase.
1802 ///
1803 /// # Returns
1804 ///
1805 /// Returns `Ok(())` on success or an error if download fails.
1806 ///
1807 /// # Example
1808 ///
1809 /// ```rust,no_run
1810 /// # use edgefirst_client::{Client, DatasetID, FileType};
1811 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1812 /// let client = Client::new()?.with_token_path(None)?;
1813 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1814 ///
1815 /// // Download with sequence subdirectories (default)
1816 /// client
1817 /// .download_dataset(
1818 /// dataset_id,
1819 /// &[],
1820 /// &[FileType::Image],
1821 /// "./data".into(),
1822 /// false,
1823 /// None,
1824 /// )
1825 /// .await?;
1826 ///
1827 /// // Download flattened (all files in one directory)
1828 /// client
1829 /// .download_dataset(
1830 /// dataset_id,
1831 /// &[],
1832 /// &[FileType::Image],
1833 /// "./data".into(),
1834 /// true,
1835 /// None,
1836 /// )
1837 /// .await?;
1838 ///
1839 /// // Download all sensor types
1840 /// client
1841 /// .download_dataset(
1842 /// dataset_id,
1843 /// &[],
1844 /// &FileType::expand_types(&[FileType::All]),
1845 /// "./data".into(),
1846 /// false,
1847 /// None,
1848 /// )
1849 /// .await?;
1850 /// # Ok(())
1851 /// # }
1852 /// ```
1853 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
1854 pub async fn download_dataset(
1855 &self,
1856 dataset_id: DatasetID,
1857 groups: &[String],
1858 file_types: &[FileType],
1859 output: PathBuf,
1860 flatten: bool,
1861 progress: Option<Sender<Progress>>,
1862 ) -> Result<(), Error> {
1863 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
1864 let samples = self
1865 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
1866 .await?;
1867 fs::create_dir_all(&output).await?;
1868
1869 // Phase 2: Download actual files using direct semaphore pattern
1870 let total = samples.len();
1871 let current = Arc::new(AtomicUsize::new(0));
1872 let sem = Arc::new(Semaphore::new(max_tasks()));
1873
1874 // Send initial progress for download phase
1875 if let Some(ref progress) = progress {
1876 let _ = progress
1877 .send(Progress {
1878 current: 0,
1879 total,
1880 status: Some("Downloading".to_string()),
1881 })
1882 .await;
1883 }
1884
1885 let tasks = samples
1886 .into_iter()
1887 .map(|sample| {
1888 let client = self.clone();
1889 let file_types = file_types.to_vec();
1890 let output = output.clone();
1891 let progress = progress.clone();
1892 let current = current.clone();
1893 let sem = sem.clone();
1894
1895 tokio::spawn(async move {
1896 let _permit = sem.acquire().await.map_err(|_| {
1897 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
1898 })?;
1899
1900 for file_type in &file_types {
1901 if let Some(data) = sample.download(&client, file_type.clone()).await? {
1902 let (file_ext, is_image) = match file_type {
1903 FileType::Image => (
1904 infer::get(&data)
1905 .expect("Failed to identify image file format for sample")
1906 .extension()
1907 .to_string(),
1908 true,
1909 ),
1910 other => (other.file_extension().to_string(), false),
1911 };
1912
1913 // Determine target directory based on sequence membership and
1914 // flatten option
1915 // - flatten=false + sequence_name: dataset/sequence_name/
1916 // - flatten=false + no sequence: dataset/ (root level)
1917 // - flatten=true: dataset/ (all files in output root)
1918 // NOTE: group (train/val/test) is NOT used for directory structure
1919 let sequence_dir = sample
1920 .sequence_name()
1921 .map(|name| sanitize_path_component(name));
1922
1923 let target_dir = if flatten {
1924 output.clone()
1925 } else {
1926 sequence_dir
1927 .as_ref()
1928 .map(|seq| output.join(seq))
1929 .unwrap_or_else(|| output.clone())
1930 };
1931 fs::create_dir_all(&target_dir).await?;
1932
1933 let sanitized_sample_name = sample
1934 .name()
1935 .map(|name| sanitize_path_component(&name))
1936 .unwrap_or_else(|| "unknown".to_string());
1937
1938 let image_name = sample.image_name().map(sanitize_path_component);
1939
1940 // Construct filename with smart prefixing for flatten mode
1941 // When flatten=true and sample belongs to a sequence:
1942 // - Check if filename already starts with "{sequence_name}_"
1943 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
1944 // - If yes, use filename as-is (already uniquely named)
1945 let file_name = if is_image {
1946 if let Some(img_name) = image_name {
1947 Client::build_filename(
1948 &img_name,
1949 flatten,
1950 sequence_dir.as_ref(),
1951 sample.frame_number(),
1952 )
1953 } else {
1954 format!("{}.{}", sanitized_sample_name, file_ext)
1955 }
1956 } else {
1957 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
1958 Client::build_filename(
1959 &base_name,
1960 flatten,
1961 sequence_dir.as_ref(),
1962 sample.frame_number(),
1963 )
1964 };
1965
1966 let file_path = target_dir.join(&file_name);
1967
1968 let mut file = File::create(&file_path).await?;
1969 file.write_all(&data).await?;
1970 }
1971 }
1972
1973 // Update progress after sample completes
1974 if let Some(progress) = &progress {
1975 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
1976 let _ = progress
1977 .send(Progress {
1978 current: completed,
1979 total,
1980 status: Some("Downloading".to_string()),
1981 })
1982 .await;
1983 }
1984
1985 Ok::<(), Error>(())
1986 })
1987 })
1988 .collect::<Vec<_>>();
1989
1990 join_all(tasks)
1991 .await
1992 .into_iter()
1993 .collect::<Result<Vec<_>, _>>()?
1994 .into_iter()
1995 .collect::<Result<Vec<_>, _>>()?;
1996
1997 Ok(())
1998 }
1999
2000 /// Builds a filename with smart prefixing for flatten mode.
2001 ///
2002 /// When flattening sequences into a single directory, this function ensures
2003 /// unique filenames by checking if the sequence prefix already exists and
2004 /// adding it if necessary.
2005 ///
2006 /// # Logic
2007 ///
2008 /// - If `flatten=false`: returns `base_name` unchanged
2009 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
2010 /// - If `flatten=true` and in sequence:
2011 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
2012 /// unchanged
2013 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
2014 /// `{sequence_name}_{base_name}`
2015 fn build_filename(
2016 base_name: &str,
2017 flatten: bool,
2018 sequence_name: Option<&String>,
2019 frame_number: Option<u32>,
2020 ) -> String {
2021 if !flatten || sequence_name.is_none() {
2022 return base_name.to_string();
2023 }
2024
2025 let seq_name = sequence_name.unwrap();
2026 let prefix = format!("{}_", seq_name);
2027
2028 // Check if already prefixed with sequence name
2029 if base_name.starts_with(&prefix) {
2030 base_name.to_string()
2031 } else {
2032 // Add sequence (and optionally frame) prefix
2033 match frame_number {
2034 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
2035 None => format!("{}{}", prefix, base_name),
2036 }
2037 }
2038 }
2039
2040 /// List available annotation sets for the specified dataset.
2041 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2042 pub async fn annotation_sets(
2043 &self,
2044 dataset_id: DatasetID,
2045 ) -> Result<Vec<AnnotationSet>, Error> {
2046 let params = HashMap::from([("dataset_id", dataset_id)]);
2047 self.rpc("annset.list".to_owned(), Some(params)).await
2048 }
2049
2050 /// Create a new annotation set for the specified dataset.
2051 ///
2052 /// # Arguments
2053 ///
2054 /// * `dataset_id` - The ID of the dataset to create the annotation set in
2055 /// * `name` - The name of the new annotation set
2056 /// * `description` - Optional description for the annotation set
2057 ///
2058 /// # Returns
2059 ///
2060 /// Returns the annotation set ID of the newly created annotation set.
2061 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2062 pub async fn create_annotation_set(
2063 &self,
2064 dataset_id: DatasetID,
2065 name: &str,
2066 description: Option<&str>,
2067 ) -> Result<AnnotationSetID, Error> {
2068 #[derive(Serialize)]
2069 struct Params<'a> {
2070 dataset_id: DatasetID,
2071 name: &'a str,
2072 operator: &'a str,
2073 #[serde(skip_serializing_if = "Option::is_none")]
2074 description: Option<&'a str>,
2075 }
2076
2077 #[derive(Deserialize)]
2078 struct CreateAnnotationSetResult {
2079 id: AnnotationSetID,
2080 }
2081
2082 let username = self.username().await?;
2083 let result: CreateAnnotationSetResult = self
2084 .rpc(
2085 "annset.add".to_owned(),
2086 Some(Params {
2087 dataset_id,
2088 name,
2089 operator: &username,
2090 description,
2091 }),
2092 )
2093 .await?;
2094 Ok(result.id)
2095 }
2096
2097 /// Deletes an annotation set by marking it as deleted.
2098 ///
2099 /// # Arguments
2100 ///
2101 /// * `annotation_set_id` - The ID of the annotation set to delete
2102 ///
2103 /// # Returns
2104 ///
2105 /// Returns `Ok(())` if the annotation set was successfully marked as
2106 /// deleted.
2107 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2108 pub async fn delete_annotation_set(
2109 &self,
2110 annotation_set_id: AnnotationSetID,
2111 ) -> Result<(), Error> {
2112 let params = HashMap::from([("id", annotation_set_id)]);
2113 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
2114 Ok(())
2115 }
2116
2117 /// Retrieve the annotation set with the specified ID.
2118 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2119 pub async fn annotation_set(
2120 &self,
2121 annotation_set_id: AnnotationSetID,
2122 ) -> Result<AnnotationSet, Error> {
2123 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
2124 self.rpc("annset.get".to_owned(), Some(params)).await
2125 }
2126
2127 /// Get the annotations for the specified annotation set with the
2128 /// requested annotation types. The annotation types are used to filter
2129 /// the annotations returned. The groups parameter is used to filter for
2130 /// dataset groups (train, val, test). Images which do not have any
2131 /// annotations are also included in the result as long as they are in the
2132 /// requested groups (when specified).
2133 ///
2134 /// The result is a vector of Annotations objects which contain the
2135 /// full dataset along with the annotations for the specified types.
2136 ///
2137 /// # Progress
2138 ///
2139 /// Reports progress with `status: None` as samples are fetched and
2140 /// processed for their annotations. Progress unit is samples processed
2141 /// (not individual annotations).
2142 ///
2143 /// To get the annotations as a DataFrame, use the `samples_dataframe`
2144 /// method instead.
2145 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2146 pub async fn annotations(
2147 &self,
2148 annotation_set_id: AnnotationSetID,
2149 groups: &[String],
2150 annotation_types: &[AnnotationType],
2151 progress: Option<Sender<Progress>>,
2152 ) -> Result<Vec<Annotation>, Error> {
2153 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
2154 let labels = self
2155 .labels(dataset_id)
2156 .await?
2157 .into_iter()
2158 .map(|label| (label.name().to_string(), label.index()))
2159 .collect::<HashMap<_, _>>();
2160 let total = self
2161 .samples_count(
2162 dataset_id,
2163 Some(annotation_set_id),
2164 annotation_types,
2165 groups,
2166 &[],
2167 )
2168 .await?
2169 .total as usize;
2170
2171 if total == 0 {
2172 return Ok(vec![]);
2173 }
2174
2175 let context = FetchContext {
2176 dataset_id,
2177 annotation_set_id: Some(annotation_set_id),
2178 groups,
2179 types: annotation_types.iter().map(|t| t.to_string()).collect(),
2180 labels: &labels,
2181 };
2182
2183 self.fetch_annotations_paginated(context, total, progress)
2184 .await
2185 }
2186
2187 async fn fetch_annotations_paginated(
2188 &self,
2189 context: FetchContext<'_>,
2190 total: usize,
2191 progress: Option<Sender<Progress>>,
2192 ) -> Result<Vec<Annotation>, Error> {
2193 let mut annotations = vec![];
2194 let mut continue_token: Option<String> = None;
2195 let mut current = 0;
2196
2197 loop {
2198 let params = SamplesListParams {
2199 dataset_id: context.dataset_id,
2200 annotation_set_id: context.annotation_set_id,
2201 types: context.types.clone(),
2202 group_names: context.groups.to_vec(),
2203 continue_token,
2204 };
2205
2206 let result: SamplesListResult =
2207 self.rpc("samples.list".to_owned(), Some(params)).await?;
2208 current += result.samples.len();
2209 continue_token = result.continue_token;
2210
2211 if result.samples.is_empty() {
2212 break;
2213 }
2214
2215 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
2216
2217 if let Some(progress) = &progress {
2218 let _ = progress
2219 .send(Progress {
2220 current,
2221 total,
2222 status: None,
2223 })
2224 .await;
2225 }
2226
2227 match &continue_token {
2228 Some(token) if !token.is_empty() => continue,
2229 _ => break,
2230 }
2231 }
2232
2233 drop(progress);
2234 Ok(annotations)
2235 }
2236
2237 fn process_sample_annotations(
2238 &self,
2239 samples: &[Sample],
2240 labels: &HashMap<String, u64>,
2241 annotations: &mut Vec<Annotation>,
2242 ) {
2243 for sample in samples {
2244 if sample.annotations().is_empty() {
2245 let mut annotation = Annotation::new();
2246 annotation.set_sample_id(sample.id());
2247 annotation.set_name(sample.name());
2248 annotation.set_sequence_name(sample.sequence_name().cloned());
2249 annotation.set_frame_number(sample.frame_number());
2250 annotation.set_group(sample.group().cloned());
2251 annotations.push(annotation);
2252 continue;
2253 }
2254
2255 for annotation in sample.annotations() {
2256 let mut annotation = annotation.clone();
2257 annotation.set_sample_id(sample.id());
2258 annotation.set_name(sample.name());
2259 annotation.set_sequence_name(sample.sequence_name().cloned());
2260 annotation.set_frame_number(sample.frame_number());
2261 annotation.set_group(sample.group().cloned());
2262 Self::set_label_index_from_map(&mut annotation, labels);
2263 annotations.push(annotation);
2264 }
2265 }
2266 }
2267
2268 /// Delete annotations in bulk from specified samples.
2269 ///
2270 /// This method calls the `annotation.bulk.del` API to efficiently remove
2271 /// annotations from multiple samples at once. Useful for clearing
2272 /// annotations before re-importing updated data.
2273 ///
2274 /// # Arguments
2275 /// * `annotation_set_id` - The annotation set containing the annotations
2276 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
2277 /// for masks
2278 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
2279 ///
2280 /// # Example
2281 /// ```no_run
2282 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
2283 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2284 /// # let client = Client::new()?.with_login("user", "pass").await?;
2285 /// let annotation_set_id = AnnotationSetID::from(123);
2286 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2287 ///
2288 /// client
2289 /// .delete_annotations_bulk(
2290 /// annotation_set_id,
2291 /// &["box".to_string(), "seg".to_string()],
2292 /// &sample_ids,
2293 /// )
2294 /// .await?;
2295 /// # Ok(())
2296 /// # }
2297 /// ```
2298 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2299 pub async fn delete_annotations_bulk(
2300 &self,
2301 annotation_set_id: AnnotationSetID,
2302 annotation_types: &[String],
2303 sample_ids: &[SampleID],
2304 ) -> Result<(), Error> {
2305 use crate::api::AnnotationBulkDeleteParams;
2306
2307 let params = AnnotationBulkDeleteParams {
2308 annotation_set_id: annotation_set_id.into(),
2309 annotation_types: annotation_types.to_vec(),
2310 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2311 delete_all: None,
2312 };
2313
2314 let _: String = self
2315 .rpc("annotation.bulk.del".to_owned(), Some(params))
2316 .await?;
2317 Ok(())
2318 }
2319
2320 /// Add annotations in bulk.
2321 ///
2322 /// This method calls the `annotation.add_bulk` API to efficiently add
2323 /// multiple annotations at once. The annotations must be in server format
2324 /// with image_id references.
2325 ///
2326 /// # Arguments
2327 /// * `annotation_set_id` - The annotation set to add annotations to
2328 /// * `annotations` - Vector of server-format annotations to add
2329 ///
2330 /// # Returns
2331 /// Vector of created annotation records from the server.
2332 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2333 pub async fn add_annotations_bulk(
2334 &self,
2335 annotation_set_id: AnnotationSetID,
2336 annotations: Vec<crate::api::ServerAnnotation>,
2337 ) -> Result<Vec<serde_json::Value>, Error> {
2338 use crate::api::AnnotationAddBulkParams;
2339
2340 let params = AnnotationAddBulkParams {
2341 annotation_set_id: annotation_set_id.into(),
2342 annotations,
2343 };
2344
2345 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2346 .await
2347 }
2348
2349 /// Helper to parse frame number from image_name when sequence_name is
2350 /// present. This ensures frame_number is always derived from the image
2351 /// filename, not from the server's frame_number field (which may be
2352 /// inconsistent).
2353 ///
2354 /// Returns Some(frame_number) if sequence_name is present and frame can be
2355 /// parsed, otherwise None.
2356 fn parse_frame_from_image_name(
2357 image_name: Option<&String>,
2358 sequence_name: Option<&String>,
2359 ) -> Option<u32> {
2360 use std::path::Path;
2361
2362 let sequence = sequence_name?;
2363 let name = image_name?;
2364
2365 // Extract stem (remove extension)
2366 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2367
2368 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2369 stem.strip_prefix(sequence)
2370 .and_then(|suffix| suffix.strip_prefix('_'))
2371 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2372 }
2373
2374 /// Helper to set label index from a label map
2375 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2376 if let Some(label) = annotation.label() {
2377 annotation.set_label_index(Some(labels[label.as_str()]));
2378 }
2379 }
2380
2381 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2382 pub async fn samples_count(
2383 &self,
2384 dataset_id: DatasetID,
2385 annotation_set_id: Option<AnnotationSetID>,
2386 annotation_types: &[AnnotationType],
2387 groups: &[String],
2388 types: &[FileType],
2389 ) -> Result<SamplesCountResult, Error> {
2390 // Use server type names for API calls (e.g., "box" instead of "box2d")
2391 let types = annotation_types
2392 .iter()
2393 .map(|t| t.as_server_type().to_string())
2394 .chain(types.iter().map(|t| t.to_string()))
2395 .collect::<Vec<_>>();
2396
2397 let params = SamplesListParams {
2398 dataset_id,
2399 annotation_set_id,
2400 group_names: groups.to_vec(),
2401 types,
2402 continue_token: None,
2403 };
2404
2405 self.rpc("samples.count".to_owned(), Some(params)).await
2406 }
2407
2408 /// Fetches samples from a dataset with optional annotation and file type
2409 /// filters.
2410 ///
2411 /// # Arguments
2412 ///
2413 /// * `dataset_id` - The dataset to fetch samples from
2414 /// * `annotation_set_id` - Optional annotation set to include annotations
2415 /// from
2416 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2417 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2418 /// * `types` - File types to include metadata for
2419 /// * `progress` - Optional channel for progress updates
2420 ///
2421 /// # Progress
2422 ///
2423 /// Reports progress with `status: None` as samples are fetched from the
2424 /// server in paginated batches. Progress unit is samples fetched.
2425 ///
2426 /// # Returns
2427 ///
2428 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2429 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2430 pub async fn samples(
2431 &self,
2432 dataset_id: DatasetID,
2433 annotation_set_id: Option<AnnotationSetID>,
2434 annotation_types: &[AnnotationType],
2435 groups: &[String],
2436 types: &[FileType],
2437 progress: Option<Sender<Progress>>,
2438 ) -> Result<Vec<Sample>, Error> {
2439 // Use server type names for API calls (e.g., "box" instead of "box2d")
2440 let types_vec = annotation_types
2441 .iter()
2442 .map(|t| t.as_server_type().to_string())
2443 .chain(types.iter().map(|t| t.to_string()))
2444 .collect::<Vec<_>>();
2445 let labels = self
2446 .labels(dataset_id)
2447 .await?
2448 .into_iter()
2449 .map(|label| (label.name().to_string(), label.index()))
2450 .collect::<HashMap<_, _>>();
2451 let total = self
2452 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2453 .await?
2454 .total as usize;
2455
2456 if total == 0 {
2457 return Ok(vec![]);
2458 }
2459
2460 let context = FetchContext {
2461 dataset_id,
2462 annotation_set_id,
2463 groups,
2464 types: types_vec,
2465 labels: &labels,
2466 };
2467
2468 self.fetch_samples_paginated(context, total, progress).await
2469 }
2470
2471 /// Get all sample names in a dataset.
2472 ///
2473 /// This is an efficient method for checking which samples already exist,
2474 /// useful for resuming interrupted imports. It only retrieves sample names
2475 /// without loading full annotation data.
2476 ///
2477 /// # Arguments
2478 ///
2479 /// * `dataset_id` - The dataset to query
2480 /// * `groups` - Optional group filter (empty = all groups)
2481 /// * `progress` - Optional progress channel
2482 ///
2483 /// # Progress
2484 ///
2485 /// Reports progress with `status: None` as sample names are fetched from
2486 /// the server in paginated batches. Progress unit is samples fetched.
2487 ///
2488 /// # Returns
2489 ///
2490 /// A HashSet of sample names (image_name field) that exist in the dataset.
2491 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2492 pub async fn sample_names(
2493 &self,
2494 dataset_id: DatasetID,
2495 groups: &[String],
2496 progress: Option<Sender<Progress>>,
2497 ) -> Result<std::collections::HashSet<String>, Error> {
2498 use std::collections::HashSet;
2499
2500 let total = self
2501 .samples_count(dataset_id, None, &[], groups, &[])
2502 .await?
2503 .total as usize;
2504
2505 if total == 0 {
2506 return Ok(HashSet::new());
2507 }
2508
2509 let mut names = HashSet::with_capacity(total);
2510 let mut continue_token: Option<String> = None;
2511 let mut current = 0;
2512
2513 loop {
2514 let params = SamplesListParams {
2515 dataset_id,
2516 annotation_set_id: None,
2517 types: vec![], // No type filter - we just want names
2518 group_names: groups.to_vec(),
2519 continue_token: continue_token.clone(),
2520 };
2521
2522 let result: SamplesListResult =
2523 self.rpc("samples.list".to_owned(), Some(params)).await?;
2524 current += result.samples.len();
2525 continue_token = result.continue_token;
2526
2527 if result.samples.is_empty() {
2528 break;
2529 }
2530
2531 // Extract sample names (normalized without extension)
2532 for sample in result.samples {
2533 if let Some(name) = sample.name() {
2534 names.insert(name);
2535 }
2536 }
2537
2538 if let Some(ref p) = progress {
2539 let _ = p
2540 .send(Progress {
2541 current,
2542 total,
2543 status: None,
2544 })
2545 .await;
2546 }
2547
2548 match &continue_token {
2549 Some(token) if !token.is_empty() => continue,
2550 _ => break,
2551 }
2552 }
2553
2554 Ok(names)
2555 }
2556
2557 async fn fetch_samples_paginated(
2558 &self,
2559 context: FetchContext<'_>,
2560 total: usize,
2561 progress: Option<Sender<Progress>>,
2562 ) -> Result<Vec<Sample>, Error> {
2563 let mut samples = vec![];
2564 let mut continue_token: Option<String> = None;
2565 let mut current = 0;
2566
2567 loop {
2568 let params = SamplesListParams {
2569 dataset_id: context.dataset_id,
2570 annotation_set_id: context.annotation_set_id,
2571 types: context.types.clone(),
2572 group_names: context.groups.to_vec(),
2573 continue_token: continue_token.clone(),
2574 };
2575
2576 let result: SamplesListResult =
2577 self.rpc("samples.list".to_owned(), Some(params)).await?;
2578 current += result.samples.len();
2579 continue_token = result.continue_token;
2580
2581 if result.samples.is_empty() {
2582 break;
2583 }
2584
2585 samples.append(
2586 &mut result
2587 .samples
2588 .into_iter()
2589 .map(|s| {
2590 // Use server's frame_number if valid (>= 0 after deserialization)
2591 // Otherwise parse from image_name as fallback
2592 // This ensures we respect explicit frame_number from uploads
2593 // while still handling legacy data that only has filename encoding
2594 let frame_number = s.frame_number.or_else(|| {
2595 Self::parse_frame_from_image_name(
2596 s.image_name.as_ref(),
2597 s.sequence_name.as_ref(),
2598 )
2599 });
2600
2601 let mut anns = s.annotations().to_vec();
2602 for ann in &mut anns {
2603 // Set annotation fields from parent sample
2604 ann.set_name(s.name());
2605 ann.set_group(s.group().cloned());
2606 ann.set_sequence_name(s.sequence_name().cloned());
2607 ann.set_frame_number(frame_number);
2608 Self::set_label_index_from_map(ann, context.labels);
2609 }
2610 s.with_annotations(anns).with_frame_number(frame_number)
2611 })
2612 .collect::<Vec<_>>(),
2613 );
2614
2615 if let Some(progress) = &progress {
2616 let _ = progress
2617 .send(Progress {
2618 current,
2619 total,
2620 status: None,
2621 })
2622 .await;
2623 }
2624
2625 match &continue_token {
2626 Some(token) if !token.is_empty() => continue,
2627 _ => break,
2628 }
2629 }
2630
2631 drop(progress);
2632 Ok(samples)
2633 }
2634
2635 /// Populates (imports) samples into a dataset using the `samples.populate2`
2636 /// API.
2637 ///
2638 /// This method creates new samples in the specified dataset, optionally
2639 /// with annotations and sensor data files. For each sample, the `files`
2640 /// field is checked for local file paths. If a filename is a valid path
2641 /// to an existing file, the file will be automatically uploaded to S3
2642 /// using presigned URLs returned by the server. The filename in the
2643 /// request is replaced with the basename (path removed) before sending
2644 /// to the server.
2645 ///
2646 /// # Important Notes
2647 ///
2648 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2649 /// annotations. Without it, the server will accept the request but will
2650 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2651 /// available annotation sets for a dataset, or create a new one via the
2652 /// Studio UI.
2653 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2654 /// boxes. Divide pixel coordinates by image width/height before creating
2655 /// [`Box2d`](crate::Box2d) annotations.
2656 /// - **Files are uploaded automatically** when the filename is a valid
2657 /// local path. The method will replace the full path with just the
2658 /// basename before sending to the server.
2659 /// - **Image dimensions are extracted automatically** for image files using
2660 /// the `imagesize` crate. The width/height are sent to the server, but
2661 /// note that the server currently doesn't return these fields when
2662 /// fetching samples back.
2663 /// - **UUIDs are generated automatically** if not provided. If you need
2664 /// deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
2665 /// that the server doesn't currently return UUIDs in sample queries.
2666 ///
2667 /// # Arguments
2668 ///
2669 /// * `dataset_id` - The ID of the dataset to populate
2670 /// * `annotation_set_id` - **Required** if samples contain annotations,
2671 /// otherwise they will be ignored. Query with
2672 /// [`Client::annotation_sets`].
2673 /// * `samples` - Vector of samples to import with metadata and file
2674 /// references. For files, use the full local path - it will be uploaded
2675 /// automatically. UUIDs and image dimensions will be
2676 /// auto-generated/extracted if not provided.
2677 /// * `progress` - Optional channel for progress updates
2678 ///
2679 /// # Progress
2680 ///
2681 /// Reports progress with `status: None` as each sample's files are
2682 /// uploaded. Progress unit is samples (not individual files). Each
2683 /// sample may contain multiple files (image, lidar, radar, etc.) which
2684 /// are all uploaded before the sample is counted as complete.
2685 ///
2686 /// # Returns
2687 ///
2688 /// Returns the API result with sample UUIDs and upload status.
2689 ///
2690 /// # Example
2691 ///
2692 /// ```no_run
2693 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
2694 ///
2695 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2696 /// # let client = Client::new()?.with_login("user", "pass").await?;
2697 /// # let dataset_id = DatasetID::from(1);
2698 /// // Query available annotation sets for the dataset
2699 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
2700 /// let annotation_set_id = annotation_sets
2701 /// .first()
2702 /// .ok_or_else(|| {
2703 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
2704 /// })?
2705 /// .id();
2706 ///
2707 /// // Create sample with annotation (UUID will be auto-generated)
2708 /// let mut sample = Sample::new();
2709 /// sample.width = Some(1920);
2710 /// sample.height = Some(1080);
2711 /// sample.group = Some("train".to_string());
2712 ///
2713 /// // Add file - use full path to local file, it will be uploaded automatically
2714 /// sample.files = vec![SampleFile::with_filename(
2715 /// "image".to_string(),
2716 /// "/path/to/image.jpg".to_string(),
2717 /// )];
2718 ///
2719 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
2720 /// let mut annotation = Annotation::new();
2721 /// annotation.set_label(Some("person".to_string()));
2722 /// // Normalize pixel coordinates by dividing by image dimensions
2723 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
2724 /// annotation.set_box2d(Some(bbox));
2725 /// sample.annotations = vec![annotation];
2726 ///
2727 /// // Populate with annotation_set_id (REQUIRED for annotations)
2728 /// let result = client
2729 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
2730 /// .await?;
2731 /// # Ok(())
2732 /// # }
2733 /// ```
2734 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2735 pub async fn populate_samples(
2736 &self,
2737 dataset_id: DatasetID,
2738 annotation_set_id: Option<AnnotationSetID>,
2739 samples: Vec<Sample>,
2740 progress: Option<Sender<Progress>>,
2741 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2742 self.populate_samples_with_concurrency(
2743 dataset_id,
2744 annotation_set_id,
2745 samples,
2746 progress,
2747 None,
2748 )
2749 .await
2750 }
2751
2752 /// Populate samples with custom upload concurrency.
2753 ///
2754 /// Same as [`populate_samples`](Self::populate_samples) but allows
2755 /// specifying the maximum number of concurrent file uploads. Use this
2756 /// for bulk imports where higher concurrency can significantly reduce
2757 /// upload time.
2758 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
2759 pub async fn populate_samples_with_concurrency(
2760 &self,
2761 dataset_id: DatasetID,
2762 annotation_set_id: Option<AnnotationSetID>,
2763 samples: Vec<Sample>,
2764 progress: Option<Sender<Progress>>,
2765 concurrency: Option<usize>,
2766 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
2767 use crate::api::SamplesPopulateParams;
2768
2769 // Track which files need to be uploaded
2770 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
2771
2772 // Process samples to detect local files and generate UUIDs
2773 let samples = self.prepare_samples_for_upload(samples, &mut files_to_upload)?;
2774
2775 let has_files_to_upload = !files_to_upload.is_empty();
2776
2777 // Call populate API with presigned_urls=true if we have files to upload
2778 let params = SamplesPopulateParams {
2779 dataset_id,
2780 annotation_set_id,
2781 presigned_urls: Some(has_files_to_upload),
2782 samples,
2783 };
2784
2785 let results: Vec<crate::SamplesPopulateResult> = self
2786 .rpc("samples.populate2".to_owned(), Some(params))
2787 .await?;
2788
2789 // Upload files if we have any
2790 if has_files_to_upload {
2791 self.upload_sample_files(&results, files_to_upload, progress, concurrency)
2792 .await?;
2793 }
2794
2795 Ok(results)
2796 }
2797
2798 fn prepare_samples_for_upload(
2799 &self,
2800 samples: Vec<Sample>,
2801 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2802 ) -> Result<Vec<Sample>, Error> {
2803 Ok(samples
2804 .into_iter()
2805 .map(|mut sample| {
2806 // Generate UUID if not provided
2807 if sample.uuid.is_none() {
2808 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
2809 }
2810
2811 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
2812
2813 // Process files: detect local paths and queue for upload
2814 let files_copy = sample.files.clone();
2815 let updated_files: Vec<crate::SampleFile> = files_copy
2816 .iter()
2817 .map(|file| {
2818 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
2819 })
2820 .collect();
2821
2822 sample.files = updated_files;
2823 sample
2824 })
2825 .collect())
2826 }
2827
2828 fn process_sample_file(
2829 &self,
2830 file: &crate::SampleFile,
2831 sample_uuid: &str,
2832 sample: &mut Sample,
2833 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
2834 ) -> crate::SampleFile {
2835 use std::path::Path;
2836
2837 // Handle files with raw bytes (e.g., from ZIP archives)
2838 if let Some(bytes) = file.bytes()
2839 && let Some(filename) = file.filename()
2840 {
2841 // For image files with bytes, try to extract dimensions if not already set
2842 if file.file_type() == "image"
2843 && (sample.width.is_none() || sample.height.is_none())
2844 && let Ok(size) = imagesize::blob_size(bytes)
2845 {
2846 sample.width = Some(size.width as u32);
2847 sample.height = Some(size.height as u32);
2848 }
2849
2850 // Store the bytes for later upload
2851 files_to_upload.push((
2852 sample_uuid.to_string(),
2853 file.file_type().to_string(),
2854 FileSource::Bytes(bytes.to_vec()),
2855 filename.to_string(),
2856 ));
2857
2858 // Return SampleFile with just the filename
2859 return crate::SampleFile::with_filename(
2860 file.file_type().to_string(),
2861 filename.to_string(),
2862 );
2863 }
2864
2865 // Handle files with local paths
2866 if let Some(filename) = file.filename() {
2867 let path = Path::new(filename);
2868
2869 // Check if this is a valid local file path
2870 if path.exists()
2871 && path.is_file()
2872 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
2873 {
2874 // For image files, try to extract dimensions if not already set
2875 if file.file_type() == "image"
2876 && (sample.width.is_none() || sample.height.is_none())
2877 && let Ok(size) = imagesize::size(path)
2878 {
2879 sample.width = Some(size.width as u32);
2880 sample.height = Some(size.height as u32);
2881 }
2882
2883 // Store the full path for later upload
2884 files_to_upload.push((
2885 sample_uuid.to_string(),
2886 file.file_type().to_string(),
2887 FileSource::Path(path.to_path_buf()),
2888 basename.to_string(),
2889 ));
2890
2891 // Return SampleFile with just the basename
2892 return crate::SampleFile::with_filename(
2893 file.file_type().to_string(),
2894 basename.to_string(),
2895 );
2896 }
2897 }
2898 // Return the file unchanged if not a local path
2899 file.clone()
2900 }
2901
2902 async fn upload_sample_files(
2903 &self,
2904 results: &[crate::SamplesPopulateResult],
2905 files_to_upload: Vec<(String, String, FileSource, String)>,
2906 progress: Option<Sender<Progress>>,
2907 concurrency: Option<usize>,
2908 ) -> Result<(), Error> {
2909 // Build a map from (sample_uuid, basename) -> file source
2910 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
2911 for (uuid, _file_type, source, basename) in files_to_upload {
2912 upload_map.insert((uuid, basename), source);
2913 }
2914
2915 let http = self.bulk_http.clone();
2916
2917 // Extract the data we need for parallel upload
2918 let upload_tasks: Vec<_> = results
2919 .iter()
2920 .map(|result| (result.uuid.clone(), result.urls.clone()))
2921 .collect();
2922
2923 parallel_foreach_items(
2924 upload_tasks,
2925 progress.clone(),
2926 concurrency,
2927 move |(uuid, urls)| {
2928 let http = http.clone();
2929 let upload_map = upload_map.clone();
2930
2931 async move {
2932 // Upload all files for this sample
2933 for url_info in &urls {
2934 if let Some(source) =
2935 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
2936 {
2937 match source {
2938 FileSource::Path(path) => {
2939 upload_file_to_presigned_url(
2940 http.clone(),
2941 &url_info.url,
2942 path.clone(),
2943 )
2944 .await?;
2945 }
2946 FileSource::Bytes(bytes) => {
2947 upload_bytes_to_presigned_url(
2948 http.clone(),
2949 &url_info.url,
2950 bytes.clone(),
2951 &url_info.filename,
2952 )
2953 .await?;
2954 }
2955 }
2956 }
2957 }
2958
2959 Ok(())
2960 }
2961 },
2962 )
2963 .await
2964 }
2965
2966 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2967 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
2968 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
2969 if !url.starts_with("http://") && !url.starts_with("https://") {
2970 return Err(Error::InvalidParameters(format!(
2971 "Invalid URL (must be absolute): {}",
2972 url
2973 )));
2974 }
2975
2976 let resp = self.bulk_http.get(url).send().await?;
2977
2978 if !resp.status().is_success() {
2979 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
2980 }
2981
2982 let bytes = resp.bytes().await?;
2983 Ok(bytes.to_vec())
2984 }
2985
2986 /// Get samples as a DataFrame with complete 2025.10 schema.
2987 ///
2988 /// This is the recommended method for obtaining dataset annotations in
2989 /// DataFrame format. It includes all sample metadata (size, location,
2990 /// pose, degradation) as optional columns.
2991 ///
2992 /// # Arguments
2993 ///
2994 /// * `dataset_id` - Dataset identifier
2995 /// * `annotation_set_id` - Optional annotation set filter
2996 /// * `groups` - Dataset groups to include (train, val, test)
2997 /// * `types` - Annotation types to filter (bbox, box3d, mask)
2998 /// * `progress` - Optional progress callback
2999 ///
3000 /// # Progress
3001 ///
3002 /// Reports progress with `status: None` as samples are fetched from the
3003 /// server in paginated batches. Progress unit is samples fetched. This
3004 /// method delegates to [`samples()`](Self::samples) and shares its
3005 /// progress behavior.
3006 ///
3007 /// # Example
3008 ///
3009 /// ```rust,no_run
3010 /// use edgefirst_client::Client;
3011 ///
3012 /// # async fn example() -> Result<(), edgefirst_client::Error> {
3013 /// # let client = Client::new()?;
3014 /// # let dataset_id = 1.into();
3015 /// # let annotation_set_id = 1.into();
3016 /// let df = client
3017 /// .samples_dataframe(
3018 /// dataset_id,
3019 /// Some(annotation_set_id),
3020 /// &["train".to_string()],
3021 /// &[],
3022 /// None,
3023 /// )
3024 /// .await?;
3025 /// println!("DataFrame shape: {:?}", df.shape());
3026 /// # Ok(())
3027 /// # }
3028 /// ```
3029 #[cfg(feature = "polars")]
3030 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3031 pub async fn samples_dataframe(
3032 &self,
3033 dataset_id: DatasetID,
3034 annotation_set_id: Option<AnnotationSetID>,
3035 groups: &[String],
3036 types: &[AnnotationType],
3037 progress: Option<Sender<Progress>>,
3038 ) -> Result<DataFrame, Error> {
3039 use crate::dataset::samples_dataframe;
3040
3041 let samples = self
3042 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
3043 .await?;
3044 samples_dataframe(&samples)
3045 }
3046
3047 /// List available snapshots. If a name is provided, only snapshots
3048 /// containing that name are returned.
3049 ///
3050 /// Results are sorted by match quality: exact matches first, then
3051 /// case-insensitive exact matches, then shorter descriptions (more
3052 /// specific), then alphabetically.
3053 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3054 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
3055 let snapshots: Vec<Snapshot> = self
3056 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
3057 .await?;
3058 if let Some(name) = name {
3059 Ok(filter_and_sort_by_name(snapshots, name, |s| {
3060 s.description()
3061 }))
3062 } else {
3063 Ok(snapshots)
3064 }
3065 }
3066
3067 /// Get the snapshot with the specified id.
3068 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3069 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
3070 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3071 self.rpc("snapshots.get".to_owned(), Some(params)).await
3072 }
3073
3074 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
3075 ///
3076 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
3077 /// pairs) that serve two primary purposes:
3078 ///
3079 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
3080 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
3081 /// restored with AGTG (Automatic Ground Truth Generation) and optional
3082 /// auto-depth processing.
3083 ///
3084 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
3085 /// migration between EdgeFirst Studio instances using the create →
3086 /// download → upload → restore workflow.
3087 ///
3088 /// Large files are automatically chunked into 100MB parts and uploaded
3089 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
3090 /// is streamed without loading into memory, maintaining constant memory
3091 /// usage.
3092 ///
3093 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3094 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
3095 /// better for large files to avoid timeout issues. Higher values (16-32)
3096 /// are better for many small files.
3097 ///
3098 /// # Arguments
3099 ///
3100 /// * `path` - Local file path to MCAP file or directory containing
3101 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
3102 /// * `progress` - Optional channel to receive upload progress updates
3103 ///
3104 /// # Progress
3105 ///
3106 /// Reports progress with `status: None` as file data is uploaded. Progress
3107 /// unit is bytes uploaded. For single files, total is the file size. For
3108 /// directories, total is the combined size of all files.
3109 ///
3110 /// # Returns
3111 ///
3112 /// Returns a `Snapshot` object with ID, description, status, path, and
3113 /// creation timestamp on success.
3114 ///
3115 /// # Errors
3116 ///
3117 /// Returns an error if:
3118 /// * Path doesn't exist or contains invalid UTF-8
3119 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
3120 /// * Upload fails or network error occurs
3121 /// * Server rejects the snapshot
3122 ///
3123 /// # Example
3124 ///
3125 /// ```no_run
3126 /// # use edgefirst_client::{Client, Progress};
3127 /// # use tokio::sync::mpsc;
3128 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3129 /// let client = Client::new()?.with_token_path(None)?;
3130 ///
3131 /// // Upload MCAP file with progress tracking
3132 /// let (tx, mut rx) = mpsc::channel(1);
3133 /// tokio::spawn(async move {
3134 /// while let Some(Progress {
3135 /// current,
3136 /// total,
3137 /// status,
3138 /// }) = rx.recv().await
3139 /// {
3140 /// println!(
3141 /// "{}: {}/{} bytes ({:.1}%)",
3142 /// status.as_deref().unwrap_or("Upload"),
3143 /// current,
3144 /// total,
3145 /// (current as f64 / total as f64) * 100.0
3146 /// );
3147 /// }
3148 /// });
3149 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
3150 /// println!("Created snapshot: {:?}", snapshot.id());
3151 ///
3152 /// // Upload dataset directory (no progress)
3153 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
3154 /// # Ok(())
3155 /// # }
3156 /// ```
3157 ///
3158 /// # See Also
3159 ///
3160 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3161 /// dataset
3162 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3163 /// data
3164 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3165 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3166 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
3167 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3168 pub async fn create_snapshot(
3169 &self,
3170 path: &str,
3171 progress: Option<Sender<Progress>>,
3172 ) -> Result<Snapshot, Error> {
3173 let path = Path::new(path);
3174
3175 if path.is_dir() {
3176 let path_str = path.to_str().ok_or_else(|| {
3177 Error::IoError(std::io::Error::new(
3178 std::io::ErrorKind::InvalidInput,
3179 "Path contains invalid UTF-8",
3180 ))
3181 })?;
3182 return self.create_snapshot_folder(path_str, progress).await;
3183 }
3184
3185 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3186 Error::IoError(std::io::Error::new(
3187 std::io::ErrorKind::InvalidInput,
3188 "Invalid filename",
3189 ))
3190 })?;
3191 let total = path.metadata()?.len() as usize;
3192 let current = Arc::new(AtomicUsize::new(0));
3193
3194 if let Some(progress) = &progress {
3195 let _ = progress
3196 .send(Progress {
3197 current: 0,
3198 total,
3199 status: None,
3200 })
3201 .await;
3202 }
3203
3204 let params = SnapshotCreateMultipartParams {
3205 snapshot_name: name.to_owned(),
3206 keys: vec![name.to_owned()],
3207 file_sizes: vec![total],
3208 snapshot_type: None,
3209 };
3210 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3211 .rpc(
3212 "snapshots.create_upload_url_multipart".to_owned(),
3213 Some(params),
3214 )
3215 .await?;
3216
3217 let snapshot_id = match multipart.get("snapshot_id") {
3218 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3219 _ => return Err(Error::InvalidResponse),
3220 };
3221
3222 let snapshot = self.snapshot(snapshot_id).await?;
3223 let part_prefix = snapshot
3224 .path()
3225 .split("::/")
3226 .last()
3227 .ok_or(Error::InvalidResponse)?
3228 .to_owned();
3229 let part_key = format!("{}/{}", part_prefix, name);
3230 let mut part = match multipart.get(&part_key) {
3231 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3232 _ => return Err(Error::InvalidResponse),
3233 }
3234 .clone();
3235 part.key = Some(part_key);
3236
3237 let params = upload_multipart(
3238 self.bulk_http.clone(),
3239 part.clone(),
3240 path.to_path_buf(),
3241 total,
3242 current,
3243 progress.clone(),
3244 )
3245 .await?;
3246
3247 let complete: String = self
3248 .rpc(
3249 "snapshots.complete_multipart_upload".to_owned(),
3250 Some(params),
3251 )
3252 .await?;
3253 debug!("Snapshot Multipart Complete: {:?}", complete);
3254
3255 let params: SnapshotStatusParams = SnapshotStatusParams {
3256 snapshot_id,
3257 status: "available".to_owned(),
3258 };
3259 let _: SnapshotStatusResult = self
3260 .rpc("snapshots.update".to_owned(), Some(params))
3261 .await?;
3262
3263 if let Some(progress) = progress {
3264 drop(progress);
3265 }
3266
3267 self.snapshot(snapshot_id).await
3268 }
3269
3270 async fn create_snapshot_folder(
3271 &self,
3272 path: &str,
3273 progress: Option<Sender<Progress>>,
3274 ) -> Result<Snapshot, Error> {
3275 let path = Path::new(path);
3276 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3277 Error::IoError(std::io::Error::new(
3278 std::io::ErrorKind::InvalidInput,
3279 "Invalid directory name",
3280 ))
3281 })?;
3282
3283 let files = WalkDir::new(path)
3284 .into_iter()
3285 .filter_map(|entry| entry.ok())
3286 .filter(|entry| entry.file_type().is_file())
3287 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3288 .collect::<Vec<_>>();
3289
3290 let total: usize = files
3291 .iter()
3292 .filter_map(|file| path.join(file).metadata().ok())
3293 .map(|metadata| metadata.len() as usize)
3294 .sum();
3295 let current = Arc::new(AtomicUsize::new(0));
3296
3297 if let Some(progress) = &progress {
3298 let _ = progress
3299 .send(Progress {
3300 current: 0,
3301 total,
3302 status: None,
3303 })
3304 .await;
3305 }
3306
3307 let keys = files
3308 .iter()
3309 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3310 .collect::<Vec<_>>();
3311 let file_sizes = files
3312 .iter()
3313 .filter_map(|key| path.join(key).metadata().ok())
3314 .map(|metadata| metadata.len() as usize)
3315 .collect::<Vec<_>>();
3316
3317 let params = SnapshotCreateMultipartParams {
3318 snapshot_name: name.to_owned(),
3319 keys,
3320 file_sizes,
3321 snapshot_type: None,
3322 };
3323
3324 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3325 .rpc(
3326 "snapshots.create_upload_url_multipart".to_owned(),
3327 Some(params),
3328 )
3329 .await?;
3330
3331 let snapshot_id = match multipart.get("snapshot_id") {
3332 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3333 _ => return Err(Error::InvalidResponse),
3334 };
3335
3336 let snapshot = self.snapshot(snapshot_id).await?;
3337 let part_prefix = snapshot
3338 .path()
3339 .split("::/")
3340 .last()
3341 .ok_or(Error::InvalidResponse)?
3342 .to_owned();
3343
3344 for file in files {
3345 let file_str = file.to_str().ok_or_else(|| {
3346 Error::IoError(std::io::Error::new(
3347 std::io::ErrorKind::InvalidInput,
3348 "File path contains invalid UTF-8",
3349 ))
3350 })?;
3351 let part_key = format!("{}/{}", part_prefix, file_str);
3352 let mut part = match multipart.get(&part_key) {
3353 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3354 _ => return Err(Error::InvalidResponse),
3355 }
3356 .clone();
3357 part.key = Some(part_key);
3358
3359 let params = upload_multipart(
3360 self.bulk_http.clone(),
3361 part.clone(),
3362 path.join(file),
3363 total,
3364 current.clone(),
3365 progress.clone(),
3366 )
3367 .await?;
3368
3369 let complete: String = self
3370 .rpc(
3371 "snapshots.complete_multipart_upload".to_owned(),
3372 Some(params),
3373 )
3374 .await?;
3375 debug!("Snapshot Part Complete: {:?}", complete);
3376 }
3377
3378 let params = SnapshotStatusParams {
3379 snapshot_id,
3380 status: "available".to_owned(),
3381 };
3382 let _: SnapshotStatusResult = self
3383 .rpc("snapshots.update".to_owned(), Some(params))
3384 .await?;
3385
3386 if let Some(progress) = progress {
3387 drop(progress);
3388 }
3389
3390 self.snapshot(snapshot_id).await
3391 }
3392
3393 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3394 ///
3395 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3396 /// This format is the native EdgeFirst Dataset Format used for efficient
3397 /// dataset storage and transfer.
3398 ///
3399 /// # Arguments
3400 ///
3401 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3402 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3403 /// * `description` - Optional description for the snapshot
3404 /// * `progress` - Optional progress channel for upload tracking
3405 ///
3406 /// # File Requirements
3407 ///
3408 /// - Arrow file must have `.arrow` extension
3409 /// - ZIP file must have `.zip` extension
3410 /// - Both files must exist and be readable
3411 ///
3412 /// # Example
3413 ///
3414 /// ```no_run
3415 /// # use edgefirst_client::Client;
3416 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3417 /// let client = Client::new()?.with_token_path(None)?;
3418 ///
3419 /// let snapshot = client
3420 /// .create_snapshot_edgefirst_format(
3421 /// "dataset.arrow",
3422 /// "dataset.zip",
3423 /// Some("My Dataset Snapshot"),
3424 /// None,
3425 /// )
3426 /// .await?;
3427 /// println!("Created snapshot: {}", snapshot.id());
3428 /// # Ok(())
3429 /// # }
3430 /// ```
3431 ///
3432 /// # See Also
3433 ///
3434 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3435 /// folder
3436 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3437 /// dataset
3438 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3439 pub async fn create_snapshot_edgefirst_format(
3440 &self,
3441 arrow_path: &str,
3442 zip_path: &str,
3443 description: Option<&str>,
3444 progress: Option<Sender<Progress>>,
3445 ) -> Result<Snapshot, Error> {
3446 let arrow_path = Path::new(arrow_path);
3447 let zip_path = Path::new(zip_path);
3448
3449 // Validate files exist
3450 if !arrow_path.exists() {
3451 return Err(Error::IoError(std::io::Error::new(
3452 std::io::ErrorKind::NotFound,
3453 format!("Arrow file not found: {}", arrow_path.display()),
3454 )));
3455 }
3456 if !zip_path.exists() {
3457 return Err(Error::IoError(std::io::Error::new(
3458 std::io::ErrorKind::NotFound,
3459 format!("ZIP file not found: {}", zip_path.display()),
3460 )));
3461 }
3462
3463 // Get file names
3464 let arrow_name = arrow_path
3465 .file_name()
3466 .and_then(|n| n.to_str())
3467 .ok_or_else(|| {
3468 Error::IoError(std::io::Error::new(
3469 std::io::ErrorKind::InvalidInput,
3470 "Invalid Arrow filename",
3471 ))
3472 })?;
3473 let zip_name = zip_path
3474 .file_name()
3475 .and_then(|n| n.to_str())
3476 .ok_or_else(|| {
3477 Error::IoError(std::io::Error::new(
3478 std::io::ErrorKind::InvalidInput,
3479 "Invalid ZIP filename",
3480 ))
3481 })?;
3482
3483 // Generate snapshot name from arrow file (without extension)
3484 let snapshot_name = description
3485 .map(|s| s.to_string())
3486 .or_else(|| {
3487 arrow_path
3488 .file_stem()
3489 .and_then(|s| s.to_str())
3490 .map(|s| s.to_string())
3491 })
3492 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3493
3494 // Calculate file sizes
3495 let arrow_size = arrow_path.metadata()?.len() as usize;
3496 let zip_size = zip_path.metadata()?.len() as usize;
3497 let total = arrow_size + zip_size;
3498 let current = Arc::new(AtomicUsize::new(0));
3499
3500 if let Some(progress) = &progress {
3501 let _ = progress
3502 .send(Progress {
3503 current: 0,
3504 total,
3505 status: None,
3506 })
3507 .await;
3508 }
3509
3510 // Create multipart upload request with "ziparrow" type
3511 let params = SnapshotCreateMultipartParams {
3512 snapshot_name,
3513 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
3514 file_sizes: vec![arrow_size, zip_size],
3515 snapshot_type: Some("ziparrow".to_string()),
3516 };
3517
3518 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3519 .rpc(
3520 "snapshots.create_upload_url_multipart".to_owned(),
3521 Some(params),
3522 )
3523 .await?;
3524
3525 let snapshot_id = match multipart.get("snapshot_id") {
3526 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3527 _ => return Err(Error::InvalidResponse),
3528 };
3529
3530 let snapshot = self.snapshot(snapshot_id).await?;
3531 let part_prefix = snapshot
3532 .path()
3533 .split("::/")
3534 .last()
3535 .ok_or(Error::InvalidResponse)?
3536 .to_owned();
3537
3538 // Upload Arrow file
3539 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
3540 let mut arrow_part = match multipart.get(&arrow_key) {
3541 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3542 _ => return Err(Error::InvalidResponse),
3543 };
3544 arrow_part.key = Some(arrow_key);
3545
3546 let params = upload_multipart(
3547 self.bulk_http.clone(),
3548 arrow_part,
3549 arrow_path.to_path_buf(),
3550 total,
3551 current.clone(),
3552 progress.clone(),
3553 )
3554 .await?;
3555
3556 let _: String = self
3557 .rpc(
3558 "snapshots.complete_multipart_upload".to_owned(),
3559 Some(params),
3560 )
3561 .await?;
3562 debug!("Arrow file upload complete");
3563
3564 // Upload ZIP file
3565 let zip_key = format!("{}/{}", part_prefix, zip_name);
3566 let mut zip_part = match multipart.get(&zip_key) {
3567 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
3568 _ => return Err(Error::InvalidResponse),
3569 };
3570 zip_part.key = Some(zip_key);
3571
3572 let params = upload_multipart(
3573 self.bulk_http.clone(),
3574 zip_part,
3575 zip_path.to_path_buf(),
3576 total,
3577 current.clone(),
3578 progress.clone(),
3579 )
3580 .await?;
3581
3582 let _: String = self
3583 .rpc(
3584 "snapshots.complete_multipart_upload".to_owned(),
3585 Some(params),
3586 )
3587 .await?;
3588 debug!("ZIP file upload complete");
3589
3590 // Mark snapshot as available
3591 let params = SnapshotStatusParams {
3592 snapshot_id,
3593 status: "available".to_owned(),
3594 };
3595 let _: SnapshotStatusResult = self
3596 .rpc("snapshots.update".to_owned(), Some(params))
3597 .await?;
3598
3599 if let Some(progress) = progress {
3600 drop(progress);
3601 }
3602
3603 self.snapshot(snapshot_id).await
3604 }
3605
3606 /// Delete a snapshot from EdgeFirst Studio.
3607 ///
3608 /// Permanently removes a snapshot and its associated data. This operation
3609 /// cannot be undone.
3610 ///
3611 /// # Arguments
3612 ///
3613 /// * `snapshot_id` - The snapshot ID to delete
3614 ///
3615 /// # Errors
3616 ///
3617 /// Returns an error if:
3618 /// * Snapshot doesn't exist
3619 /// * User lacks permission to delete the snapshot
3620 /// * Server error occurs
3621 ///
3622 /// # Example
3623 ///
3624 /// ```no_run
3625 /// # use edgefirst_client::{Client, SnapshotID};
3626 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3627 /// let client = Client::new()?.with_token_path(None)?;
3628 /// let snapshot_id = SnapshotID::from(123);
3629 /// client.delete_snapshot(snapshot_id).await?;
3630 /// # Ok(())
3631 /// # }
3632 /// ```
3633 ///
3634 /// # See Also
3635 ///
3636 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3637 /// * [`snapshots`](Self::snapshots) - List all snapshots
3638 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3639 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
3640 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3641 let _: serde_json::Value = self
3642 .rpc("snapshots.delete".to_owned(), Some(params))
3643 .await?;
3644 Ok(())
3645 }
3646
3647 /// Create a snapshot from an existing dataset on the server.
3648 ///
3649 /// Triggers server-side snapshot generation which exports the dataset's
3650 /// images and annotations into a downloadable EdgeFirst Dataset Format
3651 /// snapshot.
3652 ///
3653 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
3654 /// while restore creates a dataset from a snapshot, this method creates a
3655 /// snapshot from a dataset.
3656 ///
3657 /// # Arguments
3658 ///
3659 /// * `dataset_id` - The dataset ID to create snapshot from
3660 /// * `description` - Description for the created snapshot
3661 ///
3662 /// # Returns
3663 ///
3664 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
3665 /// for monitoring progress.
3666 ///
3667 /// # Errors
3668 ///
3669 /// Returns an error if:
3670 /// * Dataset doesn't exist
3671 /// * User lacks permission to access the dataset
3672 /// * Server rejects the request
3673 ///
3674 /// # Example
3675 ///
3676 /// ```no_run
3677 /// # use edgefirst_client::{Client, DatasetID};
3678 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3679 /// let client = Client::new()?.with_token_path(None)?;
3680 /// let dataset_id = DatasetID::from(123);
3681 ///
3682 /// // Create snapshot from dataset (all annotation sets)
3683 /// let result = client
3684 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
3685 /// .await?;
3686 /// println!("Created snapshot: {:?}", result.id);
3687 ///
3688 /// // Monitor progress via task ID
3689 /// if let Some(task_id) = result.task_id {
3690 /// println!("Task: {}", task_id);
3691 /// }
3692 /// # Ok(())
3693 /// # }
3694 /// ```
3695 ///
3696 /// # See Also
3697 ///
3698 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
3699 /// snapshot
3700 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3701 /// dataset
3702 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3703 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3704 pub async fn create_snapshot_from_dataset(
3705 &self,
3706 dataset_id: DatasetID,
3707 description: &str,
3708 annotation_set_id: Option<AnnotationSetID>,
3709 ) -> Result<SnapshotFromDatasetResult, Error> {
3710 // Resolve annotation_set_id: use provided value or fetch default
3711 let annotation_set_id = match annotation_set_id {
3712 Some(id) => id,
3713 None => {
3714 // Fetch annotation sets and find default ("annotations") or use first
3715 let sets = self.annotation_sets(dataset_id).await?;
3716 if sets.is_empty() {
3717 return Err(Error::InvalidParameters(
3718 "No annotation sets available for dataset".to_owned(),
3719 ));
3720 }
3721 // Look for "annotations" set (default), otherwise use first
3722 sets.iter()
3723 .find(|s| s.name() == "annotations")
3724 .unwrap_or(&sets[0])
3725 .id()
3726 }
3727 };
3728 let params = SnapshotCreateFromDataset {
3729 description: description.to_owned(),
3730 dataset_id,
3731 annotation_set_id,
3732 };
3733 self.rpc("snapshots.create".to_owned(), Some(params)).await
3734 }
3735
3736 /// Download a snapshot from EdgeFirst Studio to local storage.
3737 ///
3738 /// Downloads all files in a snapshot (single MCAP file or directory of
3739 /// EdgeFirst Dataset Format files) to the specified output path. Files are
3740 /// downloaded concurrently with progress tracking.
3741 ///
3742 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3743 /// downloads (default: half of CPU cores, min 2, max 8).
3744 ///
3745 /// # Arguments
3746 ///
3747 /// * `snapshot_id` - The snapshot ID to download
3748 /// * `output` - Local directory path to save downloaded files
3749 /// * `progress` - Optional channel to receive download progress updates
3750 ///
3751 /// # Progress
3752 ///
3753 /// Reports progress with `status: None` as file data is received. Progress
3754 /// unit is bytes downloaded across all files combined. The total
3755 /// accumulates as file sizes become known (from HTTP Content-Length
3756 /// headers), so both `current` and `total` may increase during
3757 /// download.
3758 ///
3759 /// # Errors
3760 ///
3761 /// Returns an error if:
3762 /// * Snapshot doesn't exist
3763 /// * Output directory cannot be created
3764 /// * Download fails or network error occurs
3765 ///
3766 /// # Example
3767 ///
3768 /// ```no_run
3769 /// # use edgefirst_client::{Client, SnapshotID, Progress};
3770 /// # use tokio::sync::mpsc;
3771 /// # use std::path::PathBuf;
3772 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3773 /// let client = Client::new()?.with_token_path(None)?;
3774 /// let snapshot_id = SnapshotID::from(123);
3775 ///
3776 /// // Download with progress tracking
3777 /// let (tx, mut rx) = mpsc::channel(1);
3778 /// tokio::spawn(async move {
3779 /// while let Some(Progress {
3780 /// current,
3781 /// total,
3782 /// status,
3783 /// }) = rx.recv().await
3784 /// {
3785 /// println!(
3786 /// "{}: {}/{} bytes",
3787 /// status.as_deref().unwrap_or("Download"),
3788 /// current,
3789 /// total
3790 /// );
3791 /// }
3792 /// });
3793 /// client
3794 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
3795 /// .await?;
3796 /// # Ok(())
3797 /// # }
3798 /// ```
3799 ///
3800 /// # See Also
3801 ///
3802 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3803 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3804 /// dataset
3805 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3806 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
3807 pub async fn download_snapshot(
3808 &self,
3809 snapshot_id: SnapshotID,
3810 output: PathBuf,
3811 progress: Option<Sender<Progress>>,
3812 ) -> Result<(), Error> {
3813 fs::create_dir_all(&output).await?;
3814
3815 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3816 let items: HashMap<String, String> = self
3817 .rpc("snapshots.create_download_url".to_owned(), Some(params))
3818 .await?;
3819
3820 // Single-phase: each task holds its semaphore permit for the full
3821 // lifetime of the request (GET → headers → stream → disk). This bounds
3822 // the number of simultaneously-open connections to max_tasks() and
3823 // avoids accumulating all responses in memory before streaming.
3824 //
3825 // total is updated atomically as each response's Content-Length header
3826 // arrives, so progress tracking is accurate without a separate phase.
3827 let http = self.bulk_http.clone();
3828 let current = Arc::new(AtomicUsize::new(0));
3829 let total = Arc::new(AtomicUsize::new(0));
3830 let sem = Arc::new(Semaphore::new(max_tasks()));
3831
3832 let tasks = items
3833 .into_iter()
3834 .map(|(key, url)| {
3835 let http = http.clone();
3836 let output = output.clone();
3837 let progress = progress.clone();
3838 let current = current.clone();
3839 let total = total.clone();
3840 let sem = sem.clone();
3841
3842 tokio::spawn(async move {
3843 let _permit = sem.acquire().await.map_err(|_| {
3844 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
3845 })?;
3846
3847 let res = http.get(url).send().await?;
3848 let res = res.error_for_status()?;
3849
3850 // Contribute this file's size to the running total so the
3851 // caller's progress bar knows the overall scope.
3852 if let Some(len) = res.content_length() {
3853 total.fetch_add(len as usize, Ordering::SeqCst);
3854 }
3855
3856 let mut file = File::create(output.join(key)).await?;
3857 let mut stream = res.bytes_stream();
3858
3859 while let Some(chunk) = stream.next().await {
3860 let chunk = chunk?;
3861 file.write_all(&chunk).await?;
3862 let len = chunk.len();
3863
3864 if let Some(progress) = &progress {
3865 let cur = current.fetch_add(len, Ordering::SeqCst) + len;
3866 let tot = total.load(Ordering::SeqCst);
3867 let _ = progress
3868 .send(Progress {
3869 current: cur,
3870 total: tot,
3871 status: None,
3872 })
3873 .await;
3874 }
3875 }
3876
3877 Ok::<(), Error>(())
3878 })
3879 })
3880 .collect::<Vec<_>>();
3881
3882 join_all(tasks)
3883 .await
3884 .into_iter()
3885 .collect::<Result<Vec<_>, _>>()?
3886 .into_iter()
3887 .collect::<Result<Vec<_>, _>>()?;
3888
3889 Ok(())
3890 }
3891
3892 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
3893 ///
3894 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
3895 /// the specified project. For MCAP files, supports:
3896 ///
3897 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
3898 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
3899 /// present)
3900 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
3901 /// * **Topic filtering**: Select specific MCAP topics to restore
3902 ///
3903 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
3904 /// dataset structure.
3905 ///
3906 /// # Arguments
3907 ///
3908 /// * `project_id` - Target project ID
3909 /// * `snapshot_id` - Snapshot ID to restore
3910 /// * `topics` - MCAP topics to include (empty = all topics)
3911 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
3912 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
3913 /// * `dataset_name` - Optional custom dataset name
3914 /// * `dataset_description` - Optional dataset description
3915 ///
3916 /// # Returns
3917 ///
3918 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
3919 ///
3920 /// # Errors
3921 ///
3922 /// Returns an error if:
3923 /// * Snapshot or project doesn't exist
3924 /// * Snapshot format is invalid
3925 /// * Server rejects restoration parameters
3926 ///
3927 /// # Example
3928 ///
3929 /// ```no_run
3930 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
3931 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3932 /// let client = Client::new()?.with_token_path(None)?;
3933 /// let project_id = ProjectID::from(1);
3934 /// let snapshot_id = SnapshotID::from(123);
3935 ///
3936 /// // Restore MCAP with AGTG for "person" and "car" detection
3937 /// let result = client
3938 /// .restore_snapshot(
3939 /// project_id,
3940 /// snapshot_id,
3941 /// &[], // All topics
3942 /// &["person".to_string(), "car".to_string()], // AGTG labels
3943 /// true, // Auto-depth
3944 /// Some("Highway Dataset"),
3945 /// Some("Collected on I-95"),
3946 /// )
3947 /// .await?;
3948 /// println!("Restored to dataset: {:?}", result.dataset_id);
3949 /// # Ok(())
3950 /// # }
3951 /// ```
3952 ///
3953 /// # See Also
3954 ///
3955 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
3956 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3957 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3958 #[allow(clippy::too_many_arguments)]
3959 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3960 pub async fn restore_snapshot(
3961 &self,
3962 project_id: ProjectID,
3963 snapshot_id: SnapshotID,
3964 topics: &[String],
3965 autolabel: &[String],
3966 autodepth: bool,
3967 dataset_name: Option<&str>,
3968 dataset_description: Option<&str>,
3969 ) -> Result<SnapshotRestoreResult, Error> {
3970 let params = SnapshotRestore {
3971 project_id,
3972 snapshot_id,
3973 fps: 1,
3974 autodepth,
3975 agtg_pipeline: !autolabel.is_empty(),
3976 autolabel: autolabel.to_vec(),
3977 topics: topics.to_vec(),
3978 dataset_name: dataset_name.map(|s| s.to_owned()),
3979 dataset_description: dataset_description.map(|s| s.to_owned()),
3980 };
3981 self.rpc("snapshots.restore".to_owned(), Some(params)).await
3982 }
3983
3984 /// Returns a list of experiments available to the user. The experiments
3985 /// are returned as a vector of Experiment objects. If name is provided
3986 /// then only experiments containing this string are returned.
3987 ///
3988 /// Results are sorted by match quality: exact matches first, then
3989 /// case-insensitive exact matches, then shorter names (more specific),
3990 /// then alphabetically.
3991 ///
3992 /// Experiments provide a method of organizing training and validation
3993 /// sessions together and are akin to an Experiment in MLFlow terminology.
3994 /// Each experiment can have multiple trainer sessions associated with it,
3995 /// these would be akin to runs in MLFlow terminology.
3996 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3997 pub async fn experiments(
3998 &self,
3999 project_id: ProjectID,
4000 name: Option<&str>,
4001 ) -> Result<Vec<Experiment>, Error> {
4002 let params = HashMap::from([("project_id", project_id)]);
4003 let experiments: Vec<Experiment> =
4004 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
4005 if let Some(name) = name {
4006 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
4007 } else {
4008 Ok(experiments)
4009 }
4010 }
4011
4012 /// Return the experiment with the specified experiment ID. If the
4013 /// experiment does not exist, an error is returned.
4014 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4015 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
4016 let params = HashMap::from([("trainer_id", experiment_id)]);
4017 self.rpc("trainer.get".to_owned(), Some(params)).await
4018 }
4019
4020 /// Returns a list of trainer sessions available to the user. The trainer
4021 /// sessions are returned as a vector of TrainingSession objects. If name
4022 /// is provided then only trainer sessions containing this string are
4023 /// returned.
4024 ///
4025 /// Results are sorted by match quality: exact matches first, then
4026 /// case-insensitive exact matches, then shorter names (more specific),
4027 /// then alphabetically.
4028 ///
4029 /// Trainer sessions are akin to runs in MLFlow terminology. These
4030 /// represent an actual training session which will produce metrics and
4031 /// model artifacts.
4032 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4033 pub async fn training_sessions(
4034 &self,
4035 experiment_id: ExperimentID,
4036 name: Option<&str>,
4037 ) -> Result<Vec<TrainingSession>, Error> {
4038 let params = HashMap::from([("trainer_id", experiment_id)]);
4039 let sessions: Vec<TrainingSession> = self
4040 .rpc("trainer.session.list".to_owned(), Some(params))
4041 .await?;
4042 if let Some(name) = name {
4043 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
4044 } else {
4045 Ok(sessions)
4046 }
4047 }
4048
4049 /// Return the trainer session with the specified trainer session ID. If
4050 /// the trainer session does not exist, an error is returned.
4051 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4052 pub async fn training_session(
4053 &self,
4054 session_id: TrainingSessionID,
4055 ) -> Result<TrainingSession, Error> {
4056 let params = HashMap::from([("trainer_session_id", session_id)]);
4057 self.rpc("trainer.session.get".to_owned(), Some(params))
4058 .await
4059 }
4060
4061 /// List validation sessions for the given project.
4062 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4063 pub async fn validation_sessions(
4064 &self,
4065 project_id: ProjectID,
4066 ) -> Result<Vec<ValidationSession>, Error> {
4067 let params = HashMap::from([("project_id", project_id)]);
4068 self.rpc("validate.session.list".to_owned(), Some(params))
4069 .await
4070 }
4071
4072 /// Retrieve a specific validation session.
4073 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4074 pub async fn validation_session(
4075 &self,
4076 session_id: ValidationSessionID,
4077 ) -> Result<ValidationSession, Error> {
4078 let params = HashMap::from([("validate_session_id", session_id)]);
4079 self.rpc("validate.session.get".to_owned(), Some(params))
4080 .await
4081 }
4082
4083 /// Create a new validation session via Studio's `cloud.server.start`.
4084 ///
4085 /// Pass `is_local: true` in the [`StartValidationRequest`] to create
4086 /// a **user-managed** session: the database row is created and the
4087 /// session is fully usable for data uploads / downloads / metrics,
4088 /// but no EC2 instance is provisioned and no automated validator
4089 /// pipeline is started. That is the mode our integration tests use
4090 /// — they create a session, exercise the wrapper APIs against it,
4091 /// then call [`Client::delete_validation_sessions`] in teardown so
4092 /// no stray sessions accumulate on the test account.
4093 ///
4094 /// Returns a [`NewValidationSession`] carrying the backing task id
4095 /// and the freshly-minted validation session id.
4096 ///
4097 /// # Errors
4098 ///
4099 /// Surfaces any RPC error from `cloud.server.start`. Common cases:
4100 /// `RpcError(101, …)` if a required entity is missing (project,
4101 /// training session, dataset, …); `PermissionDenied` if the caller
4102 /// can't write to the target project.
4103 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, req)))]
4104 pub async fn start_validation_session(
4105 &self,
4106 req: StartValidationRequest,
4107 ) -> Result<NewValidationSession, Error> {
4108 // Build the params shape the server expects. `cloud.server.start`
4109 // is intentionally generic — different server types pull
4110 // different fields out of `params` — so we serialize manually to
4111 // match the JS frontend's call site verbatim (see
4112 // `dve-frontend/src/components/ValidationPage/StartValidatorModal.vue`).
4113 let mut body = serde_json::Map::new();
4114 body.insert(
4115 "type".into(),
4116 serde_json::Value::String("validation".into()),
4117 );
4118 body.insert("name".into(), serde_json::Value::String(req.name));
4119 body.insert("project_id".into(), serde_json::to_value(req.project_id)?);
4120 body.insert(
4121 "training_session_id".into(),
4122 serde_json::to_value(req.training_session_id)?,
4123 );
4124 body.insert(
4125 "model_file".into(),
4126 serde_json::Value::String(req.model_file),
4127 );
4128 body.insert("val_type".into(), serde_json::Value::String(req.val_type));
4129 body.insert("is_local".into(), serde_json::Value::Bool(req.is_local));
4130 body.insert(
4131 "is_kubernetes".into(),
4132 serde_json::Value::Bool(req.is_kubernetes),
4133 );
4134
4135 // `validate.session` reads its config from `params.params` (one
4136 // extra envelope level). The outer `params` wrapper is required
4137 // even when the inner map is empty.
4138 let inner = serde_json::to_value(req.params)?;
4139 let mut outer = serde_json::Map::new();
4140 outer.insert("params".into(), inner);
4141 body.insert("params".into(), serde_json::Value::Object(outer));
4142
4143 if let Some(d) = req.description {
4144 body.insert("description".into(), serde_json::Value::String(d));
4145 }
4146 if let Some(id) = req.dataset_id {
4147 body.insert("dataset_id".into(), serde_json::to_value(id)?);
4148 }
4149 if let Some(id) = req.annotation_set_id {
4150 body.insert("annotation_set_id".into(), serde_json::to_value(id)?);
4151 }
4152 if let Some(id) = req.snapshot_id {
4153 body.insert("snapshot_id".into(), serde_json::to_value(id)?);
4154 }
4155
4156 self.rpc("cloud.server.start".to_owned(), Some(body)).await
4157 }
4158
4159 /// Delete one or more validation sessions via
4160 /// `validate.session.delete`.
4161 ///
4162 /// Used by integration tests to tear down sessions they created
4163 /// with [`Client::start_validation_session`]; idempotent against
4164 /// already-deleted ids on the server side (the RPC accepts the
4165 /// list, deletes what it can, and surfaces an error only if none
4166 /// of the ids were resolvable).
4167 ///
4168 /// # Errors
4169 ///
4170 /// Surfaces any RPC error from `validate.session.delete`. A
4171 /// `PermissionDenied` indicates the caller lacks
4172 /// `TrainerWrite` on at least one of the listed sessions.
4173 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4174 pub async fn delete_validation_sessions(
4175 &self,
4176 session_ids: &[ValidationSessionID],
4177 ) -> Result<(), Error> {
4178 let mut body = serde_json::Map::new();
4179 body.insert("session_ids".into(), serde_json::to_value(session_ids)?);
4180 let _: serde_json::Value = self
4181 .rpc("validate.session.delete".to_owned(), Some(body))
4182 .await?;
4183 Ok(())
4184 }
4185
4186 /// List the artifacts for the specified trainer session. The artifacts
4187 /// are returned as a vector of strings.
4188 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4189 pub async fn artifacts(
4190 &self,
4191 training_session_id: TrainingSessionID,
4192 ) -> Result<Vec<Artifact>, Error> {
4193 let params = HashMap::from([("training_session_id", training_session_id)]);
4194 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
4195 .await
4196 }
4197
4198 /// Download the model artifact for the specified trainer session to the
4199 /// specified file path, if path is not provided it will be downloaded to
4200 /// the current directory with the same filename.
4201 ///
4202 /// # Progress
4203 ///
4204 /// Reports progress with `status: None` as file data is received. Progress
4205 /// unit is bytes downloaded. Total is determined from the HTTP
4206 /// Content-Length header (may be 0 if server doesn't provide it).
4207 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4208 pub async fn download_artifact(
4209 &self,
4210 training_session_id: TrainingSessionID,
4211 modelname: &str,
4212 filename: Option<PathBuf>,
4213 progress: Option<Sender<Progress>>,
4214 ) -> Result<(), Error> {
4215 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
4216 let resp = self
4217 .bulk_http
4218 .get(format!(
4219 "{}/download_model?training_session_id={}&file={}",
4220 self.url,
4221 training_session_id.value(),
4222 modelname
4223 ))
4224 .header("Authorization", format!("Bearer {}", self.token().await))
4225 .send()
4226 .await?;
4227 if !resp.status().is_success() {
4228 let err = resp.error_for_status_ref().unwrap_err();
4229 return Err(Error::HttpError(err));
4230 }
4231
4232 if let Some(parent) = filename.parent() {
4233 fs::create_dir_all(parent).await?;
4234 }
4235
4236 stream_response_to_file(resp, &filename, progress).await
4237 }
4238
4239 /// Download the model checkpoint associated with the specified trainer
4240 /// session to the specified file path, if path is not provided it will be
4241 /// downloaded to the current directory with the same filename.
4242 ///
4243 /// There is no API for listing checkpoints it is expected that trainers are
4244 /// aware of possible checkpoints and their names within the checkpoint
4245 /// folder on the server.
4246 ///
4247 /// # Progress
4248 ///
4249 /// Reports progress with `status: None` as file data is received. Progress
4250 /// unit is bytes downloaded. Total is determined from the HTTP
4251 /// Content-Length header (may be 0 if server doesn't provide it).
4252 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4253 pub async fn download_checkpoint(
4254 &self,
4255 training_session_id: TrainingSessionID,
4256 checkpoint: &str,
4257 filename: Option<PathBuf>,
4258 progress: Option<Sender<Progress>>,
4259 ) -> Result<(), Error> {
4260 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
4261 let resp = self
4262 .bulk_http
4263 .get(format!(
4264 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
4265 self.url,
4266 training_session_id.value(),
4267 checkpoint
4268 ))
4269 .header("Authorization", format!("Bearer {}", self.token().await))
4270 .send()
4271 .await?;
4272 if !resp.status().is_success() {
4273 let err = resp.error_for_status_ref().unwrap_err();
4274 return Err(Error::HttpError(err));
4275 }
4276
4277 if let Some(parent) = filename.parent() {
4278 fs::create_dir_all(parent).await?;
4279 }
4280
4281 stream_response_to_file(resp, &filename, progress).await
4282 }
4283
4284 /// Return a list of tasks for the current user.
4285 ///
4286 /// # Arguments
4287 ///
4288 /// * `name` - Optional filter for task name (client-side substring match)
4289 /// * `workflow` - Optional filter for workflow/task type. If provided,
4290 /// filters server-side by exact match. Valid values include: "trainer",
4291 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
4292 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
4293 /// "convertor", "twostage"
4294 /// * `status` - Optional filter for task status (e.g., "running",
4295 /// "complete", "error")
4296 /// * `manager` - Optional filter for task manager type (e.g., "aws",
4297 /// "user", "kubernetes")
4298 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4299 pub async fn tasks(
4300 &self,
4301 name: Option<&str>,
4302 workflow: Option<&str>,
4303 status: Option<&str>,
4304 manager: Option<&str>,
4305 ) -> Result<Vec<Task>, Error> {
4306 let mut params = TasksListParams {
4307 continue_token: None,
4308 types: workflow.map(|w| vec![w.to_owned()]),
4309 status: status.map(|s| vec![s.to_owned()]),
4310 manager: manager.map(|m| vec![m.to_owned()]),
4311 };
4312 let mut tasks = Vec::new();
4313
4314 loop {
4315 let result = self
4316 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
4317 .await?;
4318 tasks.extend(result.tasks);
4319
4320 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
4321 params.continue_token = None;
4322 } else {
4323 params.continue_token = result.continue_token;
4324 }
4325
4326 if params.continue_token.is_none() {
4327 break;
4328 }
4329 }
4330
4331 if let Some(name) = name {
4332 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4333 }
4334
4335 Ok(tasks)
4336 }
4337
4338 /// Submits a job (app run) to the server and returns the resulting `Job`
4339 /// record (which carries the linked task id alongside the cloud-batch
4340 /// metadata).
4341 ///
4342 /// # Arguments
4343 /// * `app_name` - The name of the registered app to run (e.g., `"edgefirst-validator"`).
4344 /// * `job_name` - A user-defined label for this run.
4345 /// * `env` - Environment variables passed to the job (string-string map).
4346 /// * `data` - Job input payload (e.g., session ids, parameters).
4347 ///
4348 /// # Returns
4349 /// The full `Job` record returned by the server (wraps the BK_BATCH object),
4350 /// including AWS Batch job ID, state, and the linked `task_id`. Callers that
4351 /// only need the task ID can call `.task_id()` on the returned `Job`.
4352 pub async fn job_run(
4353 &self,
4354 app_name: &str,
4355 job_name: &str,
4356 env: std::collections::HashMap<String, String>,
4357 data: std::collections::HashMap<String, crate::api::Parameter>,
4358 ) -> Result<crate::api::Job, Error> {
4359 let req = JobRunRequest {
4360 name: app_name.to_owned(),
4361 job_name: job_name.to_owned(),
4362 env,
4363 data,
4364 };
4365 let resp: crate::api::Job = match self.rpc("job.run".to_owned(), Some(&req)).await {
4366 Ok(r) => r,
4367 Err(Error::RpcError(code, msg)) => {
4368 return Err(map_rpc_error("job.run", code, msg, None));
4369 }
4370 Err(e) => return Err(e),
4371 };
4372 Ok(resp)
4373 }
4374
4375 /// Requests a running job task be stopped.
4376 ///
4377 /// Returns `Ok(())` if the stop request was accepted by the server. The
4378 /// task may still take time to fully terminate; poll `task_info` if you
4379 /// need to wait for shutdown.
4380 pub async fn job_stop(&self, task_id: crate::api::TaskID) -> Result<(), Error> {
4381 let req = JobStopRequest {
4382 task_id: task_id.value(),
4383 };
4384 // We don't care about the response body; deserialize as serde_json::Value.
4385 let _resp: serde_json::Value = match self.rpc("job.stop".to_owned(), Some(&req)).await {
4386 Ok(r) => r,
4387 Err(Error::RpcError(code, msg)) => {
4388 return Err(map_rpc_error("job.stop", code, msg, Some(task_id)));
4389 }
4390 Err(e) => return Err(e),
4391 };
4392 Ok(())
4393 }
4394
4395 /// Lists job (app-run) entries visible to the authenticated user.
4396 ///
4397 /// The server returns AWS Batch-wrapper entries (not bare `Task` objects),
4398 /// surfacing cloud-batch state (`RUNNING`/`SUCCEEDED`/...) and the linked
4399 /// `task_id`. Use `Job::task_id()` + `Client::task_info` to fetch the
4400 /// underlying task details.
4401 ///
4402 /// The server does not support server-side filters, so the optional
4403 /// `name` argument is applied client-side as a substring match against
4404 /// each job's `job_name`.
4405 pub async fn jobs(&self, name: Option<&str>) -> Result<Vec<crate::api::Job>, Error> {
4406 let req = JobsListRequest {};
4407 let mut jobs: Vec<crate::api::Job> = match self.rpc("job.list".to_owned(), Some(&req)).await
4408 {
4409 Ok(r) => r,
4410 Err(Error::RpcError(code, msg)) => {
4411 return Err(map_rpc_error("job.list", code, msg, None));
4412 }
4413 Err(e) => return Err(e),
4414 };
4415 if let Some(name) = name {
4416 let needle = name.to_lowercase();
4417 jobs.retain(|j| j.job_name.to_lowercase().contains(&needle));
4418 jobs.sort_by(|a, b| a.job_name.cmp(&b.job_name));
4419 }
4420 Ok(jobs)
4421 }
4422
4423 /// Retrieve the task information and status.
4424 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4425 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4426 self.rpc(
4427 "task.get".to_owned(),
4428 Some(HashMap::from([("id", task_id)])),
4429 )
4430 .await
4431 }
4432
4433 /// Updates the tasks status.
4434 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4435 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4436 let status = TaskStatus {
4437 task_id,
4438 status: status.to_owned(),
4439 };
4440 self.rpc("docker.update.status".to_owned(), Some(status))
4441 .await
4442 }
4443
4444 /// Defines the stages for the task. The stages are defined as a mapping
4445 /// from stage names to their descriptions. Once stages are defined their
4446 /// status can be updated using the update_stage method.
4447 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4448 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4449 let stages: Vec<HashMap<String, String>> = stages
4450 .iter()
4451 .map(|(key, value)| {
4452 let mut stage_map = HashMap::new();
4453 stage_map.insert(key.to_string(), value.to_string());
4454 stage_map
4455 })
4456 .collect();
4457 let params = TaskStages { task_id, stages };
4458 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4459 Ok(())
4460 }
4461
4462 /// Updates the progress of the task for the provided stage and status
4463 /// information.
4464 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4465 pub async fn update_stage(
4466 &self,
4467 task_id: TaskID,
4468 stage: &str,
4469 status: &str,
4470 message: &str,
4471 percentage: u8,
4472 ) -> Result<(), Error> {
4473 let stage = Stage::new(
4474 Some(task_id),
4475 stage.to_owned(),
4476 Some(status.to_owned()),
4477 Some(message.to_owned()),
4478 percentage,
4479 );
4480 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4481 Ok(())
4482 }
4483
4484 /// Authenticated fetch from the Studio server using the bulk HTTP client
4485 /// (no total-request timeout; idle read timeout per chunk).
4486 ///
4487 /// **Buffers the entire response body into memory.** Suitable for small to
4488 /// medium payloads. For very large binary downloads (multi-GB artifacts or
4489 /// checkpoints), prefer a streaming approach that writes directly to disk.
4490 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4491 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4492 let req = self
4493 .bulk_http
4494 .get(format!("{}/{}", self.url, query))
4495 .header("User-Agent", "EdgeFirst Client")
4496 .header("Authorization", format!("Bearer {}", self.token().await));
4497 let resp = req.send().await?;
4498
4499 if resp.status().is_success() {
4500 let body = resp.bytes().await?;
4501
4502 if log_enabled!(Level::Trace) {
4503 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
4504 }
4505
4506 Ok(body.to_vec())
4507 } else {
4508 let err = resp.error_for_status_ref().unwrap_err();
4509 Err(Error::HttpError(err))
4510 }
4511 }
4512
4513 /// Sends a multipart post request to the server. This is used by the
4514 /// upload and download APIs which do not use JSON-RPC but instead transfer
4515 /// files using multipart/form-data.
4516 ///
4517 /// The result field is deserialized as `serde_json::Value` rather than
4518 /// `String` because different server endpoints return different shapes —
4519 /// `val.data.upload` returns a plain string while `task.data.upload`
4520 /// returns an object `{"message":…,"path":…,"size":…}`. All current
4521 /// callers discard the return value so this is backwards-compatible.
4522 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
4523 pub async fn post_multipart(
4524 &self,
4525 method: &str,
4526 form: Form,
4527 ) -> Result<serde_json::Value, Error> {
4528 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
4529 .ok()
4530 .and_then(|s| s.parse().ok())
4531 .unwrap_or(600u64);
4532
4533 let req = self
4534 .http
4535 .post(format!("{}/api?method={}", self.url, method))
4536 .header("Accept", "application/json")
4537 .header("User-Agent", "EdgeFirst Client")
4538 .header("Authorization", format!("Bearer {}", self.token().await))
4539 .timeout(Duration::from_secs(upload_timeout_secs))
4540 .multipart(form);
4541 let resp = req.send().await?;
4542
4543 if resp.status().is_success() {
4544 let body = resp.bytes().await?;
4545
4546 if log_enabled!(Level::Trace) {
4547 trace!(
4548 "POST Multipart Response: {}",
4549 String::from_utf8_lossy(&body)
4550 );
4551 }
4552
4553 let response: RpcResponse<serde_json::Value> = match serde_json::from_slice(&body) {
4554 Ok(response) => response,
4555 Err(err) => {
4556 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4557 return Err(err.into());
4558 }
4559 };
4560
4561 if let Some(error) = response.error {
4562 Err(Error::RpcError(error.code, error.message))
4563 } else if let Some(result) = response.result {
4564 Ok(result)
4565 } else {
4566 Err(Error::InvalidResponse)
4567 }
4568 } else {
4569 // HTTP-level failure on the multipart upload. Map 413 to the
4570 // typed `PayloadTooLarge` variant so callers see the same error
4571 // type from both single-file rpc_download paths and multipart
4572 // upload paths; everything else falls through to HttpError.
4573 let status = resp.status();
4574 if status.as_u16() == 413 {
4575 return Err(Error::PayloadTooLarge {
4576 method: method.to_string(),
4577 size_hint: None,
4578 });
4579 }
4580 let err = resp.error_for_status_ref().unwrap_err();
4581 Err(Error::HttpError(err))
4582 }
4583 }
4584
4585 /// Internal helper: POST a JSON-RPC request and stream the binary response
4586 /// to `output_path`. The response is assumed to be raw binary (not a JSON
4587 /// envelope). Use for endpoints that return file contents directly.
4588 ///
4589 /// On HTTP non-success, the response body is read as text and surfaced
4590 /// via `Error::RpcError(status_code, body)`.
4591 pub(crate) async fn rpc_download<P: Serialize>(
4592 &self,
4593 method: &str,
4594 params: &P,
4595 output_path: &std::path::Path,
4596 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
4597 ) -> Result<(), Error> {
4598 let envelope = serde_json::json!({
4599 "jsonrpc": "2.0",
4600 "id": 0,
4601 "method": method,
4602 "params": params,
4603 });
4604
4605 let url = format!("{}/api", self.url);
4606 let resp = self
4607 .bulk_http
4608 .post(&url)
4609 .header("Authorization", format!("Bearer {}", self.token().await))
4610 .json(&envelope)
4611 .send()
4612 .await?;
4613
4614 let status = resp.status();
4615 if !status.is_success() {
4616 if status.as_u16() == 413 {
4617 return Err(Error::PayloadTooLarge {
4618 method: method.to_string(),
4619 size_hint: None,
4620 });
4621 }
4622 let body = resp.text().await.unwrap_or_default();
4623 return Err(Error::RpcError(status.as_u16() as i32, body));
4624 }
4625
4626 // HTTP 200 with Content-Type: application/json can mean two things:
4627 // (a) a JSON-RPC error envelope when the server failed mid-way
4628 // (e.g. {"jsonrpc":"2.0","error":{"code":N,"message":"..."}}),
4629 // (b) a legitimate JSON file payload — validation traces, chart
4630 // bodies, metrics, etc., are typically served with this MIME.
4631 //
4632 // Disambiguate structurally: a JSON-RPC 2.0 envelope is required to
4633 // carry a `jsonrpc` member, and an *error* envelope further requires
4634 // an `error.code` integer (per RFC 8259 + JSON-RPC 2.0 §5). Only
4635 // decode the body as an error if both markers are present. This is
4636 // strict enough to leave legitimate JSON artifacts that happen to
4637 // contain a free-form `error` field (metrics, diagnostics, log
4638 // dumps) untouched, while still catching every real server
4639 // failure.
4640 let content_type = resp
4641 .headers()
4642 .get(reqwest::header::CONTENT_TYPE)
4643 .and_then(|v| v.to_str().ok())
4644 .unwrap_or("")
4645 .to_owned();
4646 if content_type.contains("application/json") {
4647 let body = resp.bytes().await?;
4648 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&body)
4649 && is_jsonrpc_error_envelope(&val)
4650 && let Some(err_obj) = val.get("error")
4651 {
4652 let code = err_obj.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
4653 let message = err_obj
4654 .get("message")
4655 .and_then(|m| m.as_str())
4656 .unwrap_or("unknown error")
4657 .to_string();
4658 return Err(Error::RpcError(code, message));
4659 }
4660 // Not an error envelope — body is a JSON file. Write it to disk
4661 // and emit a single completion progress event so callers (e.g.,
4662 // Python download_data progress callbacks) see the download
4663 // finish.
4664 //
4665 // `Path::parent` returns `Some("")` for a bare filename like
4666 // "metrics.json"; `create_dir_all("")` errors out with
4667 // `NotFound`, so only create the parent when it actually names
4668 // a directory.
4669 if let Some(parent) = output_path.parent()
4670 && !parent.as_os_str().is_empty()
4671 {
4672 tokio::fs::create_dir_all(parent).await?;
4673 }
4674 let mut file = tokio::fs::File::create(output_path).await?;
4675 file.write_all(&body).await?;
4676 file.flush().await?;
4677 if let Some(tx) = progress {
4678 let total = body.len();
4679 // Use the awaited send for the final event so completion
4680 // handlers are never silently dropped.
4681 let _ = tx
4682 .send(Progress {
4683 current: total,
4684 total,
4685 status: None,
4686 })
4687 .await;
4688 }
4689 return Ok(());
4690 }
4691
4692 // Same empty-parent guard for the streaming download path: passing
4693 // a bare filename like "metrics.json" must write to the current
4694 // directory rather than failing on `create_dir_all("")`.
4695 if let Some(parent) = output_path.parent()
4696 && !parent.as_os_str().is_empty()
4697 {
4698 tokio::fs::create_dir_all(parent).await?;
4699 }
4700
4701 stream_response_to_file(resp, output_path, progress).await
4702 }
4703
4704 /// Send a JSON-RPC request to the server. The method is the name of the
4705 /// method to call on the server. The params are the parameters to pass to
4706 /// the method. The method and params are serialized into a JSON-RPC
4707 /// request and sent to the server. The response is deserialized into
4708 /// the specified type and returned to the caller.
4709 ///
4710 /// NOTE: This API would generally not be called directly and instead users
4711 /// should use the higher-level methods provided by the client.
4712 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
4713 pub async fn rpc<Params, RpcResult>(
4714 &self,
4715 method: String,
4716 params: Option<Params>,
4717 ) -> Result<RpcResult, Error>
4718 where
4719 Params: Serialize,
4720 RpcResult: DeserializeOwned,
4721 {
4722 let auth_expires = self.token_expiration().await?;
4723 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
4724 self.renew_token().await?;
4725 }
4726
4727 self.rpc_without_auth(method, params).await
4728 }
4729
4730 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
4731 async fn rpc_without_auth<Params, RpcResult>(
4732 &self,
4733 method: String,
4734 params: Option<Params>,
4735 ) -> Result<RpcResult, Error>
4736 where
4737 Params: Serialize,
4738 RpcResult: DeserializeOwned,
4739 {
4740 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
4741 .ok()
4742 .and_then(|s| s.parse().ok())
4743 .unwrap_or(5usize);
4744
4745 let url = format!("{}/api", self.url);
4746
4747 // Serialize request body once before retry loop to avoid Clone bound on Params
4748 let request = RpcRequest {
4749 method: method.clone(),
4750 params,
4751 ..Default::default()
4752 };
4753
4754 // Log request for debugging (log crate) and profiling (tracing crate)
4755 let request_json = if method == "auth.login" {
4756 // Redact auth.login params (contains password)
4757 serde_json::json!({
4758 "jsonrpc": "2.0",
4759 "method": &method,
4760 "params": "[REDACTED - contains credentials]",
4761 "id": request.id
4762 })
4763 .to_string()
4764 } else {
4765 serde_json::to_string(&request)?
4766 };
4767
4768 if log_enabled!(Level::Trace) {
4769 trace!("RPC Request: {}", request_json);
4770 }
4771
4772 // Record request on current span for Perfetto when profiling is enabled
4773 #[cfg(feature = "profiling")]
4774 tracing::Span::current().record("request", &request_json);
4775
4776 let request_body = serde_json::to_vec(&request)?;
4777 let mut last_error: Option<Error> = None;
4778
4779 for attempt in 0..=max_retries {
4780 if attempt > 0 {
4781 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
4782 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
4783 // while ensuring we never retry faster than the base delay
4784 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
4785 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
4786 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
4787 let delay = Duration::from_millis(delay_ms);
4788 warn!(
4789 "Retry {}/{} for RPC '{}' after {:?}",
4790 attempt, max_retries, method, delay
4791 );
4792 tokio::time::sleep(delay).await;
4793 }
4794
4795 let result = self
4796 .http
4797 .post(&url)
4798 .header("Accept", "application/json")
4799 .header("Content-Type", "application/json")
4800 .header("User-Agent", "EdgeFirst Client")
4801 .header("Authorization", format!("Bearer {}", self.token().await))
4802 .body(request_body.clone())
4803 .send()
4804 .await;
4805
4806 match result {
4807 Ok(res) => {
4808 let status = res.status();
4809 let status_code = status.as_u16();
4810
4811 // Check for retryable HTTP status codes before processing response
4812 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
4813 && attempt < max_retries
4814 {
4815 warn!(
4816 "RPC '{}' failed with HTTP {} (retrying)",
4817 method, status_code
4818 );
4819 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
4820 continue;
4821 }
4822
4823 // Process the response
4824 match self.process_rpc_response(res).await {
4825 Ok(result) => {
4826 if attempt > 0 {
4827 debug!("RPC '{}' succeeded on retry {}", method, attempt);
4828 }
4829 return Ok(result);
4830 }
4831 Err(e) => {
4832 // Don't retry client errors (4xx except 408, 429)
4833 if attempt > 0 {
4834 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4835 }
4836 return Err(e);
4837 }
4838 }
4839 }
4840 Err(e) => {
4841 // Transport error (timeout, connection failure, etc.)
4842 let is_timeout = e.is_timeout();
4843 let is_connect = e.is_connect();
4844
4845 if (is_timeout || is_connect) && attempt < max_retries {
4846 warn!(
4847 "RPC '{}' transport error (retrying): {}",
4848 method,
4849 if is_timeout {
4850 "timeout"
4851 } else {
4852 "connection failed"
4853 }
4854 );
4855 last_error = Some(Error::HttpError(e));
4856 continue;
4857 }
4858
4859 if attempt > 0 {
4860 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
4861 }
4862 return Err(Error::HttpError(e));
4863 }
4864 }
4865 }
4866
4867 // Should not reach here
4868 Err(last_error.unwrap_or_else(|| {
4869 Error::InvalidParameters(format!(
4870 "RPC '{}' failed after {} retries",
4871 method, max_retries
4872 ))
4873 }))
4874 }
4875
4876 async fn process_rpc_response<RpcResult>(
4877 &self,
4878 res: reqwest::Response,
4879 ) -> Result<RpcResult, Error>
4880 where
4881 RpcResult: DeserializeOwned,
4882 {
4883 let body = res.bytes().await?;
4884 let response_str = String::from_utf8_lossy(&body);
4885
4886 if log_enabled!(Level::Trace) {
4887 trace!("RPC Response: {}", response_str);
4888 }
4889
4890 // Record response on current span for Perfetto when profiling is enabled
4891 // Truncate large responses to avoid bloating trace files
4892 #[cfg(feature = "profiling")]
4893 {
4894 const MAX_RESPONSE_LEN: usize = 4096;
4895 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
4896 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
4897 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
4898 format!(
4899 "{}...[truncated {} bytes]",
4900 &response_str[..safe_end],
4901 response_str.len() - safe_end
4902 )
4903 } else {
4904 response_str.to_string()
4905 };
4906 tracing::Span::current().record("response", &truncated);
4907 }
4908
4909 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
4910 Ok(response) => response,
4911 Err(err) => {
4912 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
4913 return Err(err.into());
4914 }
4915 };
4916
4917 // FIXME: Studio Server always returns 999 as the id.
4918 // if request.id.to_string() != response.id {
4919 // return Err(Error::InvalidRpcId(response.id));
4920 // }
4921
4922 if let Some(error) = response.error {
4923 Err(Error::RpcError(error.code, error.message))
4924 } else if let Some(result) = response.result {
4925 Ok(result)
4926 } else {
4927 Err(Error::InvalidResponse)
4928 }
4929 }
4930}
4931
4932/// Process items in parallel with semaphore concurrency control and progress
4933/// tracking.
4934///
4935/// This helper eliminates boilerplate for parallel item processing with:
4936/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
4937/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
4938/// - Atomic progress counter with automatic item-level updates
4939/// - Progress updates sent after each item completes (not byte-level streaming)
4940/// - Proper error propagation from spawned tasks
4941///
4942/// Note: This is optimized for discrete items with post-completion progress
4943/// updates. For byte-level streaming progress or custom retry logic, use
4944/// specialized implementations.
4945///
4946/// # Arguments
4947///
4948/// * `items` - Collection of items to process in parallel
4949/// * `progress` - Optional progress channel for tracking completion
4950/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
4951/// * `work_fn` - Async function to execute for each item
4952///
4953/// # Examples
4954///
4955/// ```rust,ignore
4956/// // Use default concurrency
4957/// parallel_foreach_items(samples, progress, None, |sample| async move {
4958/// sample.download(&client, file_type).await?;
4959/// Ok(())
4960/// }).await?;
4961/// ```
4962async fn parallel_foreach_items<T, F, Fut>(
4963 items: Vec<T>,
4964 progress: Option<Sender<Progress>>,
4965 concurrency: Option<usize>,
4966 work_fn: F,
4967) -> Result<(), Error>
4968where
4969 T: Send + 'static,
4970 F: Fn(T) -> Fut + Send + Sync + 'static,
4971 Fut: Future<Output = Result<(), Error>> + Send + 'static,
4972{
4973 let total = items.len();
4974 let current = Arc::new(AtomicUsize::new(0));
4975 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
4976 let work_fn = Arc::new(work_fn);
4977
4978 let tasks = items
4979 .into_iter()
4980 .map(|item| {
4981 let sem = sem.clone();
4982 let current = current.clone();
4983 let progress = progress.clone();
4984 let work_fn = work_fn.clone();
4985
4986 tokio::spawn(async move {
4987 let _permit = sem.acquire().await.map_err(|_| {
4988 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4989 })?;
4990
4991 // Execute the actual work
4992 work_fn(item).await?;
4993
4994 // Update progress
4995 if let Some(progress) = &progress {
4996 let current = current.fetch_add(1, Ordering::SeqCst);
4997 let _ = progress
4998 .send(Progress {
4999 current: current + 1,
5000 total,
5001 status: None,
5002 })
5003 .await;
5004 }
5005
5006 Ok::<(), Error>(())
5007 })
5008 })
5009 .collect::<Vec<_>>();
5010
5011 join_all(tasks)
5012 .await
5013 .into_iter()
5014 .collect::<Result<Vec<_>, _>>()?
5015 .into_iter()
5016 .collect::<Result<Vec<_>, _>>()?;
5017
5018 if let Some(progress) = progress {
5019 drop(progress);
5020 }
5021
5022 Ok(())
5023}
5024
5025/// Upload a file to S3 using multipart upload with presigned URLs.
5026///
5027/// Splits a file into chunks (100MB each) and uploads them in parallel using
5028/// S3 multipart upload protocol. Returns completion parameters with ETags for
5029/// finalizing the upload.
5030///
5031/// This function handles:
5032/// - Splitting files into parts based on PART_SIZE (100MB)
5033/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
5034/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
5035/// - Retry logic (handled by reqwest client)
5036/// - Progress tracking across all parts
5037///
5038/// # Arguments
5039///
5040/// * `http` - HTTP client for making requests
5041/// * `part` - Snapshot part info with presigned URLs for each chunk
5042/// * `path` - Local file path to upload
5043/// * `total` - Total bytes across all files for progress calculation
5044/// * `current` - Atomic counter tracking bytes uploaded across all operations
5045/// * `progress` - Optional channel for sending progress updates
5046///
5047/// # Returns
5048///
5049/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
5050async fn upload_multipart(
5051 http: reqwest::Client,
5052 part: SnapshotPart,
5053 path: PathBuf,
5054 total: usize,
5055 confirmed_bytes: Arc<AtomicUsize>,
5056 progress: Option<Sender<Progress>>,
5057) -> Result<SnapshotCompleteMultipartParams, Error> {
5058 let filesize = path.metadata()?.len() as usize;
5059 let n_parts = filesize.div_ceil(PART_SIZE);
5060 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
5061
5062 let key = part.key.ok_or(Error::InvalidResponse)?;
5063 let upload_id = part.upload_id;
5064
5065 let urls = part.urls.clone();
5066
5067 // Pre-allocate ETag slots for all parts
5068 let etags = Arc::new(tokio::sync::Mutex::new(vec![
5069 EtagPart {
5070 etag: "".to_owned(),
5071 part_number: 0,
5072 };
5073 n_parts
5074 ]));
5075
5076 // Per-part byte counters for streaming progress (reset on retry)
5077 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
5078 (0..n_parts)
5079 .map(|_| AtomicUsize::new(0))
5080 .collect::<Vec<_>>(),
5081 );
5082
5083 // Upload all parts in parallel with concurrency limiting
5084 let tasks = (0..n_parts)
5085 .map(|part_idx| {
5086 let http = http.clone();
5087 let url = urls[part_idx].clone();
5088 let etags = etags.clone();
5089 let path = path.to_owned();
5090 let sem = sem.clone();
5091 let progress = progress.clone();
5092 let confirmed_bytes = confirmed_bytes.clone();
5093 let part_bytes = part_bytes.clone();
5094
5095 // Calculate this part's size
5096 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5097 filesize % PART_SIZE
5098 } else {
5099 PART_SIZE
5100 };
5101
5102 tokio::spawn(async move {
5103 // Acquire semaphore permit to limit concurrent uploads
5104 let _permit = sem.acquire().await.map_err(|_| {
5105 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5106 })?;
5107
5108 // Upload part with streaming progress and retry logic
5109 let etag = upload_part_with_progress(
5110 http,
5111 url,
5112 path,
5113 part_idx,
5114 n_parts,
5115 part_size,
5116 total,
5117 confirmed_bytes.clone(),
5118 part_bytes.clone(),
5119 progress.clone(),
5120 )
5121 .await?;
5122
5123 // Store ETag for this part (needed to complete multipart upload)
5124 let mut etags_guard = etags.lock().await;
5125 etags_guard[part_idx] = EtagPart {
5126 etag,
5127 part_number: part_idx + 1,
5128 };
5129
5130 // Part completed successfully - add to confirmed bytes
5131 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
5132 // Reset part counter since it's now confirmed
5133 part_bytes[part_idx].store(0, Ordering::SeqCst);
5134
5135 // Send final progress update for this part
5136 if let Some(progress) = &progress {
5137 let current = confirmed_bytes.load(Ordering::SeqCst)
5138 + part_bytes
5139 .iter()
5140 .map(|p| p.load(Ordering::SeqCst))
5141 .sum::<usize>();
5142 let _ = progress
5143 .send(Progress {
5144 current,
5145 total,
5146 status: None,
5147 })
5148 .await;
5149 }
5150
5151 Ok::<(), Error>(())
5152 })
5153 })
5154 .collect::<Vec<_>>();
5155
5156 // Wait for all parts to complete (double collect to handle both JoinError and
5157 // inner Error)
5158 join_all(tasks)
5159 .await
5160 .into_iter()
5161 .collect::<Result<Vec<_>, _>>()?
5162 .into_iter()
5163 .collect::<Result<Vec<_>, _>>()?;
5164
5165 Ok(SnapshotCompleteMultipartParams {
5166 key,
5167 upload_id,
5168 etag_list: etags.lock().await.clone(),
5169 })
5170}
5171
5172/// Upload a single part with streaming progress tracking and retry logic.
5173///
5174/// Progress is reported continuously as bytes are sent. On retry, the part's
5175/// progress counter is reset to avoid over-reporting.
5176#[allow(clippy::too_many_arguments)]
5177async fn upload_part_with_progress(
5178 http: reqwest::Client,
5179 url: String,
5180 path: PathBuf,
5181 part_idx: usize,
5182 n_parts: usize,
5183 part_size: usize,
5184 total: usize,
5185 confirmed_bytes: Arc<AtomicUsize>,
5186 part_bytes: Arc<Vec<AtomicUsize>>,
5187 progress: Option<Sender<Progress>>,
5188) -> Result<String, Error> {
5189 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5190 .ok()
5191 .and_then(|s| s.parse().ok())
5192 .unwrap_or(5usize);
5193
5194 // Per-part total upload timeout. Covers the send phase (request body) where
5195 // read_timeout does not apply. Each part is at most PART_SIZE (100MB), so
5196 // this bounds how long a stalled upload can block before retrying.
5197 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5198 .ok()
5199 .and_then(|s| s.parse().ok())
5200 .unwrap_or(600u64); // 600s = 100MB at ~170 KB/s minimum
5201
5202 let mut last_error: Option<Error> = None;
5203
5204 for attempt in 0..=max_retries {
5205 if attempt > 0 {
5206 // Reset this part's progress counter before retry
5207 part_bytes[part_idx].store(0, Ordering::SeqCst);
5208
5209 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5210 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5211 warn!(
5212 "Retry {}/{} for part {} after {:?}",
5213 attempt, max_retries, part_idx, delay
5214 );
5215 tokio::time::sleep(delay).await;
5216 }
5217
5218 match upload_part_streaming(
5219 http.clone(),
5220 url.clone(),
5221 path.clone(),
5222 part_idx,
5223 n_parts,
5224 part_size,
5225 total,
5226 upload_timeout_secs,
5227 confirmed_bytes.clone(),
5228 part_bytes.clone(),
5229 progress.clone(),
5230 )
5231 .await
5232 {
5233 Ok(etag) => return Ok(etag),
5234 Err(e) => {
5235 // Check if error is retryable
5236 let is_retryable = matches!(
5237 &e,
5238 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
5239 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
5240 );
5241
5242 if is_retryable && attempt < max_retries {
5243 last_error = Some(e);
5244 continue;
5245 }
5246
5247 return Err(e);
5248 }
5249 }
5250 }
5251
5252 Err(last_error
5253 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
5254}
5255
5256/// Perform the actual upload with streaming progress.
5257#[allow(clippy::too_many_arguments)]
5258async fn upload_part_streaming(
5259 http: reqwest::Client,
5260 url: String,
5261 path: PathBuf,
5262 part_idx: usize,
5263 n_parts: usize,
5264 _part_size: usize,
5265 total: usize,
5266 upload_timeout_secs: u64,
5267 confirmed_bytes: Arc<AtomicUsize>,
5268 part_bytes: Arc<Vec<AtomicUsize>>,
5269 progress: Option<Sender<Progress>>,
5270) -> Result<String, Error> {
5271 let filesize = path.metadata()?.len() as usize;
5272 let mut file = File::open(&path).await?;
5273 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
5274 .await?;
5275 let file = file.take(PART_SIZE as u64);
5276
5277 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5278 filesize % PART_SIZE
5279 } else {
5280 PART_SIZE
5281 };
5282
5283 // Create stream with progress tracking
5284 let stream = FramedRead::new(file, BytesCodec::new());
5285
5286 // Wrap stream to track bytes sent and report progress
5287 let progress_stream = stream.map(move |result| {
5288 if let Ok(ref bytes) = result {
5289 let bytes_len = bytes.len();
5290 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
5291
5292 // Send progress update (fire-and-forget via try_send to avoid blocking)
5293 if let Some(ref progress) = progress {
5294 let current = confirmed_bytes.load(Ordering::SeqCst)
5295 + part_bytes
5296 .iter()
5297 .map(|p| p.load(Ordering::SeqCst))
5298 .sum::<usize>();
5299 // Best-effort progress reporting: use try_send to avoid blocking.
5300 // If the channel is full or closed, we intentionally skip this update
5301 // to avoid stalling the upload; subsequent updates will still be delivered.
5302 let _ = progress.try_send(Progress {
5303 current,
5304 total,
5305 status: None,
5306 });
5307 }
5308 }
5309 result.map(|b| b.freeze())
5310 });
5311
5312 let body = Body::wrap_stream(progress_stream);
5313
5314 let resp = http
5315 .put(url)
5316 .header(CONTENT_LENGTH, body_length)
5317 .timeout(Duration::from_secs(upload_timeout_secs))
5318 .body(body)
5319 .send()
5320 .await?
5321 .error_for_status()?;
5322
5323 let etag = resp
5324 .headers()
5325 .get("etag")
5326 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
5327 .to_str()
5328 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
5329 .to_owned();
5330
5331 // Studio Server requires etag without the quotes.
5332 let etag = etag
5333 .strip_prefix("\"")
5334 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
5335 let etag = etag
5336 .strip_suffix("\"")
5337 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
5338
5339 Ok(etag.to_owned())
5340}
5341
5342/// Upload a complete file to a presigned S3 URL using HTTP PUT.
5343///
5344/// This is used for populate_samples to upload files to S3 after
5345/// receiving presigned URLs from the server.
5346///
5347/// Includes explicit retry logic with exponential backoff for transient
5348/// failures.
5349async fn upload_file_to_presigned_url(
5350 http: reqwest::Client,
5351 url: &str,
5352 path: PathBuf,
5353) -> Result<(), Error> {
5354 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5355 .ok()
5356 .and_then(|s| s.parse().ok())
5357 .unwrap_or(5usize);
5358
5359 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5360 .ok()
5361 .and_then(|s| s.parse().ok())
5362 .unwrap_or(600u64);
5363
5364 // Read the entire file into memory once
5365 let file_data = fs::read(&path).await?;
5366 let file_size = file_data.len();
5367 let filename = path.file_name().unwrap_or_default().to_string_lossy();
5368
5369 let mut last_error: Option<Error> = None;
5370
5371 for attempt in 0..=max_retries {
5372 if attempt > 0 {
5373 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5374 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5375 warn!(
5376 "Retry {}/{} for upload '{}' after {:?}",
5377 attempt, max_retries, filename, delay
5378 );
5379 tokio::time::sleep(delay).await;
5380 }
5381
5382 // Attempt upload
5383 let result = http
5384 .put(url)
5385 .header(CONTENT_LENGTH, file_size)
5386 .timeout(Duration::from_secs(upload_timeout_secs))
5387 .body(file_data.clone())
5388 .send()
5389 .await;
5390
5391 match result {
5392 Ok(resp) => {
5393 if resp.status().is_success() {
5394 if attempt > 0 {
5395 debug!(
5396 "Upload '{}' succeeded on retry {} ({} bytes)",
5397 filename, attempt, file_size
5398 );
5399 } else {
5400 debug!(
5401 "Successfully uploaded file: {} ({} bytes)",
5402 filename, file_size
5403 );
5404 }
5405 return Ok(());
5406 }
5407
5408 let status = resp.status();
5409 let status_code = status.as_u16();
5410
5411 // Check if error is retryable
5412 let is_retryable =
5413 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5414
5415 if is_retryable && attempt < max_retries {
5416 let error_text = resp.text().await.unwrap_or_default();
5417 warn!(
5418 "Upload '{}' failed with HTTP {} (retryable): {}",
5419 filename, status_code, error_text
5420 );
5421 last_error = Some(Error::InvalidParameters(format!(
5422 "Upload failed: HTTP {} - {}",
5423 status, error_text
5424 )));
5425 continue;
5426 }
5427
5428 // Non-retryable error or max retries exceeded
5429 let error_text = resp.text().await.unwrap_or_default();
5430 if attempt > 0 {
5431 error!(
5432 "Upload '{}' failed after {} retries: HTTP {} - {}",
5433 filename, attempt, status, error_text
5434 );
5435 }
5436 return Err(Error::InvalidParameters(format!(
5437 "Upload failed: HTTP {} - {}",
5438 status, error_text
5439 )));
5440 }
5441 Err(e) => {
5442 // Transport error (timeout, connection failure, etc.)
5443 let is_timeout = e.is_timeout();
5444 let is_connect = e.is_connect();
5445
5446 if (is_timeout || is_connect) && attempt < max_retries {
5447 warn!(
5448 "Upload '{}' transport error (retrying): {}",
5449 filename,
5450 if is_timeout {
5451 "timeout"
5452 } else {
5453 "connection failed"
5454 }
5455 );
5456 last_error = Some(Error::HttpError(e));
5457 continue;
5458 }
5459
5460 // Non-retryable or max retries exceeded
5461 if attempt > 0 {
5462 error!(
5463 "Upload '{}' failed after {} retries: {}",
5464 filename, attempt, e
5465 );
5466 }
5467 return Err(Error::HttpError(e));
5468 }
5469 }
5470 }
5471
5472 // Should not reach here, but return last error if we do
5473 Err(last_error.unwrap_or_else(|| {
5474 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5475 }))
5476}
5477
5478/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
5479///
5480/// This is used for populate_samples to upload file content from memory
5481/// (e.g., from ZIP archives) without writing to disk first.
5482///
5483/// Includes explicit retry logic with exponential backoff for transient
5484/// failures.
5485async fn upload_bytes_to_presigned_url(
5486 http: reqwest::Client,
5487 url: &str,
5488 file_data: Vec<u8>,
5489 filename: &str,
5490) -> Result<(), Error> {
5491 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5492 .ok()
5493 .and_then(|s| s.parse().ok())
5494 .unwrap_or(5usize);
5495
5496 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5497 .ok()
5498 .and_then(|s| s.parse().ok())
5499 .unwrap_or(600u64);
5500
5501 let file_size = file_data.len();
5502 let mut last_error: Option<Error> = None;
5503
5504 for attempt in 0..=max_retries {
5505 if attempt > 0 {
5506 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5507 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5508 warn!(
5509 "Retry {}/{} for upload '{}' after {:?}",
5510 attempt, max_retries, filename, delay
5511 );
5512 tokio::time::sleep(delay).await;
5513 }
5514
5515 // Attempt upload
5516 let result = http
5517 .put(url)
5518 .header(CONTENT_LENGTH, file_size)
5519 .timeout(Duration::from_secs(upload_timeout_secs))
5520 .body(file_data.clone())
5521 .send()
5522 .await;
5523
5524 match result {
5525 Ok(resp) => {
5526 if resp.status().is_success() {
5527 if attempt > 0 {
5528 debug!(
5529 "Upload '{}' succeeded on retry {} ({} bytes)",
5530 filename, attempt, file_size
5531 );
5532 } else {
5533 debug!(
5534 "Successfully uploaded file: {} ({} bytes)",
5535 filename, file_size
5536 );
5537 }
5538 return Ok(());
5539 }
5540
5541 let status = resp.status();
5542 let status_code = status.as_u16();
5543
5544 // Check if error is retryable
5545 let is_retryable =
5546 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5547
5548 if is_retryable && attempt < max_retries {
5549 let error_text = resp.text().await.unwrap_or_default();
5550 warn!(
5551 "Upload '{}' failed with HTTP {} (retryable): {}",
5552 filename, status_code, error_text
5553 );
5554 last_error = Some(Error::InvalidParameters(format!(
5555 "Upload failed: HTTP {} - {}",
5556 status, error_text
5557 )));
5558 continue;
5559 }
5560
5561 // Non-retryable error or max retries exceeded
5562 let error_text = resp.text().await.unwrap_or_default();
5563 if attempt > 0 {
5564 error!(
5565 "Upload '{}' failed after {} retries: HTTP {} - {}",
5566 filename, attempt, status, error_text
5567 );
5568 }
5569 return Err(Error::InvalidParameters(format!(
5570 "Upload failed: HTTP {} - {}",
5571 status, error_text
5572 )));
5573 }
5574 Err(e) => {
5575 // Transport error (timeout, connection failure, etc.)
5576 let is_timeout = e.is_timeout();
5577 let is_connect = e.is_connect();
5578
5579 if (is_timeout || is_connect) && attempt < max_retries {
5580 warn!(
5581 "Upload '{}' transport error (retrying): {}",
5582 filename,
5583 if is_timeout {
5584 "timeout"
5585 } else {
5586 "connection failed"
5587 }
5588 );
5589 last_error = Some(Error::HttpError(e));
5590 continue;
5591 }
5592
5593 // Non-retryable or max retries exceeded
5594 if attempt > 0 {
5595 error!(
5596 "Upload '{}' failed after {} retries: {}",
5597 filename, attempt, e
5598 );
5599 }
5600 return Err(Error::HttpError(e));
5601 }
5602 }
5603 }
5604
5605 // Should not reach here, but return last error if we do
5606 Err(last_error.unwrap_or_else(|| {
5607 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
5608 }))
5609}
5610
5611#[cfg(test)]
5612mod tests {
5613 use super::*;
5614
5615 #[test]
5616 fn test_filter_and_sort_by_name_exact_match_first() {
5617 // Test that exact matches come first
5618 let items = vec![
5619 "Deer Roundtrip 123".to_string(),
5620 "Deer".to_string(),
5621 "Reindeer".to_string(),
5622 "DEER".to_string(),
5623 ];
5624 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5625 assert_eq!(result[0], "Deer"); // Exact match first
5626 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
5627 }
5628
5629 #[test]
5630 fn test_filter_and_sort_by_name_shorter_names_preferred() {
5631 // Test that shorter names (more specific) come before longer ones
5632 let items = vec![
5633 "Test Dataset ABC".to_string(),
5634 "Test".to_string(),
5635 "Test Dataset".to_string(),
5636 ];
5637 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5638 assert_eq!(result[0], "Test"); // Exact match first
5639 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
5640 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
5641 }
5642
5643 #[test]
5644 fn test_filter_and_sort_by_name_case_insensitive_filter() {
5645 // Test that filtering is case-insensitive
5646 let items = vec![
5647 "UPPERCASE".to_string(),
5648 "lowercase".to_string(),
5649 "MixedCase".to_string(),
5650 ];
5651 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
5652 assert_eq!(result.len(), 3); // All items should match
5653 }
5654
5655 #[test]
5656 fn test_filter_and_sort_by_name_no_matches() {
5657 // Test that empty result is returned when no matches
5658 let items = vec!["Apple".to_string(), "Banana".to_string()];
5659 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
5660 assert!(result.is_empty());
5661 }
5662
5663 #[test]
5664 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
5665 // Test alphabetical ordering for same-length names
5666 let items = vec![
5667 "TestC".to_string(),
5668 "TestA".to_string(),
5669 "TestB".to_string(),
5670 ];
5671 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
5672 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
5673 }
5674
5675 #[test]
5676 fn test_build_filename_no_flatten() {
5677 // When flatten=false, should return base_name unchanged
5678 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
5679 assert_eq!(result, "image.jpg");
5680
5681 let result = Client::build_filename("test.png", false, None, None);
5682 assert_eq!(result, "test.png");
5683 }
5684
5685 #[test]
5686 fn test_build_filename_flatten_no_sequence() {
5687 // When flatten=true but no sequence, should return base_name unchanged
5688 let result = Client::build_filename("standalone.jpg", true, None, None);
5689 assert_eq!(result, "standalone.jpg");
5690 }
5691
5692 #[test]
5693 fn test_build_filename_flatten_with_sequence_not_prefixed() {
5694 // When flatten=true, in sequence, filename not prefixed → add prefix
5695 let result = Client::build_filename(
5696 "image.camera.jpeg",
5697 true,
5698 Some(&"deer_sequence".to_string()),
5699 Some(42),
5700 );
5701 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
5702 }
5703
5704 #[test]
5705 fn test_build_filename_flatten_with_sequence_no_frame() {
5706 // When flatten=true, in sequence, no frame number → prefix with sequence only
5707 let result =
5708 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
5709 assert_eq!(result, "sequence_A_image.jpg");
5710 }
5711
5712 #[test]
5713 fn test_build_filename_flatten_already_prefixed() {
5714 // When flatten=true, filename already starts with sequence_ → return unchanged
5715 let result = Client::build_filename(
5716 "deer_sequence_042.camera.jpeg",
5717 true,
5718 Some(&"deer_sequence".to_string()),
5719 Some(42),
5720 );
5721 assert_eq!(result, "deer_sequence_042.camera.jpeg");
5722 }
5723
5724 #[test]
5725 fn test_build_filename_flatten_already_prefixed_different_frame() {
5726 // Edge case: filename has sequence prefix but we're adding different frame
5727 // Should still respect existing prefix
5728 let result = Client::build_filename(
5729 "sequence_A_001.jpg",
5730 true,
5731 Some(&"sequence_A".to_string()),
5732 Some(2),
5733 );
5734 assert_eq!(result, "sequence_A_001.jpg");
5735 }
5736
5737 #[test]
5738 fn test_build_filename_flatten_partial_match() {
5739 // Edge case: filename contains sequence name but not as prefix
5740 let result = Client::build_filename(
5741 "test_sequence_A_image.jpg",
5742 true,
5743 Some(&"sequence_A".to_string()),
5744 Some(5),
5745 );
5746 // Should add prefix because it doesn't START with "sequence_A_"
5747 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
5748 }
5749
5750 #[test]
5751 fn test_build_filename_flatten_preserves_extension() {
5752 // Verify that file extensions are preserved correctly
5753 let extensions = vec![
5754 "jpeg",
5755 "jpg",
5756 "png",
5757 "camera.jpeg",
5758 "lidar.pcd",
5759 "depth.png",
5760 ];
5761
5762 for ext in extensions {
5763 let filename = format!("image.{}", ext);
5764 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
5765 assert!(
5766 result.ends_with(&format!(".{}", ext)),
5767 "Extension .{} not preserved in {}",
5768 ext,
5769 result
5770 );
5771 }
5772 }
5773
5774 #[test]
5775 fn test_build_filename_flatten_sanitization_compatibility() {
5776 // Test with sanitized path components (no special chars)
5777 let result = Client::build_filename(
5778 "sample_001.jpg",
5779 true,
5780 Some(&"seq_name_with_underscores".to_string()),
5781 Some(10),
5782 );
5783 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
5784 }
5785
5786 // =========================================================================
5787 // Additional filter_and_sort_by_name tests for exact match determinism
5788 // =========================================================================
5789
5790 #[test]
5791 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
5792 // Test that searching for "Deer" always returns "Deer" first, not
5793 // "Deer Roundtrip 20251129" or similar
5794 let items = vec![
5795 "Deer Roundtrip 20251129".to_string(),
5796 "White-Tailed Deer".to_string(),
5797 "Deer".to_string(),
5798 "Deer Snapshot Test".to_string(),
5799 "Reindeer Dataset".to_string(),
5800 ];
5801
5802 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5803
5804 // CRITICAL: First result must be exact match "Deer"
5805 assert_eq!(
5806 result.first().map(|s| s.as_str()),
5807 Some("Deer"),
5808 "Expected exact match 'Deer' first, got: {:?}",
5809 result.first()
5810 );
5811
5812 // Verify all items containing "Deer" are present (case-insensitive)
5813 assert_eq!(result.len(), 5);
5814 }
5815
5816 #[test]
5817 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
5818 // Verify case-sensitive exact match takes priority over case-insensitive
5819 let items = vec![
5820 "DEER".to_string(),
5821 "deer".to_string(),
5822 "Deer".to_string(),
5823 "Deer Test".to_string(),
5824 ];
5825
5826 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5827
5828 // Priority 1: Case-sensitive exact match "Deer" first
5829 assert_eq!(result[0], "Deer");
5830 // Priority 2: Case-insensitive exact matches next
5831 assert!(result[1] == "DEER" || result[1] == "deer");
5832 assert!(result[2] == "DEER" || result[2] == "deer");
5833 }
5834
5835 #[test]
5836 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
5837 // Realistic scenario: User searches for snapshot "Deer" and multiple
5838 // snapshots exist with similar names
5839 let items = vec![
5840 "Unit Testing - Deer Dataset Backup".to_string(),
5841 "Deer".to_string(),
5842 "Deer Snapshot 2025-01-15".to_string(),
5843 "Original Deer".to_string(),
5844 ];
5845
5846 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5847
5848 // MUST return exact match first for deterministic test behavior
5849 assert_eq!(
5850 result[0], "Deer",
5851 "Searching for 'Deer' should return exact 'Deer' first"
5852 );
5853 }
5854
5855 #[test]
5856 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
5857 // Realistic scenario: User searches for dataset "Deer" but multiple
5858 // datasets have "Deer" in their name
5859 let items = vec![
5860 "Deer Roundtrip".to_string(),
5861 "Deer".to_string(),
5862 "deer".to_string(),
5863 "White-Tailed Deer".to_string(),
5864 "Deer-V2".to_string(),
5865 ];
5866
5867 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
5868
5869 // Exact case-sensitive match must be first
5870 assert_eq!(result[0], "Deer");
5871 // Case-insensitive exact match should be second
5872 assert_eq!(result[1], "deer");
5873 // Shorter names should come before longer names
5874 assert!(
5875 result.iter().position(|s| s == "Deer-V2").unwrap()
5876 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
5877 );
5878 }
5879
5880 #[test]
5881 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
5882 // CRITICAL: The first result should ALWAYS be the best match
5883 // This is essential for deterministic test behavior
5884 let scenarios = vec![
5885 // (items, filter, expected_first)
5886 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
5887 (vec!["test", "TEST", "Test Data"], "test", "test"),
5888 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
5889 ];
5890
5891 for (items, filter, expected_first) in scenarios {
5892 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
5893 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
5894
5895 assert_eq!(
5896 result.first().map(|s| s.as_str()),
5897 Some(expected_first),
5898 "For filter '{}', expected first result '{}', got: {:?}",
5899 filter,
5900 expected_first,
5901 result.first()
5902 );
5903 }
5904 }
5905
5906 #[test]
5907 fn test_with_server_clears_storage() {
5908 use crate::storage::MemoryTokenStorage;
5909
5910 // Create client with memory storage and a token
5911 let storage = Arc::new(MemoryTokenStorage::new());
5912 storage.store("test-token").unwrap();
5913
5914 let client = Client::new().unwrap().with_storage(storage.clone());
5915
5916 // Verify token is loaded
5917 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
5918
5919 // Change server - should clear storage
5920 let _new_client = client.with_server("test").unwrap();
5921
5922 // Verify storage was cleared
5923 assert_eq!(storage.load().unwrap(), None);
5924 }
5925
5926 #[test]
5927 fn test_with_server_clears_storage_even_for_full_url() {
5928 // Regression: `with_server` used to short-circuit to `with_url`
5929 // when given a full URL, which preserved the bearer token. The
5930 // contract for `with_server` is that switching servers means
5931 // the token from the old server is no longer trusted.
5932 use crate::storage::MemoryTokenStorage;
5933
5934 let storage = Arc::new(MemoryTokenStorage::new());
5935 storage.store("token-from-old-server").unwrap();
5936 let client = Client::new().unwrap().with_storage(storage.clone());
5937 assert_eq!(
5938 storage.load().unwrap(),
5939 Some("token-from-old-server".to_string())
5940 );
5941
5942 // Switch to a self-hosted Studio (full URL). Storage must be
5943 // cleared, and the new client must have a blank in-memory token.
5944 let new_client = client
5945 .with_server("https://studio.example.com")
5946 .expect("https full URL through with_server");
5947 assert_eq!(storage.load().unwrap(), None);
5948 assert_eq!(new_client.url(), "https://studio.example.com");
5949
5950 // The new client should not carry the old token in memory either.
5951 let in_mem = tokio::runtime::Runtime::new()
5952 .unwrap()
5953 .block_on(async { new_client.token.read().await.clone() });
5954 assert!(in_mem.is_empty(), "expected blank token, got {in_mem:?}");
5955 }
5956
5957 #[test]
5958 fn test_with_server_rejects_insecure_full_url() {
5959 // `with_server` validates full URLs through `with_url`, so the
5960 // HTTPS rule applies uniformly. Plain http to a public host
5961 // must be rejected — the bearer token would otherwise leak in
5962 // plaintext when the caller next authenticates.
5963 let client = Client::new().unwrap();
5964 let err = client.with_server("http://studio.example.com").unwrap_err();
5965 assert!(matches!(err, Error::InsecureUrl(_)));
5966 }
5967
5968 // ===== with_url HTTPS enforcement =====
5969 //
5970 // The bearer token rides in the Authorization header, so plain
5971 // http:// to a public host would leak it in the clear. The function
5972 // must reject those URLs, but still let wiremock / local-dev URLs
5973 // through (loopback addresses, "localhost", "*.localhost").
5974
5975 #[test]
5976 fn with_url_accepts_https_public_host() {
5977 let client = Client::new().unwrap();
5978 let out = client
5979 .with_url("https://studio.example.com")
5980 .expect("https public host must be accepted");
5981 assert_eq!(out.url(), "https://studio.example.com");
5982 }
5983
5984 #[test]
5985 fn with_url_accepts_http_loopback_ipv4() {
5986 let client = Client::new().unwrap();
5987 let out = client
5988 .with_url("http://127.0.0.1:8080")
5989 .expect("http://127.0.0.1 must be accepted (loopback)");
5990 assert_eq!(out.url(), "http://127.0.0.1:8080");
5991 }
5992
5993 #[test]
5994 fn with_url_accepts_http_loopback_ipv6() {
5995 let client = Client::new().unwrap();
5996 let out = client
5997 .with_url("http://[::1]:8080")
5998 .expect("http://[::1] must be accepted (loopback)");
5999 assert!(out.url().starts_with("http://[::1]"));
6000 }
6001
6002 #[test]
6003 fn with_url_accepts_http_localhost() {
6004 let client = Client::new().unwrap();
6005 client
6006 .with_url("http://localhost:8080")
6007 .expect("http://localhost must be accepted");
6008 client
6009 .with_url("http://LOCALHOST")
6010 .expect("http://LOCALHOST must be accepted (case-insensitive)");
6011 client
6012 .with_url("http://wiremock.localhost")
6013 .expect("http://*.localhost must be accepted");
6014 }
6015
6016 #[test]
6017 fn with_url_rejects_http_public_host() {
6018 let client = Client::new().unwrap();
6019 let err = client.with_url("http://studio.example.com").unwrap_err();
6020 match err {
6021 Error::InsecureUrl(u) => assert_eq!(u, "http://studio.example.com"),
6022 other => panic!("expected InsecureUrl, got {other:?}"),
6023 }
6024 }
6025
6026 #[test]
6027 fn with_url_rejects_http_public_ip() {
6028 let client = Client::new().unwrap();
6029 // 8.8.8.8 is not loopback; must be rejected.
6030 let err = client.with_url("http://8.8.8.8").unwrap_err();
6031 assert!(matches!(err, Error::InsecureUrl(_)));
6032 }
6033
6034 #[test]
6035 fn with_url_rejects_non_http_scheme() {
6036 let client = Client::new().unwrap();
6037 // file:// would otherwise parse, but it's not a transport we
6038 // can use for RPC and we don't want to silently accept it.
6039 let err = client.with_url("file:///etc/passwd").unwrap_err();
6040 assert!(matches!(err, Error::InsecureUrl(_)));
6041 }
6042}
6043
6044#[cfg(test)]
6045mod tests_map_rpc_error {
6046 use super::*;
6047 use crate::api::TaskID;
6048
6049 #[test]
6050 fn maps_not_found_with_task_id_to_typed_variant() {
6051 // Server code 101 + "not found" message + task_id present → TaskNotFound
6052 let task_id = TaskID::try_from("task-1a2b").unwrap();
6053 let err = map_rpc_error(
6054 "task.data.list",
6055 101,
6056 "task not found".to_string(),
6057 Some(task_id),
6058 );
6059 assert!(matches!(err, Error::TaskNotFound(_)));
6060 }
6061
6062 #[test]
6063 fn maps_cannot_find_phrasing_to_typed_variant() {
6064 // The DVE server emits "Cannot find task..." — the original "not found"
6065 // substring match missed this and the caller saw a generic RpcError.
6066 let task_id = TaskID::try_from("task-1a2b").unwrap();
6067 let err = map_rpc_error(
6068 "task.data.list",
6069 101,
6070 "Cannot find task with id 6789".to_string(),
6071 Some(task_id),
6072 );
6073 assert!(
6074 matches!(err, Error::TaskNotFound(_)),
6075 "'Cannot find task' should map to TaskNotFound, got {err:?}"
6076 );
6077 }
6078
6079 #[test]
6080 fn maps_does_not_exist_phrasing_to_typed_variant() {
6081 let task_id = TaskID::try_from("task-1a2b").unwrap();
6082 let err = map_rpc_error(
6083 "task.chart.get",
6084 101,
6085 "task does not exist".to_string(),
6086 Some(task_id),
6087 );
6088 assert!(matches!(err, Error::TaskNotFound(_)));
6089 }
6090
6091 #[test]
6092 fn maps_code_101_with_unknown_phrasing_when_task_id_supplied() {
6093 // Server contract for code 101 is "resource not found"; even if the
6094 // phrasing is novel, the typed variant should be returned so callers
6095 // can write a stable `match`.
6096 let task_id = TaskID::try_from("task-1a2b").unwrap();
6097 let err = map_rpc_error(
6098 "task.data.list",
6099 101,
6100 "completely novel server message".to_string(),
6101 Some(task_id),
6102 );
6103 assert!(
6104 matches!(err, Error::TaskNotFound(_)),
6105 "code 101 + task_id should always map to TaskNotFound, got {err:?}"
6106 );
6107 }
6108
6109 #[test]
6110 fn maps_permission_codes_to_typed_variant() {
6111 for code in [401, 403] {
6112 let err = map_rpc_error("task.chart.add", code, "denied".to_string(), None);
6113 assert!(
6114 matches!(err, Error::PermissionDenied(_)),
6115 "code {} did not map",
6116 code
6117 );
6118 }
6119 }
6120
6121 #[test]
6122 fn permission_denied_records_method_for_diagnostics() {
6123 let err = map_rpc_error("task.data.upload", 403, "forbidden".to_string(), None);
6124 match err {
6125 Error::PermissionDenied(method) => assert_eq!(method, "task.data.upload"),
6126 other => panic!("expected PermissionDenied, got {:?}", other),
6127 }
6128 }
6129
6130 #[test]
6131 fn maps_payload_too_large_to_typed_variant() {
6132 let err = map_rpc_error("val.data.upload", 413, "request too large".into(), None);
6133 match err {
6134 Error::PayloadTooLarge { method, size_hint } => {
6135 assert_eq!(method, "val.data.upload");
6136 assert!(size_hint.is_none());
6137 }
6138 other => panic!("expected PayloadTooLarge, got {:?}", other),
6139 }
6140 }
6141
6142 #[test]
6143 fn falls_through_to_generic_rpc_error_for_unknown_codes() {
6144 let err = map_rpc_error("task.data.list", -99999, "weird".to_string(), None);
6145 match err {
6146 Error::RpcError(code, msg) => {
6147 assert_eq!(code, -99999);
6148 assert_eq!(msg, "weird");
6149 }
6150 other => panic!("expected RpcError, got {:?}", other),
6151 }
6152 }
6153
6154 #[test]
6155 fn not_found_without_task_id_falls_through() {
6156 // Code 101 without task_id → generic RpcError (no task to name)
6157 let err = map_rpc_error("task.data.list", 101, "not found".to_string(), None);
6158 assert!(matches!(err, Error::RpcError(101, _)));
6159 }
6160
6161 #[test]
6162 fn code_101_with_task_id_always_maps_even_with_unrelated_message() {
6163 // Previously the test asserted fall-through for non-"not found"
6164 // messages, but the contract for code 101 is "resource not found"
6165 // (see api.go), so when a task_id is present the typed variant is
6166 // returned unconditionally to give callers a stable error type.
6167 let task_id = TaskID::try_from("task-1a2b").unwrap();
6168 let err = map_rpc_error(
6169 "task.data.list",
6170 101,
6171 "permission denied".to_string(),
6172 Some(task_id),
6173 );
6174 assert!(matches!(err, Error::TaskNotFound(_)));
6175 }
6176}
6177
6178#[cfg(test)]
6179mod tests_jobs {
6180 use super::*;
6181
6182 #[test]
6183 fn jobs_list_request_serializes_to_empty_object() {
6184 let req = JobsListRequest {};
6185 assert_eq!(serde_json::to_value(&req).unwrap(), serde_json::json!({}));
6186 }
6187
6188 #[test]
6189 fn job_deserializes_from_bk_batch_shape() {
6190 let json = r#"{
6191 "code": "edgefirst-validator:2.9.5",
6192 "title": "EdgeFirst Validator",
6193 "job_name": "smoke-test",
6194 "job_id": "aws-batch-abc",
6195 "state": "RUNNING",
6196 "launch": "2026-05-14T15:00:00Z",
6197 "task_id": 6789,
6198 "docker_task": {},
6199 "extra_field": "ignored"
6200 }"#;
6201 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6202 assert_eq!(job.code, "edgefirst-validator:2.9.5");
6203 assert_eq!(job.state, "RUNNING");
6204 assert_eq!(job.task_id, 6789);
6205 assert_eq!(job.task_id().value(), 6789);
6206 }
6207}
6208
6209#[cfg(test)]
6210mod tests_job_run {
6211 use super::*;
6212 use crate::api::Parameter;
6213 use std::collections::HashMap;
6214
6215 #[test]
6216 fn job_run_request_serializes_with_expected_fields() {
6217 let req = JobRunRequest {
6218 name: "edgefirst-validator".into(),
6219 job_name: "post-profile-run".into(),
6220 env: HashMap::from([("LOG_LEVEL".into(), "info".into())]),
6221 data: HashMap::from([("validation_session_id".into(), Parameter::Integer(2707))]),
6222 };
6223 let json = serde_json::to_value(&req).unwrap();
6224 assert_eq!(json["name"], "edgefirst-validator");
6225 assert_eq!(json["job_name"], "post-profile-run");
6226 assert_eq!(json["env"]["LOG_LEVEL"], "info");
6227 assert_eq!(json["data"]["validation_session_id"], 2707);
6228 }
6229
6230 #[test]
6231 fn job_run_response_deserializes_as_job() {
6232 // job.run now returns the full BK_BATCH record; deserialize as Job.
6233 let json = r#"{
6234 "code": "edgefirst-validator:2.9.5",
6235 "title": "EdgeFirst Validator",
6236 "job_name": "post-profile-run",
6237 "job_id": "aws-batch-job-xxx",
6238 "state": "SUBMITTED",
6239 "task_id": 6789
6240 }"#;
6241 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6242 assert_eq!(job.task_id, 6789);
6243 assert_eq!(job.job_id, "aws-batch-job-xxx");
6244 assert_eq!(job.state, "SUBMITTED");
6245 }
6246}
6247
6248#[cfg(test)]
6249mod tests_job_stop {
6250 use super::*;
6251 use crate::api::TaskID;
6252
6253 #[test]
6254 fn job_stop_request_serializes_with_task_id() {
6255 let task_id = TaskID::try_from("task-1a2b").unwrap();
6256 let req = JobStopRequest {
6257 task_id: task_id.value(),
6258 };
6259 let json = serde_json::to_value(&req).unwrap();
6260 assert_eq!(json["task_id"], task_id.value());
6261 }
6262}
6263
6264#[cfg(test)]
6265mod tests_task_data_list_request {
6266 use super::*;
6267 use crate::api::TaskID;
6268
6269 #[test]
6270 fn task_data_list_request_serializes_with_task_id() {
6271 let task_id = TaskID::try_from("task-1a2b").unwrap();
6272 let req = TaskDataListRequest {
6273 task_id: task_id.value(),
6274 };
6275 let json = serde_json::to_value(&req).unwrap();
6276 assert_eq!(json["task_id"], task_id.value());
6277 }
6278}
6279
6280#[cfg(test)]
6281mod tests_task_data_download {
6282 use super::*;
6283 use crate::api::TaskID;
6284
6285 #[test]
6286 fn task_data_download_request_serializes_with_all_fields() {
6287 let task_id = TaskID::try_from("task-1a2b").unwrap();
6288 let req = TaskDataDownloadRequest {
6289 task_id: task_id.value(),
6290 folder: "predictions".into(),
6291 file: "predictions.parquet".into(),
6292 };
6293 let json = serde_json::to_value(&req).unwrap();
6294 assert_eq!(json["task_id"], task_id.value());
6295 assert_eq!(json["folder"], "predictions");
6296 assert_eq!(json["file"], "predictions.parquet");
6297 }
6298}
6299
6300#[cfg(test)]
6301mod tests_task_chart_add {
6302 use super::*;
6303 use crate::api::{Parameter, TaskID};
6304
6305 #[test]
6306 fn task_chart_add_request_serializes_with_correct_fields() {
6307 let task_id = TaskID::try_from("task-1a2b").unwrap();
6308 let data = Parameter::Object(std::collections::HashMap::from([(
6309 "type".into(),
6310 Parameter::String("line".into()),
6311 )]));
6312 let req = TaskChartAddRequest {
6313 task_id: task_id.value(),
6314 group_name: "metrics".into(),
6315 chart_name: "loss".into(),
6316 params: None,
6317 data,
6318 };
6319 let json = serde_json::to_value(&req).unwrap();
6320 assert_eq!(json["task_id"], task_id.value());
6321 assert_eq!(json["group_name"], "metrics");
6322 assert_eq!(json["chart_name"], "loss");
6323 assert_eq!(json["data"]["type"], "line");
6324 assert!(json["params"].is_null());
6325 }
6326}
6327
6328#[cfg(test)]
6329mod tests_task_chart_list {
6330 use super::*;
6331 use crate::api::TaskID;
6332
6333 #[test]
6334 fn task_chart_list_request_omits_empty_group_name() {
6335 let task_id = TaskID::try_from("task-1a2b").unwrap();
6336 let req = TaskChartListRequest {
6337 task_id: task_id.value(),
6338 group_name: String::new(),
6339 };
6340 let json = serde_json::to_value(&req).unwrap();
6341 assert_eq!(json["task_id"], task_id.value());
6342 assert_eq!(json["group_name"], "");
6343 }
6344}
6345
6346#[cfg(test)]
6347mod tests_task_chart_get {
6348 use super::*;
6349 use crate::api::TaskID;
6350
6351 #[test]
6352 fn task_chart_get_request_serializes_with_all_fields() {
6353 let task_id = TaskID::try_from("task-1a2b").unwrap();
6354 let req = TaskChartGetRequest {
6355 task_id: task_id.value(),
6356 group_name: "metrics".into(),
6357 chart_name: "loss".into(),
6358 };
6359 let json = serde_json::to_value(&req).unwrap();
6360 assert_eq!(json["task_id"], task_id.value());
6361 assert_eq!(json["group_name"], "metrics");
6362 assert_eq!(json["chart_name"], "loss");
6363 }
6364}
6365
6366#[cfg(test)]
6367mod tests_val_data_download {
6368 use super::*;
6369
6370 #[test]
6371 fn val_data_download_request_serializes() {
6372 let req = ValDataDownloadRequest {
6373 session_id: 2707,
6374 filename: "trace/imx95.json".into(),
6375 };
6376 let json = serde_json::to_value(&req).unwrap();
6377 assert_eq!(json["session_id"], 2707);
6378 assert_eq!(json["filename"], "trace/imx95.json");
6379 }
6380}
6381
6382#[cfg(test)]
6383mod tests_val_data_list {
6384 use super::*;
6385
6386 #[test]
6387 fn val_data_list_request_serializes() {
6388 let req = ValDataListRequest { session_id: 2707 };
6389 assert_eq!(
6390 serde_json::to_value(&req).unwrap(),
6391 serde_json::json!({"session_id": 2707})
6392 );
6393 }
6394}
6395
6396#[cfg(test)]
6397mod tests_jsonrpc_envelope_detection {
6398 use super::*;
6399
6400 #[test]
6401 fn detects_real_envelope() {
6402 let v = serde_json::json!({
6403 "jsonrpc": "2.0",
6404 "id": 0,
6405 "error": { "code": 101, "message": "Cannot find task" },
6406 });
6407 assert!(is_jsonrpc_error_envelope(&v));
6408 }
6409
6410 #[test]
6411 fn rejects_plain_json_artifact_with_error_field() {
6412 // A diagnostics file with a free-form `error` object — must not be
6413 // misread as an RPC envelope just because the key collides.
6414 let v = serde_json::json!({
6415 "metric": "loss",
6416 "value": 0.42,
6417 "error": { "code": "ENV_NOT_FOUND", "message": "missing var" },
6418 });
6419 assert!(
6420 !is_jsonrpc_error_envelope(&v),
6421 "missing jsonrpc sentinel should mean 'not an envelope'"
6422 );
6423 }
6424
6425 #[test]
6426 fn rejects_envelope_missing_jsonrpc_sentinel() {
6427 // Bare `error` block without the protocol-version marker.
6428 let v = serde_json::json!({
6429 "id": 0,
6430 "error": { "code": 101, "message": "x" },
6431 });
6432 assert!(!is_jsonrpc_error_envelope(&v));
6433 }
6434
6435 #[test]
6436 fn rejects_envelope_with_non_object_error_field() {
6437 // A diagnostics file shaped like JSON-RPC accidentally but using
6438 // a string for `error`.
6439 let v = serde_json::json!({
6440 "jsonrpc": "2.0",
6441 "error": "something went wrong",
6442 });
6443 assert!(!is_jsonrpc_error_envelope(&v));
6444 }
6445
6446 #[test]
6447 fn rejects_envelope_without_error_code() {
6448 // Real envelopes always carry an integer error.code; missing one
6449 // is suspicious enough to refuse the envelope classification.
6450 let v = serde_json::json!({
6451 "jsonrpc": "2.0",
6452 "error": { "message": "no code" },
6453 });
6454 assert!(!is_jsonrpc_error_envelope(&v));
6455 }
6456
6457 #[test]
6458 fn rejects_envelope_with_non_numeric_error_code() {
6459 let v = serde_json::json!({
6460 "jsonrpc": "2.0",
6461 "error": { "code": "ENOENT", "message": "x" },
6462 });
6463 assert!(!is_jsonrpc_error_envelope(&v));
6464 }
6465
6466 #[test]
6467 fn rejects_non_object_root() {
6468 // A JSON file whose root is an array — common for metrics dumps —
6469 // must not be misread.
6470 let v = serde_json::json!([1, 2, 3]);
6471 assert!(!is_jsonrpc_error_envelope(&v));
6472 }
6473
6474 #[test]
6475 fn accepts_unsigned_error_code() {
6476 // The server's code is technically i32 but JSON has no signed/
6477 // unsigned distinction — accept both shapes.
6478 let v = serde_json::json!({
6479 "jsonrpc": "2.0",
6480 "error": { "code": 101u32, "message": "x" },
6481 });
6482 assert!(is_jsonrpc_error_envelope(&v));
6483 }
6484}
6485
6486#[cfg(test)]
6487mod tests_validate_chart_args {
6488 use super::*;
6489
6490 #[test]
6491 fn rejects_empty_group() {
6492 let err = validate_chart_args("", "name").unwrap_err();
6493 assert!(matches!(err, Error::InvalidParameters(_)));
6494 }
6495
6496 #[test]
6497 fn rejects_empty_name() {
6498 let err = validate_chart_args("group", "").unwrap_err();
6499 assert!(matches!(err, Error::InvalidParameters(_)));
6500 }
6501
6502 #[test]
6503 fn rejects_both_empty() {
6504 let err = validate_chart_args("", "").unwrap_err();
6505 assert!(matches!(err, Error::InvalidParameters(_)));
6506 }
6507
6508 #[test]
6509 fn accepts_valid_args() {
6510 assert!(validate_chart_args("group", "name").is_ok());
6511 }
6512
6513 #[test]
6514 fn accepts_unicode_args() {
6515 // Unicode names are allowed; only emptiness is rejected.
6516 assert!(validate_chart_args("metrics-集合", "损失").is_ok());
6517 }
6518}
6519
6520// ---------------------------------------------------------------------------
6521// Additional offline tests for request shapes + helpers added in DE-2565.
6522//
6523// These focus on the wire-shape and helper logic that does not require a
6524// live Studio server — they significantly boost coverage of client.rs.
6525// ---------------------------------------------------------------------------
6526
6527#[cfg(test)]
6528mod tests_job_run_request_shape {
6529 use super::*;
6530 use crate::api::Parameter;
6531 use std::collections::HashMap;
6532
6533 #[test]
6534 fn empty_env_and_data_serialize_as_empty_objects() {
6535 let req = JobRunRequest {
6536 name: "edgefirst-validator".into(),
6537 job_name: "smoke".into(),
6538 env: HashMap::new(),
6539 data: HashMap::new(),
6540 };
6541 let json = serde_json::to_value(&req).unwrap();
6542 assert_eq!(json["name"], "edgefirst-validator");
6543 assert_eq!(json["env"], serde_json::json!({}));
6544 assert_eq!(json["data"], serde_json::json!({}));
6545 }
6546
6547 #[test]
6548 fn data_passes_through_parameter_object_payloads() {
6549 // Confirms the Parameter wrapper survives JSON serialization round-trip
6550 // for the kind of structured chart payload that exercises Parameter
6551 // variants (Real, Integer, String, Array, Object, Boolean).
6552 let req = JobRunRequest {
6553 name: "edgefirst-validator".into(),
6554 job_name: "feat".into(),
6555 env: HashMap::new(),
6556 data: HashMap::from([
6557 ("flag".into(), Parameter::Boolean(true)),
6558 ("epochs".into(), Parameter::Integer(50)),
6559 ("lr".into(), Parameter::Real(1e-3)),
6560 ("name".into(), Parameter::String("hello".into())),
6561 ]),
6562 };
6563 let json = serde_json::to_value(&req).unwrap();
6564 assert_eq!(json["data"]["flag"], true);
6565 assert_eq!(json["data"]["epochs"], 50);
6566 assert!(json["data"]["lr"].as_f64().unwrap() > 0.0);
6567 assert_eq!(json["data"]["name"], "hello");
6568 }
6569}
6570
6571#[cfg(test)]
6572mod tests_task_data_chart_request_shape {
6573 use super::*;
6574 use crate::api::{Parameter, TaskID};
6575
6576 #[test]
6577 fn chart_add_request_with_params_serializes_object() {
6578 let task_id = TaskID::try_from("task-1a2b").unwrap();
6579 let params = Parameter::Object(std::collections::HashMap::from([(
6580 "y_axis".into(),
6581 Parameter::String("log".into()),
6582 )]));
6583 let data = Parameter::Object(std::collections::HashMap::from([(
6584 "type".into(),
6585 Parameter::String("line".into()),
6586 )]));
6587 let req = TaskChartAddRequest {
6588 task_id: task_id.value(),
6589 group_name: "metrics".into(),
6590 chart_name: "loss".into(),
6591 params: Some(params),
6592 data,
6593 };
6594 let json = serde_json::to_value(&req).unwrap();
6595 assert_eq!(json["params"]["y_axis"], "log");
6596 }
6597
6598 #[test]
6599 fn task_data_list_request_round_trips() {
6600 let task_id = TaskID::try_from("task-1a2b").unwrap();
6601 let req = TaskDataListRequest {
6602 task_id: task_id.value(),
6603 };
6604 let json = serde_json::to_string(&req).unwrap();
6605 // Field order is stable for a single-field struct, so an exact match
6606 // is meaningful here.
6607 assert_eq!(json, format!("{{\"task_id\":{}}}", task_id.value()));
6608 }
6609
6610 #[test]
6611 fn task_data_download_request_treats_folder_and_file_independently() {
6612 let task_id = TaskID::try_from("task-1a2b").unwrap();
6613 let req = TaskDataDownloadRequest {
6614 task_id: task_id.value(),
6615 folder: "validation/run-01".into(),
6616 file: "metrics.json".into(),
6617 };
6618 let json = serde_json::to_value(&req).unwrap();
6619 // Server takes folder + file separately (not a single combined path)
6620 // so callers don't have to escape slashes themselves.
6621 assert_eq!(json["folder"], "validation/run-01");
6622 assert_eq!(json["file"], "metrics.json");
6623 }
6624}
6625
6626#[cfg(test)]
6627mod tests_val_data_request_shape {
6628 use super::*;
6629
6630 #[test]
6631 fn val_data_list_round_trips() {
6632 let req = ValDataListRequest { session_id: 2707 };
6633 let s = serde_json::to_string(&req).unwrap();
6634 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6635 assert_eq!(back["session_id"], 2707);
6636 }
6637
6638 #[test]
6639 fn val_data_download_round_trips_with_nested_path() {
6640 let req = ValDataDownloadRequest {
6641 session_id: 2707,
6642 filename: "subfolder/imx95.json".into(),
6643 };
6644 let s = serde_json::to_string(&req).unwrap();
6645 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
6646 assert_eq!(back["session_id"], 2707);
6647 assert_eq!(back["filename"], "subfolder/imx95.json");
6648 }
6649}
6650
6651#[cfg(test)]
6652mod tests_progress_struct {
6653 use super::*;
6654
6655 #[test]
6656 fn progress_can_be_constructed_with_zero_total() {
6657 // Servers sometimes omit Content-Length; progress events should still
6658 // be representable. This guards the public field-level API.
6659 let p = Progress {
6660 current: 0,
6661 total: 0,
6662 status: None,
6663 };
6664 assert_eq!(p.current, 0);
6665 assert_eq!(p.total, 0);
6666 assert!(p.status.is_none());
6667 }
6668
6669 #[test]
6670 fn progress_tracks_current_independently_of_total() {
6671 let p = Progress {
6672 current: 123,
6673 total: 456,
6674 status: Some("Downloading".into()),
6675 };
6676 assert_eq!(p.current, 123);
6677 assert_eq!(p.total, 456);
6678 assert_eq!(p.status.as_deref(), Some("Downloading"));
6679 }
6680
6681 #[test]
6682 fn progress_can_be_cloned() {
6683 // Progress is consumed by progress sinks which may need to retain a
6684 // copy independently of the channel — derive(Clone) must hold.
6685 let p = Progress {
6686 current: 10,
6687 total: 20,
6688 status: Some("phase".into()),
6689 };
6690 let q = p.clone();
6691 assert_eq!(q.current, p.current);
6692 assert_eq!(q.total, p.total);
6693 assert_eq!(q.status, p.status);
6694 }
6695}
6696
6697#[cfg(test)]
6698mod tests_bare_filename_parent {
6699 // Documents the empty-parent guard added for `rpc_download` so that
6700 // callers passing a bare filename like "metrics.json" download to the
6701 // current directory instead of erroring on `create_dir_all("")`.
6702 use std::path::Path;
6703
6704 #[test]
6705 fn bare_filename_parent_is_empty_path() {
6706 // This is the invariant our guard depends on. If a future Rust
6707 // release ever changed `Path::parent` for bare filenames, the guard
6708 // would need revisiting.
6709 let p = Path::new("metrics.json");
6710 let parent = p.parent().expect("bare filename always has Some parent");
6711 assert!(
6712 parent.as_os_str().is_empty(),
6713 "Path::parent for bare filename should be empty, got: {parent:?}"
6714 );
6715 }
6716
6717 #[test]
6718 fn path_with_directory_has_non_empty_parent() {
6719 // The companion case: when the path includes a directory, the
6720 // parent is non-empty and `create_dir_all` should be invoked.
6721 let p = Path::new("dir/metrics.json");
6722 let parent = p.parent().expect("path-with-dir always has Some parent");
6723 assert!(!parent.as_os_str().is_empty());
6724 assert_eq!(parent, Path::new("dir"));
6725 }
6726}