edgefirst_client/client.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult,
8 NewValidationSession, Organization, Project, ProjectID, SampleID, SamplesCountResult,
9 SamplesListParams, SamplesListResult, Snapshot, SnapshotCreateFromDataset,
10 SnapshotFromDatasetResult, SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage,
11 StartValidationRequest, TaskID, TaskInfo, TaskStages, TaskStatus, TasksListParams,
12 TasksListResult, TrainingSession, TrainingSessionID, ValidationSession,
13 ValidationSessionID,
14 },
15 dataset::{
16 AnnotationSet, AnnotationType, Dataset, FileType, Group, Label, NewLabel, NewLabelObject,
17 },
18 retry::{create_retry_policy, log_retry_configuration},
19 storage::{FileTokenStorage, MemoryTokenStorage, TokenStorage},
20};
21use base64::Engine as _;
22use chrono::{DateTime, Utc};
23use directories::ProjectDirs;
24use futures::{StreamExt as _, future::join_all};
25use log::{Level, debug, error, log_enabled, trace, warn};
26use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::{
29 collections::HashMap,
30 ffi::OsStr,
31 fs::create_dir_all,
32 io::{SeekFrom, Write as _},
33 path::{Path, PathBuf},
34 sync::{
35 Arc,
36 atomic::{AtomicUsize, Ordering},
37 },
38 time::Duration,
39 vec,
40};
41use tokio::{
42 fs::{self, File},
43 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
44 sync::{RwLock, Semaphore, mpsc::Sender},
45};
46use tokio_util::codec::{BytesCodec, FramedRead};
47use walkdir::WalkDir;
48
49#[cfg(feature = "polars")]
50use polars::prelude::*;
51
52/// Maps a JSON-RPC error code to a typed `Error` variant when the code is
53/// well-known; otherwise returns `Error::RpcError(code, message)` unchanged.
54///
55/// Scoped to the new DE-2565 methods. Existing methods continue to return
56/// `Error::RpcError` directly.
57///
58/// Server error codes (from `api.go` via `jrpc.Fail`):
59/// - `1` – generic server error
60/// - `3` – validation / bad request
61/// - `10` – internal server error
62/// - `101` – resource not found (e.g. "Cannot find task...", "not found in DB")
63/// - `401` – unauthenticated
64/// - `403` – forbidden
65/// - `413` – payload too large
66pub(crate) fn map_rpc_error(
67 method: &str,
68 code: i32,
69 message: String,
70 task_id: Option<crate::api::TaskID>,
71) -> Error {
72 // Server emits "Cannot find task...", "not found in DB", and other phrasings
73 // for code 101. Code 101 with a task_id is task-not-found by contract
74 // (see api.go), so we return the typed variant unconditionally when the
75 // caller supplied a task_id — message phrasing is treated as informational
76 // and is preserved by the RPC layer for diagnostic logging upstream.
77 if code == 101
78 && let Some(id) = task_id
79 {
80 return Error::TaskNotFound(id);
81 }
82 match code {
83 401 | 403 => Error::PermissionDenied(method.to_string()),
84 413 => Error::PayloadTooLarge {
85 method: method.to_string(),
86 size_hint: None,
87 },
88 _ => Error::RpcError(code, message),
89 }
90}
91
92/// Returns true if `val` is structurally a JSON-RPC 2.0 *error* envelope.
93///
94/// A real envelope must:
95/// 1. Be a JSON object,
96/// 2. Carry a `"jsonrpc"` member (the protocol-version sentinel — JSON-RPC
97/// 2.0 §5 mandates this on every response object),
98/// 3. Carry an `"error"` object that includes a numeric `"code"` field.
99///
100/// This is intentionally stricter than a "looks for a top-level `error`
101/// key" check so that legitimate JSON file payloads (validation traces,
102/// metrics dumps, diagnostics) which happen to include a free-form `error`
103/// field are *not* misclassified as RPC failures.
104///
105/// Extracted so it can be unit-tested without a live server.
106pub(crate) fn is_jsonrpc_error_envelope(val: &serde_json::Value) -> bool {
107 let Some(obj) = val.as_object() else {
108 return false;
109 };
110 // Protocol-version sentinel — only JSON-RPC envelopes carry this.
111 if !obj.contains_key("jsonrpc") {
112 return false;
113 }
114 let Some(err) = obj.get("error").and_then(|e| e.as_object()) else {
115 return false;
116 };
117 err.get("code")
118 .map(|c| c.is_i64() || c.is_u64())
119 .unwrap_or(false)
120}
121
122/// Validates that `group` and `name` are both non-empty strings for chart
123/// operations (`add_chart`, `get_chart`). Extracted so it can be unit-tested
124/// without a live server.
125pub(crate) fn validate_chart_args(group: &str, name: &str) -> Result<(), Error> {
126 if group.is_empty() || name.is_empty() {
127 return Err(Error::InvalidParameters(
128 "chart: group and name must be non-empty".into(),
129 ));
130 }
131 Ok(())
132}
133
134static PART_SIZE: usize = 100 * 1024 * 1024;
135
136/// Source for file content during upload - either a local path or raw bytes.
137#[derive(Clone)]
138enum FileSource {
139 /// File content from a local filesystem path.
140 Path(PathBuf),
141 /// File content as raw bytes (e.g., from a ZIP archive).
142 Bytes(Vec<u8>),
143}
144
145fn max_tasks() -> usize {
146 std::env::var("MAX_TASKS")
147 .ok()
148 .and_then(|v| v.parse().ok())
149 .unwrap_or_else(|| {
150 // Default to half the number of CPUs, minimum 2, maximum 8
151 let cpus = std::thread::available_parallelism()
152 .map(|n| n.get())
153 .unwrap_or(4);
154 (cpus / 2).clamp(2, 8)
155 })
156}
157
158/// Maximum concurrent upload tasks for multipart S3 uploads.
159///
160/// Higher concurrency improves upload throughput by saturating available
161/// bandwidth. Can be overridden via `MAX_UPLOAD_TASKS` environment variable.
162fn max_upload_tasks() -> usize {
163 std::env::var("MAX_UPLOAD_TASKS")
164 .ok()
165 .and_then(|v| v.parse().ok())
166 .unwrap_or(8) // Default to 8 concurrent part uploads
167}
168
169/// Filters items by name and sorts by match quality.
170///
171/// Match quality priority (best to worst):
172/// 1. Exact match (case-sensitive)
173/// 2. Exact match (case-insensitive)
174/// 3. Substring match (shorter names first, then alphabetically)
175///
176/// This ensures that searching for "Deer" returns "Deer" before
177/// "Deer Roundtrip 20251129" or "Reindeer".
178fn filter_and_sort_by_name<T, F>(items: Vec<T>, filter: &str, get_name: F) -> Vec<T>
179where
180 F: Fn(&T) -> &str,
181{
182 let filter_lower = filter.to_lowercase();
183 let mut filtered: Vec<T> = items
184 .into_iter()
185 .filter(|item| get_name(item).to_lowercase().contains(&filter_lower))
186 .collect();
187
188 filtered.sort_by(|a, b| {
189 let name_a = get_name(a);
190 let name_b = get_name(b);
191
192 // Priority 1: Exact match (case-sensitive)
193 let exact_a = name_a == filter;
194 let exact_b = name_b == filter;
195 if exact_a != exact_b {
196 return exact_b.cmp(&exact_a); // true (exact) comes first
197 }
198
199 // Priority 2: Exact match (case-insensitive)
200 let exact_ci_a = name_a.to_lowercase() == filter_lower;
201 let exact_ci_b = name_b.to_lowercase() == filter_lower;
202 if exact_ci_a != exact_ci_b {
203 return exact_ci_b.cmp(&exact_ci_a);
204 }
205
206 // Priority 3: Shorter names first (more specific matches)
207 let len_cmp = name_a.len().cmp(&name_b.len());
208 if len_cmp != std::cmp::Ordering::Equal {
209 return len_cmp;
210 }
211
212 // Priority 4: Alphabetical order for stability
213 name_a.cmp(name_b)
214 });
215
216 filtered
217}
218
219/// Whether `host` refers to a loopback (machine-local) endpoint.
220///
221/// Used by [`Client::with_url`] to decide whether a plain-`http://` URL is
222/// safe to accept. Loopback traffic never leaves the machine, so the
223/// usual concern about leaking the Studio bearer token in plaintext does
224/// not apply — that's how wiremock and local dev servers connect.
225fn is_loopback_host(host: Option<&url::Host<&str>>) -> bool {
226 match host {
227 Some(url::Host::Ipv4(ip)) => ip.is_loopback(),
228 Some(url::Host::Ipv6(ip)) => ip.is_loopback(),
229 // RFC 6761 reserves "localhost" (and `*.localhost`) as a loopback
230 // name. Compare case-insensitively because URL hosts are matched
231 // that way and developers do type capitalized variants.
232 Some(url::Host::Domain(d)) => {
233 d.eq_ignore_ascii_case("localhost") || d.to_ascii_lowercase().ends_with(".localhost")
234 }
235 None => false,
236 }
237}
238
239fn sanitize_path_component(name: &str) -> String {
240 let trimmed = name.trim();
241 if trimmed.is_empty() {
242 return "unnamed".to_string();
243 }
244
245 let component = Path::new(trimmed)
246 .file_name()
247 .unwrap_or_else(|| OsStr::new(trimmed));
248
249 let sanitized: String = component
250 .to_string_lossy()
251 .chars()
252 .map(|c| match c {
253 '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
254 _ => c,
255 })
256 .collect();
257
258 if sanitized.is_empty() {
259 "unnamed".to_string()
260 } else {
261 sanitized
262 }
263}
264
265/// Progress information for long-running operations.
266///
267/// This struct tracks the current progress of operations like file uploads,
268/// downloads, or dataset processing. It provides the current count, total
269/// count, and an optional status string to enable progress reporting in
270/// applications.
271///
272/// # Multi-Stage Progress
273///
274/// The `status` field enables multi-stage progress tracking. When an operation
275/// has multiple phases, the status field changes to indicate the current phase.
276/// Applications should detect status changes to reset their progress display.
277///
278/// # Operation Progress Details
279///
280/// | Operation | Status | Unit | Notes |
281/// |-----------|--------|------|-------|
282/// | [`download_dataset`] | `None` then `"Downloading"` | samples | Two phases: fetch metadata, then download files |
283/// | [`populate_samples`] | `None` | samples | Each sample may contain multiple files |
284/// | [`samples`] | `None` | samples | Paginated API fetch |
285/// | [`sample_names`] | `None` | samples | Paginated API fetch, names only |
286/// | [`annotations`] | `None` | samples | Samples processed for annotations |
287/// | [`download_artifact`] | `None` | bytes | Single file byte-level progress |
288/// | [`download_checkpoint`] | `None` | bytes | Single file byte-level progress |
289/// | [`download_snapshot`] | `None` | bytes | Combined byte progress across all files |
290///
291/// [`download_dataset`]: Client::download_dataset
292/// [`populate_samples`]: Client::populate_samples
293/// [`samples`]: Client::samples
294/// [`sample_names`]: Client::sample_names
295/// [`annotations`]: Client::annotations
296/// [`download_artifact`]: Client::download_artifact
297/// [`download_checkpoint`]: Client::download_checkpoint
298/// [`download_snapshot`]: Client::download_snapshot
299///
300/// # Examples
301///
302/// Basic progress display:
303///
304/// ```rust
305/// use edgefirst_client::Progress;
306///
307/// let progress = Progress {
308/// current: 25,
309/// total: 100,
310/// status: Some("Downloading".to_string()),
311/// };
312/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
313/// println!(
314/// "{}: {:.1}% ({}/{})",
315/// progress.status.as_deref().unwrap_or("Progress"),
316/// percentage,
317/// progress.current,
318/// progress.total
319/// );
320/// ```
321///
322/// Multi-stage progress handling (e.g., for `download_dataset`):
323///
324/// ```rust,ignore
325/// let mut last_status: Option<String> = None;
326///
327/// while let Some(progress) = rx.recv().await {
328/// // Detect stage change and reset progress bar
329/// if progress.status != last_status {
330/// if let Some(ref status) = progress.status {
331/// println!("\n{}", status);
332/// }
333/// last_status = progress.status.clone();
334/// }
335///
336/// let pct = (progress.current as f64 / progress.total as f64) * 100.0;
337/// print!("\r{:.1}% ({}/{})", pct, progress.current, progress.total);
338/// }
339/// ```
340#[derive(Debug, Clone)]
341pub struct Progress {
342 /// Current number of completed items or bytes.
343 pub current: usize,
344 /// Total number of items or bytes to process.
345 pub total: usize,
346 /// Optional status describing the current operation phase.
347 ///
348 /// When this value changes from `None` to `Some(...)` or between different
349 /// values, it indicates a new phase has started. Applications should reset
350 /// their progress display when the status changes.
351 ///
352 /// Currently only [`Client::download_dataset`] uses status changes:
353 /// - Phase 1: `None` while fetching sample metadata
354 /// - Phase 2: `"Downloading"` while downloading files
355 ///
356 /// All other operations use `None` throughout.
357 pub status: Option<String>,
358}
359
360#[derive(Serialize)]
361struct RpcRequest<Params> {
362 id: u64,
363 jsonrpc: String,
364 method: String,
365 params: Option<Params>,
366}
367
368impl<T> Default for RpcRequest<T> {
369 fn default() -> Self {
370 RpcRequest {
371 id: 0,
372 jsonrpc: "2.0".to_string(),
373 method: "".to_string(),
374 params: None,
375 }
376 }
377}
378
379#[derive(Deserialize)]
380struct RpcError {
381 code: i32,
382 message: String,
383}
384
385#[derive(Deserialize)]
386struct RpcResponse<RpcResult> {
387 #[allow(dead_code)]
388 id: String,
389 #[allow(dead_code)]
390 jsonrpc: String,
391 error: Option<RpcError>,
392 result: Option<RpcResult>,
393}
394
395#[derive(Deserialize)]
396#[allow(dead_code)]
397struct EmptyResult {}
398
399#[derive(Debug, Serialize)]
400#[allow(dead_code)]
401struct SnapshotCreateParams {
402 snapshot_name: String,
403 keys: Vec<String>,
404}
405
406#[derive(Debug, Deserialize)]
407#[allow(dead_code)]
408struct SnapshotCreateResult {
409 snapshot_id: SnapshotID,
410 urls: Vec<String>,
411}
412
413#[derive(Debug, Serialize)]
414struct SnapshotCreateMultipartParams {
415 snapshot_name: String,
416 keys: Vec<String>,
417 file_sizes: Vec<usize>,
418 /// Optional snapshot type (e.g., "ziparrow" for EdgeFirst Dataset Format)
419 #[serde(skip_serializing_if = "Option::is_none", rename = "type")]
420 snapshot_type: Option<String>,
421}
422
423#[derive(Debug, Deserialize)]
424#[serde(untagged)]
425enum SnapshotCreateMultipartResultField {
426 Id(u64),
427 Part(SnapshotPart),
428}
429
430#[derive(Debug, Serialize)]
431struct SnapshotCompleteMultipartParams {
432 key: String,
433 upload_id: String,
434 etag_list: Vec<EtagPart>,
435}
436
437#[derive(Debug, Clone, Serialize)]
438struct EtagPart {
439 #[serde(rename = "ETag")]
440 etag: String,
441 #[serde(rename = "PartNumber")]
442 part_number: usize,
443}
444
445#[derive(Debug, Clone, Deserialize)]
446struct SnapshotPart {
447 key: Option<String>,
448 upload_id: String,
449 urls: Vec<String>,
450}
451
452#[derive(Debug, Serialize)]
453struct SnapshotStatusParams {
454 snapshot_id: SnapshotID,
455 status: String,
456}
457
458#[derive(Deserialize, Debug)]
459struct SnapshotStatusResult {
460 #[allow(dead_code)]
461 pub id: SnapshotID,
462 #[allow(dead_code)]
463 pub uid: String,
464 #[allow(dead_code)]
465 pub description: String,
466 #[allow(dead_code)]
467 pub date: String,
468 #[allow(dead_code)]
469 pub status: String,
470}
471
472#[derive(Serialize)]
473#[allow(dead_code)]
474struct ImageListParams {
475 images_filter: ImagesFilter,
476 image_files_filter: HashMap<String, String>,
477 only_ids: bool,
478}
479
480#[derive(Serialize)]
481#[allow(dead_code)]
482struct ImagesFilter {
483 dataset_id: DatasetID,
484}
485
486/// Main client for interacting with EdgeFirst Studio Server.
487///
488/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
489/// and manages authentication, RPC calls, and data operations. It provides
490/// methods for managing projects, datasets, experiments, training sessions,
491/// and various utility functions for data processing.
492///
493/// The client supports multiple authentication methods and can work with both
494/// SaaS and self-hosted EdgeFirst Studio instances.
495///
496/// # Features
497///
498/// - **Authentication**: Token-based authentication with automatic persistence
499/// - **Dataset Management**: Upload, download, and manipulate datasets
500/// - **Project Operations**: Create and manage projects and experiments
501/// - **Training & Validation**: Submit and monitor ML training jobs
502/// - **Data Integration**: Convert between EdgeFirst datasets and popular
503/// formats
504/// - **Progress Tracking**: Real-time progress updates for long-running
505/// operations
506///
507/// # Examples
508///
509/// ```no_run
510/// use edgefirst_client::{Client, DatasetID};
511/// use std::str::FromStr;
512///
513/// # async fn example() -> Result<(), edgefirst_client::Error> {
514/// // Create a new client and authenticate
515/// let mut client = Client::new()?;
516/// let client = client
517/// .with_login("your-email@example.com", "password")
518/// .await?;
519///
520/// // Or use an existing token
521/// let base_client = Client::new()?;
522/// let client = base_client.with_token("your-token-here")?;
523///
524/// // Get organization and projects
525/// let org = client.organization().await?;
526/// let projects = client.projects(None).await?;
527///
528/// // Work with datasets
529/// let dataset_id = DatasetID::from_str("ds-abc123")?;
530/// let dataset = client.dataset(dataset_id).await?;
531/// # Ok(())
532/// # }
533/// ```
534/// Client is Clone but cannot derive Debug due to dyn TokenStorage
535#[derive(Clone)]
536pub struct Client {
537 http: reqwest::Client,
538 /// HTTP client for long-running bulk transfers (uploads/downloads, no total-request
539 /// timeout). An idle read timeout is still configured on the underlying client, and
540 /// some operations (such as uploads) may apply additional per-request timeouts.
541 bulk_http: reqwest::Client,
542 url: String,
543 token: Arc<RwLock<String>>,
544 /// Token storage backend. When set, tokens are automatically persisted.
545 storage: Option<Arc<dyn TokenStorage>>,
546 /// Legacy token path field for backwards compatibility with
547 /// with_token_path(). Deprecated: Use with_storage() instead.
548 token_path: Option<PathBuf>,
549}
550
551impl std::fmt::Debug for Client {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("Client")
554 .field("url", &self.url)
555 .field("has_storage", &self.storage.is_some())
556 .field("token_path", &self.token_path)
557 .finish()
558 }
559}
560
561/// Private context struct for pagination operations
562struct FetchContext<'a> {
563 dataset_id: DatasetID,
564 annotation_set_id: Option<AnnotationSetID>,
565 groups: &'a [String],
566 types: Vec<String>,
567 labels: &'a HashMap<String, u64>,
568}
569
570#[derive(Debug, Serialize)]
571struct JobsListRequest {}
572
573#[derive(Debug, Serialize)]
574struct JobRunRequest {
575 name: String,
576 job_name: String,
577 env: std::collections::HashMap<String, String>,
578 data: std::collections::HashMap<String, crate::api::Parameter>,
579}
580
581#[derive(Debug, Serialize)]
582struct JobStopRequest {
583 task_id: u64,
584}
585
586#[derive(Debug, Serialize)]
587pub(crate) struct TaskDataListRequest {
588 pub(crate) task_id: u64,
589}
590
591#[derive(Debug, Serialize)]
592pub(crate) struct TaskDataDownloadRequest {
593 pub(crate) task_id: u64,
594 pub(crate) folder: String,
595 pub(crate) file: String,
596}
597
598#[derive(Debug, Serialize)]
599pub(crate) struct TaskChartAddRequest {
600 pub(crate) task_id: u64,
601 pub(crate) group_name: String,
602 pub(crate) chart_name: String,
603 pub(crate) params: Option<crate::api::Parameter>,
604 pub(crate) data: crate::api::Parameter,
605}
606
607#[derive(Debug, Serialize)]
608pub(crate) struct TaskChartListRequest {
609 pub(crate) task_id: u64,
610 pub(crate) group_name: String,
611}
612
613#[derive(Debug, Serialize)]
614pub(crate) struct TaskChartGetRequest {
615 pub(crate) task_id: u64,
616 pub(crate) group_name: String,
617 pub(crate) chart_name: String,
618}
619
620#[derive(Debug, Serialize)]
621pub(crate) struct ValDataDownloadRequest {
622 pub(crate) session_id: u64,
623 pub(crate) filename: String,
624}
625
626#[derive(Debug, Serialize)]
627pub(crate) struct ValDataListRequest {
628 pub(crate) session_id: u64,
629}
630
631/// Streams the body of a successful `reqwest` response to a file on disk,
632/// emitting optional progress events.
633///
634/// Both `download_artifact` and `rpc_download` share this logic. The caller is
635/// responsible for creating any required parent directories before calling this
636/// function.
637///
638/// # Arguments
639/// * `resp` - A successful (HTTP 2xx) `reqwest::Response` whose body will
640/// be streamed to `path`.
641/// * `path` - Destination file path (created or truncated).
642/// * `progress` - Optional channel; events carry bytes received and
643/// `Content-Length` total (0 if the server omits it).
644///
645/// # Errors
646/// Returns `Error::IoError` on file I/O failures or propagates stream errors.
647async fn stream_response_to_file(
648 resp: reqwest::Response,
649 path: &std::path::Path,
650 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
651) -> Result<(), Error> {
652 use tokio::io::AsyncWriteExt as _;
653 let total = resp.content_length().unwrap_or(0) as usize;
654 let mut stream = resp.bytes_stream();
655 let mut file = tokio::fs::File::create(path).await?;
656 let mut current = 0usize;
657
658 if let Some(ref tx) = progress {
659 let _ = tx
660 .send(Progress {
661 current: 0,
662 total,
663 status: None,
664 })
665 .await;
666 }
667
668 while let Some(chunk) = stream.next().await {
669 let chunk = chunk?;
670 file.write_all(&chunk).await?;
671 current += chunk.len();
672 if let Some(ref tx) = progress {
673 let _ = tx
674 .send(Progress {
675 current,
676 total,
677 status: None,
678 })
679 .await;
680 }
681 }
682
683 // Flush tokio's internal write buffer to the OS before returning.
684 // tokio::fs::File buffers writes internally; without this, the buffer
685 // may not reach the filesystem before the caller reads the file.
686 file.flush().await?;
687 Ok(())
688}
689
690impl Client {
691 /// Create a new unauthenticated client with the default saas server.
692 ///
693 /// By default, the client uses [`FileTokenStorage`] for token persistence.
694 /// Use [`with_storage`][Self::with_storage],
695 /// [`with_memory_storage`][Self::with_memory_storage],
696 /// or [`with_no_storage`][Self::with_no_storage] to configure storage
697 /// behavior.
698 ///
699 /// To connect to a different server, use [`with_server`][Self::with_server]
700 /// or [`with_token`][Self::with_token] (tokens include the server
701 /// instance).
702 ///
703 /// This client is created without a token and will need to authenticate
704 /// before using methods that require authentication.
705 ///
706 /// # Examples
707 ///
708 /// ```rust,no_run
709 /// use edgefirst_client::Client;
710 ///
711 /// # fn main() -> Result<(), edgefirst_client::Error> {
712 /// // Create client with default file storage
713 /// let client = Client::new()?;
714 ///
715 /// // Create client without token persistence
716 /// let client = Client::new()?.with_memory_storage();
717 /// # Ok(())
718 /// # }
719 /// ```
720 pub fn new() -> Result<Self, Error> {
721 log_retry_configuration();
722
723 // Get timeout from environment or use default
724 let timeout_secs = std::env::var("EDGEFIRST_TIMEOUT")
725 .ok()
726 .and_then(|s| s.parse().ok())
727 .unwrap_or(30); // Default 30s total deadline for API calls
728
729 // Per-chunk idle timeout for bulk transfers: fires only when no bytes
730 // arrive for this duration. Resets after every received chunk, so a
731 // healthy multi-GB transfer will never be interrupted.
732 let read_timeout_secs = std::env::var("EDGEFIRST_READ_TIMEOUT")
733 .ok()
734 .and_then(|s| s.parse().ok())
735 .unwrap_or(120); // Default 120s idle timeout for bulk transfers
736
737 // Create single HTTP client with URL-based retry policy
738 //
739 // The retry policy classifies requests into two categories:
740 // - StudioApi (*.edgefirst.studio/api): Fast-fail on auth errors, retry server
741 // errors
742 // - FileIO (S3, CloudFront, etc.): Retry all transient errors for robustness
743 //
744 // This allows the same client to handle both API calls and file operations
745 // with appropriate retry behavior for each. See retry.rs for details.
746 let http = reqwest::Client::builder()
747 .connect_timeout(Duration::from_secs(10))
748 .timeout(Duration::from_secs(timeout_secs))
749 .pool_idle_timeout(Duration::from_secs(90))
750 .pool_max_idle_per_host(10)
751 .retry(create_retry_policy())
752 .build()?;
753
754 // Separate HTTP client for bulk transfers (uploads and downloads).
755 // No total-request timeout (EDGEFIRST_TIMEOUT does not apply here).
756 // Uses read_timeout instead: resets after every received chunk, so a
757 // healthy large transfer is never interrupted, but a truly stalled
758 // connection (no bytes for EDGEFIRST_READ_TIMEOUT seconds) is aborted.
759 let bulk_http = reqwest::Client::builder()
760 .connect_timeout(Duration::from_secs(30))
761 .read_timeout(Duration::from_secs(read_timeout_secs))
762 .pool_idle_timeout(Duration::from_secs(90))
763 // Bulk file transfers fan out to many concurrent presigned-URL
764 // uploads — up to `EDGEFIRST_UPLOAD_BATCHES` pipelined batches ×
765 // `max_tasks()` uploads each. Keep enough idle connections warm to
766 // reuse across that fan-out instead of churning new TLS handshakes.
767 .pool_max_idle_per_host(64)
768 .retry(create_retry_policy())
769 .build()?;
770
771 // Default to file storage, loading any existing token
772 let storage: Arc<dyn TokenStorage> = match FileTokenStorage::new() {
773 Ok(file_storage) => Arc::new(file_storage),
774 Err(e) => {
775 warn!(
776 "Could not initialize file token storage: {}. Using memory storage.",
777 e
778 );
779 Arc::new(MemoryTokenStorage::new())
780 }
781 };
782
783 // Try to load existing token from storage
784 let token = match storage.load() {
785 Ok(Some(t)) => t,
786 Ok(None) => String::new(),
787 Err(e) => {
788 warn!(
789 "Failed to load token from storage: {}. Starting with empty token.",
790 e
791 );
792 String::new()
793 }
794 };
795
796 // Extract server from token if available
797 let url = if !token.is_empty() {
798 match Self::extract_server_from_token(&token) {
799 Ok(server) => format!("https://{}.edgefirst.studio", server),
800 Err(e) => {
801 warn!(
802 "Failed to extract server from token: {}. Using default server.",
803 e
804 );
805 "https://edgefirst.studio".to_string()
806 }
807 }
808 } else {
809 "https://edgefirst.studio".to_string()
810 };
811
812 Ok(Client {
813 http,
814 bulk_http,
815 url,
816 token: Arc::new(tokio::sync::RwLock::new(token)),
817 storage: Some(storage),
818 token_path: None,
819 })
820 }
821
822 /// Returns a new client connected to the specified server instance.
823 ///
824 /// The server parameter is an instance name that maps to a URL:
825 /// - `""` or `"saas"` → `https://edgefirst.studio` (default production
826 /// server)
827 /// - `"test"` → `https://test.edgefirst.studio`
828 /// - `"stage"` → `https://stage.edgefirst.studio`
829 /// - `"dev"` → `https://dev.edgefirst.studio`
830 /// - `"{name}"` → `https://{name}.edgefirst.studio`
831 ///
832 /// # Server Selection Priority
833 ///
834 /// When using the CLI or Python API, server selection follows this
835 /// priority:
836 ///
837 /// 1. **Token's server** (highest priority) - JWT tokens encode the server
838 /// they were issued for. If you have a valid token, its server is used.
839 /// 2. **`with_server()` / `--server`** - Used when logging in or when no
840 /// token is available. If a token exists with a different server, a
841 /// warning is emitted and the token's server takes priority.
842 /// 3. **Default `"saas"`** - If no token and no server specified, the
843 /// production server (`https://edgefirst.studio`) is used.
844 ///
845 /// # Important Notes
846 ///
847 /// - If a token is already set in the client, calling this method will
848 /// **drop the token** as tokens are specific to the server instance.
849 /// - Use [`parse_token_server`][Self::parse_token_server] to check a
850 /// token's server before calling this method.
851 /// - For login operations, call `with_server()` first, then authenticate.
852 ///
853 /// # Examples
854 ///
855 /// ```rust,no_run
856 /// use edgefirst_client::Client;
857 ///
858 /// # fn main() -> Result<(), edgefirst_client::Error> {
859 /// let client = Client::new()?.with_server("test")?;
860 /// assert_eq!(client.url(), "https://test.edgefirst.studio");
861 /// # Ok(())
862 /// # }
863 /// ```
864 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
865 // Resolve the target URL. Full URLs (self-hosted Studio,
866 // wiremock) are validated through `with_url` so the HTTPS rules
867 // there apply uniformly. Short names map to the SaaS pattern.
868 // We extract only the URL string and rebuild the Client below,
869 // because `with_url` preserves the in-memory token (the contract
870 // for self-hosted deployments) whereas `with_server` deliberately
871 // clears it (a different server means a stale token).
872 let url = if server.starts_with("http://") || server.starts_with("https://") {
873 self.with_url(server)?.url().to_string()
874 } else {
875 match server {
876 "" | "saas" => "https://edgefirst.studio".to_string(),
877 name => format!("https://{}.edgefirst.studio", name),
878 }
879 };
880
881 // Clear token from storage when changing servers to prevent
882 // authentication issues with stale tokens from different
883 // instances. This runs whether the caller passed a short name
884 // or a full URL — both reach a new server.
885 if let Some(ref storage) = self.storage
886 && let Err(e) = storage.clear()
887 {
888 warn!(
889 "Failed to clear token from storage when changing servers: {}",
890 e
891 );
892 }
893
894 Ok(Client {
895 url,
896 token: Arc::new(tokio::sync::RwLock::new(String::new())),
897 ..self.clone()
898 })
899 }
900
901 /// Returns a new client pointed at an explicit URL.
902 ///
903 /// Used for self-hosted Studio deployments (e.g.
904 /// `https://studio.example.com`) and for offline integration tests
905 /// against a mock HTTP server (e.g. `http://127.0.0.1:8080`). The
906 /// token is preserved so callers can chain
907 /// `Client::new()?.with_url(...)?.with_token(...)`.
908 ///
909 /// # Errors
910 ///
911 /// Returns [`Error::UrlParseError`] for syntactically invalid URLs and
912 /// [`Error::InsecureUrl`] for plain `http://` URLs that resolve to a
913 /// non-loopback host: the Studio bearer token rides in the
914 /// `Authorization` header, and plain HTTP would leak it in the clear.
915 /// Loopback URLs (`127.0.0.1`, `::1`, `localhost`, `*.localhost`) are
916 /// permitted because traffic never leaves the machine — wiremock and
917 /// local dev servers go through that path.
918 pub fn with_url(&self, url: &str) -> Result<Self, Error> {
919 // Reject malformed inputs early so test failures point at the test
920 // rather than a downstream reqwest send.
921 let parsed = url::Url::parse(url)?;
922 let scheme = parsed.scheme();
923 if scheme == "http" {
924 if !is_loopback_host(parsed.host().as_ref()) {
925 return Err(Error::InsecureUrl(url.to_string()));
926 }
927 } else if scheme != "https" {
928 return Err(Error::InsecureUrl(url.to_string()));
929 }
930 Ok(Client {
931 url: url.trim_end_matches('/').to_string(),
932 ..self.clone()
933 })
934 }
935
936 /// Returns a new client with the specified token storage backend.
937 ///
938 /// Use this to configure custom token storage, such as platform-specific
939 /// secure storage (iOS Keychain, Android EncryptedSharedPreferences).
940 ///
941 /// # Examples
942 ///
943 /// ```rust,no_run
944 /// use edgefirst_client::{Client, FileTokenStorage};
945 /// use std::{path::PathBuf, sync::Arc};
946 ///
947 /// # fn main() -> Result<(), edgefirst_client::Error> {
948 /// // Use a custom file path for token storage
949 /// let storage = FileTokenStorage::with_path(PathBuf::from("/custom/path/token"));
950 /// let client = Client::new()?.with_storage(Arc::new(storage));
951 /// # Ok(())
952 /// # }
953 /// ```
954 pub fn with_storage(self, storage: Arc<dyn TokenStorage>) -> Self {
955 // Try to load existing token from the new storage
956 let token = match storage.load() {
957 Ok(Some(t)) => t,
958 Ok(None) => String::new(),
959 Err(e) => {
960 warn!(
961 "Failed to load token from storage: {}. Starting with empty token.",
962 e
963 );
964 String::new()
965 }
966 };
967
968 Client {
969 token: Arc::new(tokio::sync::RwLock::new(token)),
970 storage: Some(storage),
971 token_path: None,
972 ..self
973 }
974 }
975
976 /// Returns a new client with in-memory token storage (no persistence).
977 ///
978 /// Tokens are stored in memory only and lost when the application exits.
979 /// This is useful for testing or when you want to manage token persistence
980 /// externally.
981 ///
982 /// # Examples
983 ///
984 /// ```rust,no_run
985 /// use edgefirst_client::Client;
986 ///
987 /// # fn main() -> Result<(), edgefirst_client::Error> {
988 /// let client = Client::new()?.with_memory_storage();
989 /// # Ok(())
990 /// # }
991 /// ```
992 pub fn with_memory_storage(self) -> Self {
993 Client {
994 token: Arc::new(tokio::sync::RwLock::new(String::new())),
995 storage: Some(Arc::new(MemoryTokenStorage::new())),
996 token_path: None,
997 ..self
998 }
999 }
1000
1001 /// Returns a new client with no token storage.
1002 ///
1003 /// Tokens are not persisted. Use this when you want to manage tokens
1004 /// entirely manually.
1005 ///
1006 /// # Examples
1007 ///
1008 /// ```rust,no_run
1009 /// use edgefirst_client::Client;
1010 ///
1011 /// # fn main() -> Result<(), edgefirst_client::Error> {
1012 /// let client = Client::new()?.with_no_storage();
1013 /// # Ok(())
1014 /// # }
1015 /// ```
1016 pub fn with_no_storage(self) -> Self {
1017 Client {
1018 storage: None,
1019 token_path: None,
1020 ..self
1021 }
1022 }
1023
1024 /// Returns a new client authenticated with the provided username and
1025 /// password.
1026 ///
1027 /// The token is automatically persisted to storage (if configured).
1028 ///
1029 /// # Examples
1030 ///
1031 /// ```rust,no_run
1032 /// use edgefirst_client::Client;
1033 ///
1034 /// # async fn example() -> Result<(), edgefirst_client::Error> {
1035 /// let client = Client::new()?
1036 /// .with_server("test")?
1037 /// .with_login("user@example.com", "password")
1038 /// .await?;
1039 /// # Ok(())
1040 /// # }
1041 /// ```
1042 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, password)))]
1043 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
1044 let params = HashMap::from([("username", username), ("password", password)]);
1045 let login: LoginResult = self
1046 .rpc_without_auth("auth.login".to_owned(), Some(params))
1047 .await?;
1048
1049 // Validate that the server returned a non-empty token
1050 if login.token.is_empty() {
1051 return Err(Error::EmptyToken);
1052 }
1053
1054 // Persist token to storage if configured
1055 if let Some(ref storage) = self.storage
1056 && let Err(e) = storage.store(&login.token)
1057 {
1058 warn!("Failed to persist token to storage: {}", e);
1059 }
1060
1061 Ok(Client {
1062 token: Arc::new(tokio::sync::RwLock::new(login.token)),
1063 ..self.clone()
1064 })
1065 }
1066
1067 /// Returns a new client which will load and save the token to the specified
1068 /// path.
1069 ///
1070 /// **Deprecated**: Use [`with_storage`][Self::with_storage] with
1071 /// [`FileTokenStorage`] instead for more flexible token management.
1072 ///
1073 /// This method is maintained for backwards compatibility with existing
1074 /// code. It disables the default storage and uses file-based storage at
1075 /// the specified path.
1076 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
1077 let token_path = match token_path {
1078 Some(path) => path.to_path_buf(),
1079 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1080 .ok_or_else(|| {
1081 Error::IoError(std::io::Error::new(
1082 std::io::ErrorKind::NotFound,
1083 "Could not determine user config directory",
1084 ))
1085 })?
1086 .config_dir()
1087 .join("token"),
1088 };
1089
1090 debug!("Using token path (legacy): {:?}", token_path);
1091
1092 let token = match token_path.exists() {
1093 true => std::fs::read_to_string(&token_path)?,
1094 false => "".to_string(),
1095 };
1096
1097 if !token.is_empty() {
1098 match self.with_token(&token) {
1099 Ok(client) => Ok(Client {
1100 token_path: Some(token_path),
1101 storage: None, // Disable new storage when using legacy token_path
1102 ..client
1103 }),
1104 Err(e) => {
1105 // Token is corrupted or invalid - remove it and continue with no token
1106 warn!(
1107 "Invalid or corrupted token file at {:?}: {:?}. Removing token file.",
1108 token_path, e
1109 );
1110 if let Err(remove_err) = std::fs::remove_file(&token_path) {
1111 warn!("Failed to remove corrupted token file: {:?}", remove_err);
1112 }
1113 // Clear any token from default storage to ensure we don't use it
1114 Ok(Client {
1115 token_path: Some(token_path),
1116 storage: None,
1117 token: Arc::new(RwLock::new("".to_string())),
1118 ..self.clone()
1119 })
1120 }
1121 }
1122 } else {
1123 // No token in the legacy file - clear any token from default storage
1124 Ok(Client {
1125 token_path: Some(token_path),
1126 storage: None,
1127 token: Arc::new(RwLock::new("".to_string())),
1128 ..self.clone()
1129 })
1130 }
1131 }
1132
1133 /// Returns a new client authenticated with the provided token.
1134 ///
1135 /// The token is automatically persisted to storage (if configured).
1136 /// The server URL is extracted from the token payload.
1137 ///
1138 /// # Examples
1139 ///
1140 /// ```rust,no_run
1141 /// use edgefirst_client::Client;
1142 ///
1143 /// # fn main() -> Result<(), edgefirst_client::Error> {
1144 /// let client = Client::new()?.with_token("your-jwt-token")?;
1145 /// # Ok(())
1146 /// # }
1147 /// ```
1148 /// Extract server name from JWT token payload.
1149 ///
1150 /// Helper method to parse the JWT token and extract the "server" field
1151 /// from the payload. Returns the server name (e.g., "test", "stage", "")
1152 /// or an error if the token is invalid.
1153 fn extract_server_from_token(token: &str) -> Result<String, Error> {
1154 let token_parts: Vec<&str> = token.split('.').collect();
1155 if token_parts.len() != 3 {
1156 return Err(Error::InvalidToken);
1157 }
1158
1159 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1160 .decode(token_parts[1])
1161 .map_err(|_| Error::InvalidToken)?;
1162 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1163 let server = match payload.get("server") {
1164 Some(value) => value.as_str().ok_or(Error::InvalidToken)?.to_string(),
1165 None => return Err(Error::InvalidToken),
1166 };
1167
1168 Ok(server)
1169 }
1170
1171 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
1172 if token.is_empty() {
1173 return Ok(self.clone());
1174 }
1175
1176 let server = Self::extract_server_from_token(token)?;
1177
1178 // Persist token to storage if configured
1179 if let Some(ref storage) = self.storage
1180 && let Err(e) = storage.store(token)
1181 {
1182 warn!("Failed to persist token to storage: {}", e);
1183 }
1184
1185 Ok(Client {
1186 url: format!("https://{}.edgefirst.studio", server),
1187 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
1188 ..self.clone()
1189 })
1190 }
1191
1192 /// Persist the current token to storage.
1193 ///
1194 /// This is automatically called when using [`with_login`][Self::with_login]
1195 /// or [`with_token`][Self::with_token], so you typically don't need to call
1196 /// this directly.
1197 ///
1198 /// If using the legacy `token_path` configuration, saves to the file path.
1199 /// If using the new storage abstraction, saves to the configured storage.
1200 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1201 pub async fn save_token(&self) -> Result<(), Error> {
1202 let token = self.token.read().await;
1203
1204 // Try new storage first
1205 if let Some(ref storage) = self.storage {
1206 storage.store(&token)?;
1207 debug!("Token saved to storage");
1208 return Ok(());
1209 }
1210
1211 // Fall back to legacy token_path behavior
1212 let path = self.token_path.clone().unwrap_or_else(|| {
1213 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
1214 .map(|dirs| dirs.config_dir().join("token"))
1215 .unwrap_or_else(|| PathBuf::from(".token"))
1216 });
1217
1218 create_dir_all(path.parent().ok_or_else(|| {
1219 Error::IoError(std::io::Error::new(
1220 std::io::ErrorKind::InvalidInput,
1221 "Token path has no parent directory",
1222 ))
1223 })?)?;
1224 let mut file = std::fs::File::create(&path)?;
1225 file.write_all(token.as_bytes())?;
1226
1227 debug!("Saved token to {:?}", path);
1228
1229 Ok(())
1230 }
1231
1232 /// Return the version of the EdgeFirst Studio server for the current
1233 /// client connection.
1234 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1235 pub async fn version(&self) -> Result<String, Error> {
1236 let version: HashMap<String, String> = self
1237 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
1238 .await?;
1239 let version = version.get("version").ok_or(Error::InvalidResponse)?;
1240 Ok(version.to_owned())
1241 }
1242
1243 /// Clear the token used to authenticate the client with the server.
1244 ///
1245 /// Clears the token from memory and from storage (if configured).
1246 /// If using the legacy `token_path` configuration, removes the token file.
1247 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1248 pub async fn logout(&self) -> Result<(), Error> {
1249 {
1250 let mut token = self.token.write().await;
1251 *token = "".to_string();
1252 }
1253
1254 // Clear from new storage if configured
1255 if let Some(ref storage) = self.storage
1256 && let Err(e) = storage.clear()
1257 {
1258 warn!("Failed to clear token from storage: {}", e);
1259 }
1260
1261 // Also clear legacy token_path if configured
1262 if let Some(path) = &self.token_path
1263 && path.exists()
1264 {
1265 fs::remove_file(path).await?;
1266 }
1267
1268 Ok(())
1269 }
1270
1271 /// Return the token used to authenticate the client with the server. When
1272 /// logging into the server using a username and password, the token is
1273 /// returned by the server and stored in the client for future interactions.
1274 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1275 pub async fn token(&self) -> String {
1276 self.token.read().await.clone()
1277 }
1278
1279 /// Verify the token used to authenticate the client with the server. This
1280 /// method is used to ensure that the token is still valid and has not
1281 /// expired. If the token is invalid, the server will return an error and
1282 /// the client will need to login again.
1283 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1284 pub async fn verify_token(&self) -> Result<(), Error> {
1285 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
1286 .await?;
1287 Ok::<(), Error>(())
1288 }
1289
1290 /// Renew the token used to authenticate the client with the server.
1291 ///
1292 /// Refreshes the token before it expires. If the token has already expired,
1293 /// the server will return an error and you will need to login again.
1294 ///
1295 /// The new token is automatically persisted to storage (if configured).
1296 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1297 pub async fn renew_token(&self) -> Result<(), Error> {
1298 let params = HashMap::from([("username".to_string(), self.username().await?)]);
1299 let result: LoginResult = self
1300 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
1301 .await?;
1302
1303 {
1304 let mut token = self.token.write().await;
1305 *token = result.token.clone();
1306 }
1307
1308 // Persist to new storage if configured
1309 if let Some(ref storage) = self.storage
1310 && let Err(e) = storage.store(&result.token)
1311 {
1312 warn!("Failed to persist renewed token to storage: {}", e);
1313 }
1314
1315 // Also persist to legacy token_path if configured
1316 if self.token_path.is_some() {
1317 self.save_token().await?;
1318 }
1319
1320 Ok(())
1321 }
1322
1323 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
1324 let token = self.token.read().await;
1325 if token.is_empty() {
1326 return Err(Error::EmptyToken);
1327 }
1328
1329 let token_parts: Vec<&str> = token.split('.').collect();
1330 if token_parts.len() != 3 {
1331 return Err(Error::InvalidToken);
1332 }
1333
1334 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
1335 .decode(token_parts[1])
1336 .map_err(|_| Error::InvalidToken)?;
1337 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
1338 match payload.get(field) {
1339 Some(value) => Ok(value.to_owned()),
1340 None => Err(Error::InvalidToken),
1341 }
1342 }
1343
1344 /// Returns the URL of the EdgeFirst Studio server for the current client.
1345 pub fn url(&self) -> &str {
1346 &self.url
1347 }
1348
1349 /// Returns the server name for the current client.
1350 ///
1351 /// This extracts the server name from the client's URL:
1352 /// - `https://edgefirst.studio` → `"saas"`
1353 /// - `https://test.edgefirst.studio` → `"test"`
1354 /// - `https://{name}.edgefirst.studio` → `"{name}"`
1355 ///
1356 /// # Examples
1357 ///
1358 /// ```rust,no_run
1359 /// use edgefirst_client::Client;
1360 ///
1361 /// # fn main() -> Result<(), edgefirst_client::Error> {
1362 /// let client = Client::new()?.with_server("test")?;
1363 /// assert_eq!(client.server(), "test");
1364 ///
1365 /// let client = Client::new()?; // default
1366 /// assert_eq!(client.server(), "saas");
1367 /// # Ok(())
1368 /// # }
1369 /// ```
1370 pub fn server(&self) -> &str {
1371 if self.url == "https://edgefirst.studio" {
1372 "saas"
1373 } else if let Some(name) = self.url.strip_prefix("https://") {
1374 name.strip_suffix(".edgefirst.studio").unwrap_or("saas")
1375 } else {
1376 "saas"
1377 }
1378 }
1379
1380 /// Returns the username associated with the current token.
1381 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1382 pub async fn username(&self) -> Result<String, Error> {
1383 match self.token_field("username").await? {
1384 serde_json::Value::String(username) => Ok(username),
1385 _ => Err(Error::InvalidToken),
1386 }
1387 }
1388
1389 /// Returns the expiration time for the current token.
1390 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1391 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
1392 let ts = match self.token_field("exp").await? {
1393 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
1394 _ => return Err(Error::InvalidToken),
1395 };
1396
1397 match DateTime::<Utc>::from_timestamp(ts, 0) {
1398 Some(dt) => Ok(dt),
1399 None => Err(Error::InvalidToken),
1400 }
1401 }
1402
1403 /// Returns the organization information for the current user.
1404 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1405 pub async fn organization(&self) -> Result<Organization, Error> {
1406 self.rpc::<(), Organization>("org.get".to_owned(), None)
1407 .await
1408 }
1409
1410 /// Returns a list of projects available to the user. The projects are
1411 /// returned as a vector of Project objects. If a name filter is
1412 /// provided, only projects matching the filter are returned.
1413 ///
1414 /// Results are sorted by match quality: exact matches first, then
1415 /// case-insensitive exact matches, then shorter names (more specific),
1416 /// then alphabetically.
1417 ///
1418 /// Projects are the top-level organizational unit in EdgeFirst Studio.
1419 /// Projects contain datasets, trainers, and trainer sessions. Projects
1420 /// are used to group related datasets and trainers together.
1421 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1422 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
1423 let projects = self
1424 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
1425 .await?;
1426 if let Some(name) = name {
1427 Ok(filter_and_sort_by_name(projects, name, |p| p.name()))
1428 } else {
1429 Ok(projects)
1430 }
1431 }
1432
1433 /// Return the project with the specified project ID. If the project does
1434 /// not exist, an error is returned.
1435 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(project_id = %project_id)))]
1436 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
1437 let params = HashMap::from([("project_id", project_id)]);
1438 self.rpc("project.get".to_owned(), Some(params)).await
1439 }
1440
1441 /// Returns a list of datasets available to the user. The datasets are
1442 /// returned as a vector of Dataset objects. If a name filter is
1443 /// provided, only datasets matching the filter are returned.
1444 ///
1445 /// Results are sorted by match quality: exact matches first, then
1446 /// case-insensitive exact matches, then shorter names (more specific),
1447 /// then alphabetically. This ensures "Deer" returns before "Deer
1448 /// Roundtrip".
1449 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1450 pub async fn datasets(
1451 &self,
1452 project_id: ProjectID,
1453 name: Option<&str>,
1454 ) -> Result<Vec<Dataset>, Error> {
1455 let params = HashMap::from([("project_id", project_id)]);
1456 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
1457 if let Some(name) = name {
1458 Ok(filter_and_sort_by_name(datasets, name, |d| d.name()))
1459 } else {
1460 Ok(datasets)
1461 }
1462 }
1463
1464 /// Return the dataset with the specified dataset ID. If the dataset does
1465 /// not exist, an error is returned.
1466 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1467 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
1468 let params = HashMap::from([("dataset_id", dataset_id)]);
1469 self.rpc("dataset.get".to_owned(), Some(params)).await
1470 }
1471
1472 /// Lists the labels for the specified dataset.
1473 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1474 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
1475 let params = HashMap::from([("dataset_id", dataset_id)]);
1476 self.rpc("label.list".to_owned(), Some(params)).await
1477 }
1478
1479 /// Add a new label to the dataset with the specified name.
1480 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1481 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
1482 self.add_labels(dataset_id, std::slice::from_ref(&name.to_owned()))
1483 .await
1484 }
1485
1486 /// Add multiple labels to the dataset in a single request.
1487 ///
1488 /// Equivalent to calling [`add_label`](Self::add_label) for each name but in
1489 /// one round-trip. Useful before a bulk/concurrent upload: pre-creating the
1490 /// full label set serially avoids many concurrent `populate2` calls racing to
1491 /// create the same label server-side. Names already present are not
1492 /// duplicated by the server. A no-op when `names` is empty.
1493 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, names), fields(dataset_id = %dataset_id, count = names.len())))]
1494 pub async fn add_labels(&self, dataset_id: DatasetID, names: &[String]) -> Result<(), Error> {
1495 if names.is_empty() {
1496 return Ok(());
1497 }
1498
1499 let existing = self.labels(dataset_id).await?;
1500 let existing_names: std::collections::HashSet<String> =
1501 existing.iter().map(|l| l.name().to_string()).collect();
1502
1503 let to_create: Vec<&String> = names
1504 .iter()
1505 .filter(|name| !existing_names.contains(name.as_str()))
1506 .collect();
1507
1508 if to_create.is_empty() {
1509 return Ok(());
1510 }
1511
1512 let new_label = NewLabel {
1513 dataset_id,
1514 labels: to_create
1515 .iter()
1516 .map(|name| NewLabelObject {
1517 name: (*name).clone(),
1518 })
1519 .collect(),
1520 };
1521 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1522 Ok(())
1523 }
1524
1525 /// Add a label with a caller-specified source-faithful index.
1526 ///
1527 /// Thin wrapper around [`add_labels_with_indices`](Self::add_labels_with_indices)
1528 /// for single-label use. The `index` is preserved by assigning it via
1529 /// `label.update` after creation, enabling round-trips through COCO or other
1530 /// formats where category IDs are not contiguous starting at zero.
1531 ///
1532 /// # Arguments
1533 ///
1534 /// * `dataset_id` - The dataset to add the label to
1535 /// * `name` - Label name (must be unique within the dataset)
1536 /// * `index` - The `label_index` to assign (e.g. COCO `category_id`)
1537 ///
1538 /// # Returns
1539 ///
1540 /// Returns `Ok(())` on success, or an error if the index is already held by
1541 /// a different label on the server.
1542 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1543 pub async fn add_label_with_index(
1544 &self,
1545 dataset_id: DatasetID,
1546 name: &str,
1547 index: u64,
1548 ) -> Result<(), Error> {
1549 let names = [name.to_owned()];
1550 let indices = [Some(index)];
1551 self.add_labels_with_indices(dataset_id, &names, &indices)
1552 .await
1553 }
1554
1555 /// Add multiple labels, optionally assigning source-faithful table indices.
1556 ///
1557 /// Creates missing labels via `label.add2` (names only), then assigns indices
1558 /// via a two-pass `label.update` for entries where `indices[i]` is `Some`.
1559 /// Each `None` leaves that label at the server-assigned index. The two-pass
1560 /// strategy avoids index collisions when labels within the same batch would
1561 /// swap positions. Names already present on the server are not duplicated.
1562 ///
1563 /// # Arguments
1564 ///
1565 /// * `dataset_id` - The dataset to add labels to
1566 /// * `names` - Label names to create (existing names are skipped)
1567 /// * `indices` - Parallel slice of optional indices; `None` means use server default
1568 ///
1569 /// # Returns
1570 ///
1571 /// Returns `Ok(())` on success. A no-op if `names` is empty.
1572 ///
1573 /// # Errors
1574 ///
1575 /// Returns `Error::InvalidParameters` if `names` and `indices` have different
1576 /// lengths, if any desired index conflicts with an existing unrelated label,
1577 /// or if the batch contains duplicate index values.
1578 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, names, indices), fields(dataset_id = %dataset_id, count = names.len())))]
1579 pub async fn add_labels_with_indices(
1580 &self,
1581 dataset_id: DatasetID,
1582 names: &[String],
1583 indices: &[Option<u64>],
1584 ) -> Result<(), Error> {
1585 if names.is_empty() {
1586 return Ok(());
1587 }
1588
1589 if indices.len() != names.len() {
1590 return Err(Error::InvalidParameters(format!(
1591 "add_labels_with_indices: names and indices length mismatch ({} vs {})",
1592 names.len(),
1593 indices.len()
1594 )));
1595 }
1596
1597 Self::validate_label_batch(names, Some(indices))?;
1598
1599 let existing = self.labels(dataset_id).await?;
1600 let existing_names: std::collections::HashSet<String> =
1601 existing.iter().map(|l| l.name().to_string()).collect();
1602
1603 let to_create: Vec<&String> = names
1604 .iter()
1605 .filter(|name| !existing_names.contains(name.as_str()))
1606 .collect();
1607
1608 if !to_create.is_empty() {
1609 let new_label = NewLabel {
1610 dataset_id,
1611 labels: to_create
1612 .iter()
1613 .map(|name| NewLabelObject {
1614 name: (*name).clone(),
1615 })
1616 .collect(),
1617 };
1618 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
1619 }
1620
1621 self.apply_label_indices(dataset_id, names, indices).await
1622 }
1623
1624 /// Removes the label with the specified ID from the dataset. Label IDs are
1625 /// globally unique so the dataset_id is not required.
1626 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1627 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
1628 let params = HashMap::from([("label_id", label_id)]);
1629 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
1630 Ok(())
1631 }
1632
1633 /// Creates a new dataset in the specified project.
1634 ///
1635 /// # Arguments
1636 ///
1637 /// * `project_id` - The ID of the project to create the dataset in
1638 /// * `name` - The name of the new dataset
1639 /// * `description` - Optional description for the dataset
1640 ///
1641 /// # Returns
1642 ///
1643 /// Returns the dataset ID of the newly created dataset.
1644 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
1645 pub async fn create_dataset(
1646 &self,
1647 project_id: &str,
1648 name: &str,
1649 description: Option<&str>,
1650 ) -> Result<DatasetID, Error> {
1651 let mut params = HashMap::new();
1652 params.insert("project_id", project_id);
1653 params.insert("name", name);
1654 if let Some(desc) = description {
1655 params.insert("description", desc);
1656 }
1657
1658 #[derive(Deserialize)]
1659 struct CreateDatasetResult {
1660 id: DatasetID,
1661 }
1662
1663 let result: CreateDatasetResult =
1664 self.rpc("dataset.create".to_owned(), Some(params)).await?;
1665 Ok(result.id)
1666 }
1667
1668 /// Deletes a dataset by marking it as deleted.
1669 ///
1670 /// # Arguments
1671 ///
1672 /// * `dataset_id` - The ID of the dataset to delete
1673 ///
1674 /// # Returns
1675 ///
1676 /// Returns `Ok(())` if the dataset was successfully marked as deleted.
1677 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1678 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
1679 let params = HashMap::from([("id", dataset_id)]);
1680 let _: serde_json::Value = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
1681 Ok(())
1682 }
1683
1684 /// Updates the label with the specified ID to have the new name or index.
1685 /// Label IDs cannot be changed. Label IDs are globally unique so the
1686 /// dataset_id is not required.
1687 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, label)))]
1688 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
1689 #[derive(Serialize)]
1690 struct Params {
1691 dataset_id: DatasetID,
1692 label_id: u64,
1693 label_name: String,
1694 label_index: u64,
1695 }
1696
1697 let _: String = self
1698 .rpc(
1699 "label.update".to_owned(),
1700 Some(Params {
1701 dataset_id: label.dataset_id(),
1702 label_id: label.id(),
1703 label_name: label.name().to_owned(),
1704 label_index: label.index(),
1705 }),
1706 )
1707 .await?;
1708 Ok(())
1709 }
1710
1711 /// Temporary offset for the two-pass label index assignment (avoids collisions
1712 /// during reassignment). Chosen to clear real COCO/LVIS category IDs (up to ~1723).
1713 const LABEL_INDEX_ASSIGN_TEMP_OFFSET: u64 = 100_000;
1714
1715 /// Collect parallel label name/index arrays from upload samples for
1716 /// [`add_labels_with_indices`](Self::add_labels_with_indices).
1717 ///
1718 /// Annotations without `label_index` contribute `None` at the matching position.
1719 /// Returns an error if the same label name maps to different indices.
1720 pub fn collect_labels_from_samples(
1721 samples: &[Sample],
1722 ) -> Result<(Vec<String>, Vec<Option<u64>>), Error> {
1723 let mut specs: HashMap<String, Option<u64>> = HashMap::new();
1724 let mut order: Vec<String> = Vec::new();
1725 for annotation in samples.iter().flat_map(|s| s.annotations()) {
1726 let Some(name) = annotation.label() else {
1727 continue;
1728 };
1729 match (specs.get(name), annotation.label_index()) {
1730 (Some(&Some(existing)), Some(index)) if existing != index => {
1731 return Err(Error::InvalidParameters(format!(
1732 "inconsistent label_index for '{name}': {existing} vs {index}"
1733 )));
1734 }
1735 (Some(&Some(_)), _) => {}
1736 (Some(&None), Some(index)) => {
1737 specs.insert(name.clone(), Some(index));
1738 }
1739 (None, Some(index)) => {
1740 order.push(name.clone());
1741 specs.insert(name.clone(), Some(index));
1742 }
1743 (None, None) => {
1744 order.push(name.clone());
1745 specs.insert(name.clone(), None);
1746 }
1747 (Some(&None), None) => {}
1748 }
1749 }
1750 let indices: Vec<Option<u64>> = order.iter().map(|name| specs[name]).collect();
1751 Ok((order, indices))
1752 }
1753
1754 /// Validate label batch: unique names and unique indices among entries with `Some(index)`.
1755 fn validate_label_batch(
1756 names: &[String],
1757 indices: Option<&[Option<u64>]>,
1758 ) -> Result<(), Error> {
1759 let mut seen_names = HashMap::new();
1760 let mut index_to_name = HashMap::new();
1761 for (i, name) in names.iter().enumerate() {
1762 if seen_names.insert(name.as_str(), ()).is_some() {
1763 return Err(Error::InvalidParameters(format!(
1764 "duplicate label name '{name}'"
1765 )));
1766 }
1767 if let Some(indices) = indices
1768 && let Some(index) = indices[i]
1769 && let Some(other) = index_to_name.insert(index, name.as_str())
1770 {
1771 return Err(Error::InvalidParameters(format!(
1772 "duplicate label_index {index} for labels '{other}' and '{name}'"
1773 )));
1774 }
1775 }
1776 Ok(())
1777 }
1778
1779 /// Assign label table indices (two-pass update).
1780 async fn apply_label_indices(
1781 &self,
1782 dataset_id: DatasetID,
1783 names: &[String],
1784 indices: &[Option<u64>],
1785 ) -> Result<(), Error> {
1786 let batch_names: HashMap<&str, ()> = names.iter().map(|n| (n.as_str(), ())).collect();
1787
1788 let with_index: HashMap<&str, u64> = names
1789 .iter()
1790 .zip(indices.iter())
1791 .filter_map(|(name, index)| index.map(|i| (name.as_str(), i)))
1792 .collect();
1793
1794 if with_index.is_empty() {
1795 return Ok(());
1796 }
1797
1798 let current = self.labels(dataset_id).await?;
1799 let by_name: HashMap<String, Label> = current
1800 .iter()
1801 .map(|l| (l.name().to_string(), l.clone()))
1802 .collect();
1803
1804 let mut to_sync = Vec::new();
1805 for (name, &target_index) in &with_index {
1806 let label = by_name.get(*name).ok_or_else(|| {
1807 Error::InvalidParameters(format!(
1808 "label '{name}' not found in dataset after label.add2"
1809 ))
1810 })?;
1811 if label.index() != target_index {
1812 to_sync.push((name.to_string(), target_index));
1813 }
1814 }
1815
1816 if to_sync.is_empty() {
1817 return Ok(());
1818 }
1819
1820 // Unrelated labels (not in this batch) occupying a target index block reassignment.
1821 for (name, target_index) in &to_sync {
1822 for label in ¤t {
1823 if label.index() == *target_index
1824 && label.name() != name.as_str()
1825 && !batch_names.contains_key(label.name())
1826 {
1827 return Err(Error::InvalidParameters(format!(
1828 "label_index {target_index} already used by '{}' \
1829 (needed for '{name}'); use a clean dataset or resolve the conflict",
1830 label.name()
1831 )));
1832 }
1833 }
1834 // Batch labels without an explicit index that occupy the target block reassignment.
1835 for (batch_name, batch_index) in names.iter().zip(indices.iter()) {
1836 if batch_index.is_some() || batch_name == name {
1837 continue;
1838 }
1839 if let Some(label) = by_name.get(batch_name.as_str())
1840 && label.index() == *target_index
1841 {
1842 return Err(Error::InvalidParameters(format!(
1843 "label '{batch_name}' occupies label_index {target_index} \
1844 (needed for '{name}') but no index was specified; \
1845 assign explicit indices for all labels in the batch or use a clean dataset"
1846 )));
1847 }
1848 }
1849 }
1850
1851 // Compute and validate temporary staging indices before any server writes.
1852 // checked_add guards against caller-supplied target_index values large enough
1853 // to wrap u64 when the offset is added. The occupancy check ensures no label
1854 // outside the batch already sits at the temp slot (it would be displaced by
1855 // the first pass and potentially clobber the second pass).
1856 let mut staged: Vec<(String, u64, u64)> = Vec::with_capacity(to_sync.len());
1857 for (name, target_index) in &to_sync {
1858 let temp_index = Self::LABEL_INDEX_ASSIGN_TEMP_OFFSET
1859 .checked_add(*target_index)
1860 .ok_or_else(|| {
1861 Error::InvalidParameters(format!(
1862 "label_index {target_index} for '{name}' is too large: \
1863 adding the staging offset would overflow u64"
1864 ))
1865 })?;
1866 for label in ¤t {
1867 if label.index() == temp_index && !batch_names.contains_key(label.name()) {
1868 return Err(Error::InvalidParameters(format!(
1869 "staging index {temp_index} (needed to move '{name}' to \
1870 index {target_index}) is already occupied by label '{}'; \
1871 use a clean dataset or resolve the conflict",
1872 label.name()
1873 )));
1874 }
1875 }
1876 staged.push((name.clone(), *target_index, temp_index));
1877 }
1878
1879 for (name, _, temp_index) in &staged {
1880 let mut label = by_name.get(name).cloned().expect("validated above");
1881 label.set_index(self, *temp_index).await?;
1882 }
1883
1884 for (name, target_index, _) in &staged {
1885 let mut label = by_name.get(name).cloned().expect("validated above");
1886 label.set_index(self, *target_index).await?;
1887 }
1888
1889 Ok(())
1890 }
1891
1892 /// Lists the groups for the specified dataset.
1893 ///
1894 /// Groups are used to organize samples into logical subsets such as
1895 /// "train", "val", "test", etc. Each sample can belong to at most one
1896 /// group at a time.
1897 ///
1898 /// # Arguments
1899 ///
1900 /// * `dataset_id` - The ID of the dataset to list groups for
1901 ///
1902 /// # Returns
1903 ///
1904 /// Returns a vector of [`Group`] objects for the dataset. Returns an
1905 /// empty vector if no groups have been created yet.
1906 ///
1907 /// # Errors
1908 ///
1909 /// Returns an error if the dataset does not exist or cannot be accessed.
1910 ///
1911 /// # Example
1912 ///
1913 /// ```rust,no_run
1914 /// # use edgefirst_client::{Client, DatasetID};
1915 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1916 /// let client = Client::new()?.with_token_path(None)?;
1917 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1918 ///
1919 /// let groups = client.groups(dataset_id).await?;
1920 /// for group in groups {
1921 /// println!("{}: {}", group.id, group.name);
1922 /// }
1923 /// # Ok(())
1924 /// # }
1925 /// ```
1926 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1927 pub async fn groups(&self, dataset_id: DatasetID) -> Result<Vec<Group>, Error> {
1928 let params = HashMap::from([("dataset_id", dataset_id)]);
1929 self.rpc("groups.list".to_owned(), Some(params)).await
1930 }
1931
1932 /// Gets an existing group by name or creates a new one.
1933 ///
1934 /// This is a convenience method that first checks if a group with the
1935 /// specified name exists, and creates it if not. This is useful when
1936 /// you need to ensure a group exists before assigning samples to it.
1937 ///
1938 /// # Arguments
1939 ///
1940 /// * `dataset_id` - The ID of the dataset
1941 /// * `name` - The name of the group (e.g., "train", "val", "test")
1942 ///
1943 /// # Returns
1944 ///
1945 /// Returns the group ID (either existing or newly created).
1946 ///
1947 /// # Errors
1948 ///
1949 /// Returns an error if:
1950 /// - The dataset does not exist or cannot be accessed
1951 /// - The group creation fails
1952 ///
1953 /// # Concurrency
1954 ///
1955 /// This method handles concurrent creation attempts gracefully. If another
1956 /// process creates the group between the existence check and creation,
1957 /// this method will return the existing group's ID.
1958 ///
1959 /// # Example
1960 ///
1961 /// ```rust,no_run
1962 /// # use edgefirst_client::{Client, DatasetID};
1963 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
1964 /// let client = Client::new()?.with_token_path(None)?;
1965 /// let dataset_id: DatasetID = "ds-123".try_into()?;
1966 ///
1967 /// // Get or create a "train" group
1968 /// let train_group_id = client
1969 /// .get_or_create_group(dataset_id.clone(), "train")
1970 /// .await?;
1971 /// println!("Train group ID: {}", train_group_id);
1972 ///
1973 /// // Calling again returns the same ID
1974 /// let same_id = client.get_or_create_group(dataset_id, "train").await?;
1975 /// assert_eq!(train_group_id, same_id);
1976 /// # Ok(())
1977 /// # }
1978 /// ```
1979 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
1980 pub async fn get_or_create_group(
1981 &self,
1982 dataset_id: DatasetID,
1983 name: &str,
1984 ) -> Result<u64, Error> {
1985 // First check if the group already exists
1986 let groups = self.groups(dataset_id).await?;
1987 if let Some(group) = groups.iter().find(|g| g.name == name) {
1988 return Ok(group.id);
1989 }
1990
1991 // Create the group
1992 #[derive(Serialize)]
1993 struct CreateGroupParams {
1994 dataset_id: DatasetID,
1995 group_names: Vec<String>,
1996 group_splits: Vec<i64>,
1997 }
1998
1999 let params = CreateGroupParams {
2000 dataset_id,
2001 group_names: vec![name.to_string()],
2002 group_splits: vec![0], // No automatic splitting
2003 };
2004
2005 let created_groups: Vec<Group> = self.rpc("groups.create".to_owned(), Some(params)).await?;
2006 if let Some(group) = created_groups.into_iter().find(|g| g.name == name) {
2007 Ok(group.id)
2008 } else {
2009 // Group might have been created by concurrent call, try fetching again
2010 let groups = self.groups(dataset_id).await?;
2011 groups
2012 .iter()
2013 .find(|g| g.name == name)
2014 .map(|g| g.id)
2015 .ok_or_else(|| {
2016 Error::RpcError(0, format!("Failed to create or find group '{}'", name))
2017 })
2018 }
2019 }
2020
2021 /// Sets the group for a sample.
2022 ///
2023 /// Assigns a sample to a specific group. Each sample can belong to at most
2024 /// one group at a time. Setting a new group replaces any existing group
2025 /// assignment.
2026 ///
2027 /// # Arguments
2028 ///
2029 /// * `sample_id` - The ID of the sample (image) to update
2030 /// * `group_id` - The ID of the group to assign. Use
2031 /// [`get_or_create_group`] to obtain a group ID from a name.
2032 ///
2033 /// # Returns
2034 ///
2035 /// Returns `Ok(())` on success.
2036 ///
2037 /// # Errors
2038 ///
2039 /// Returns an error if:
2040 /// - The sample does not exist
2041 /// - The group does not exist
2042 /// - Insufficient permissions to modify the sample
2043 ///
2044 /// # Example
2045 ///
2046 /// ```rust,no_run
2047 /// # use edgefirst_client::{Client, DatasetID, SampleID};
2048 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2049 /// let client = Client::new()?.with_token_path(None)?;
2050 /// let dataset_id: DatasetID = "ds-123".try_into()?;
2051 /// let sample_id: SampleID = 12345.into();
2052 ///
2053 /// // Get or create the "val" group
2054 /// let val_group_id = client.get_or_create_group(dataset_id, "val").await?;
2055 ///
2056 /// // Assign the sample to the "val" group
2057 /// client.set_sample_group_id(sample_id, val_group_id).await?;
2058 /// # Ok(())
2059 /// # }
2060 /// ```
2061 ///
2062 /// [`get_or_create_group`]: Self::get_or_create_group
2063 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2064 pub async fn set_sample_group_id(
2065 &self,
2066 sample_id: SampleID,
2067 group_id: u64,
2068 ) -> Result<(), Error> {
2069 #[derive(Serialize)]
2070 struct SetGroupParams {
2071 image_id: SampleID,
2072 group_id: u64,
2073 }
2074
2075 let params = SetGroupParams {
2076 image_id: sample_id,
2077 group_id,
2078 };
2079 let _: String = self
2080 .rpc("image.set_group_id".to_owned(), Some(params))
2081 .await?;
2082 Ok(())
2083 }
2084
2085 /// Downloads dataset samples to the local filesystem.
2086 ///
2087 /// # Arguments
2088 ///
2089 /// * `dataset_id` - The unique identifier of the dataset
2090 /// * `groups` - Dataset groups to include (e.g., "train", "val")
2091 /// * `file_types` - File types to download. Supported types:
2092 /// - `FileType::Image` - Standard image files (JPEG, PNG, etc.)
2093 /// - `FileType::LidarPcd` - LiDAR point cloud data (.pcd format)
2094 /// - `FileType::LidarDepth` - LiDAR depth images (.png format)
2095 /// - `FileType::LidarReflect` - LiDAR reflectance images (.jpg format)
2096 /// - `FileType::RadarPcd` - Radar point cloud data (.pcd format)
2097 /// - `FileType::RadarCube` - Radar cube data (.png format)
2098 /// - `FileType::All` - All sensor types (expands to all of the above)
2099 /// * `output` - Local directory to save downloaded files
2100 /// * `flatten` - If true, download all files to output root without
2101 /// sequence subdirectories. When flattening, filenames are prefixed with
2102 /// `{sequence_name}_{frame}_` (or `{sequence_name}_` if frame is
2103 /// unavailable) unless the filename already starts with
2104 /// `{sequence_name}_`, to avoid conflicts between sequences.
2105 /// * `progress` - Optional channel for progress updates
2106 ///
2107 /// # Progress
2108 ///
2109 /// This operation has two phases with distinct progress reporting:
2110 ///
2111 /// 1. **Fetching metadata** (`status: None`): Retrieves sample information
2112 /// from the server. Progress counts samples fetched.
2113 /// 2. **Downloading files** (`status: "Downloading"`): Downloads actual
2114 /// files to disk. Progress counts samples completed (each sample may
2115 /// have multiple files for different sensor types).
2116 ///
2117 /// Applications should detect the status change from `None` to
2118 /// `"Downloading"` to reset their progress bar for the second phase.
2119 ///
2120 /// # Returns
2121 ///
2122 /// Returns `Ok(())` on success or an error if download fails.
2123 ///
2124 /// # Example
2125 ///
2126 /// ```rust,no_run
2127 /// # use edgefirst_client::{Client, DatasetID, FileType};
2128 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
2129 /// let client = Client::new()?.with_token_path(None)?;
2130 /// let dataset_id: DatasetID = "ds-123".try_into()?;
2131 ///
2132 /// // Download with sequence subdirectories (default)
2133 /// client
2134 /// .download_dataset(
2135 /// dataset_id,
2136 /// &[],
2137 /// &[FileType::Image],
2138 /// "./data".into(),
2139 /// false,
2140 /// None,
2141 /// )
2142 /// .await?;
2143 ///
2144 /// // Download flattened (all files in one directory)
2145 /// client
2146 /// .download_dataset(
2147 /// dataset_id,
2148 /// &[],
2149 /// &[FileType::Image],
2150 /// "./data".into(),
2151 /// true,
2152 /// None,
2153 /// )
2154 /// .await?;
2155 ///
2156 /// // Download all sensor types
2157 /// client
2158 /// .download_dataset(
2159 /// dataset_id,
2160 /// &[],
2161 /// &FileType::expand_types(&[FileType::All]),
2162 /// "./data".into(),
2163 /// false,
2164 /// None,
2165 /// )
2166 /// .await?;
2167 /// # Ok(())
2168 /// # }
2169 /// ```
2170 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, groups, file_types, progress), fields(dataset_id = %dataset_id, output = %output.display())))]
2171 pub async fn download_dataset(
2172 &self,
2173 dataset_id: DatasetID,
2174 groups: &[String],
2175 file_types: &[FileType],
2176 output: PathBuf,
2177 flatten: bool,
2178 progress: Option<Sender<Progress>>,
2179 ) -> Result<(), Error> {
2180 // Phase 1: Fetch sample metadata (pass progress directly, no wrapper)
2181 let samples = self
2182 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
2183 .await?;
2184 fs::create_dir_all(&output).await?;
2185
2186 // Phase 2: Download actual files using direct semaphore pattern
2187 let total = samples.len();
2188 let current = Arc::new(AtomicUsize::new(0));
2189 let sem = Arc::new(Semaphore::new(max_tasks()));
2190
2191 // Send initial progress for download phase
2192 if let Some(ref progress) = progress {
2193 let _ = progress
2194 .send(Progress {
2195 current: 0,
2196 total,
2197 status: Some("Downloading".to_string()),
2198 })
2199 .await;
2200 }
2201
2202 let tasks = samples
2203 .into_iter()
2204 .map(|sample| {
2205 let client = self.clone();
2206 let file_types = file_types.to_vec();
2207 let output = output.clone();
2208 let progress = progress.clone();
2209 let current = current.clone();
2210 let sem = sem.clone();
2211
2212 tokio::spawn(async move {
2213 let _permit = sem.acquire().await.map_err(|_| {
2214 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
2215 })?;
2216
2217 for file_type in &file_types {
2218 if let Some(data) = sample.download(&client, file_type.clone()).await? {
2219 let (file_ext, is_image) = match file_type {
2220 FileType::Image => (
2221 infer::get(&data)
2222 .expect("Failed to identify image file format for sample")
2223 .extension()
2224 .to_string(),
2225 true,
2226 ),
2227 other => (other.file_extension().to_string(), false),
2228 };
2229
2230 // Determine target directory based on sequence membership and
2231 // flatten option
2232 // - flatten=false + sequence_name: dataset/sequence_name/
2233 // - flatten=false + no sequence: dataset/ (root level)
2234 // - flatten=true: dataset/ (all files in output root)
2235 // NOTE: group (train/val/test) is NOT used for directory structure
2236 let sequence_dir = sample
2237 .sequence_name()
2238 .map(|name| sanitize_path_component(name));
2239
2240 let target_dir = if flatten {
2241 output.clone()
2242 } else {
2243 sequence_dir
2244 .as_ref()
2245 .map(|seq| output.join(seq))
2246 .unwrap_or_else(|| output.clone())
2247 };
2248 fs::create_dir_all(&target_dir).await?;
2249
2250 let sanitized_sample_name = sample
2251 .name()
2252 .map(|name| sanitize_path_component(&name))
2253 .unwrap_or_else(|| "unknown".to_string());
2254
2255 let image_name = sample.image_name().map(sanitize_path_component);
2256
2257 // Construct filename with smart prefixing for flatten mode
2258 // When flatten=true and sample belongs to a sequence:
2259 // - Check if filename already starts with "{sequence_name}_"
2260 // - If not, prepend "{sequence_name}_{frame}_" to avoid conflicts
2261 // - If yes, use filename as-is (already uniquely named)
2262 let file_name = if is_image {
2263 if let Some(img_name) = image_name {
2264 Client::build_filename(
2265 &img_name,
2266 flatten,
2267 sequence_dir.as_ref(),
2268 sample.frame_number(),
2269 )
2270 } else {
2271 format!("{}.{}", sanitized_sample_name, file_ext)
2272 }
2273 } else {
2274 let base_name = format!("{}.{}", sanitized_sample_name, file_ext);
2275 Client::build_filename(
2276 &base_name,
2277 flatten,
2278 sequence_dir.as_ref(),
2279 sample.frame_number(),
2280 )
2281 };
2282
2283 let file_path = target_dir.join(&file_name);
2284
2285 let mut file = File::create(&file_path).await?;
2286 file.write_all(&data).await?;
2287 }
2288 }
2289
2290 // Update progress after sample completes
2291 if let Some(progress) = &progress {
2292 let completed = current.fetch_add(1, Ordering::SeqCst) + 1;
2293 let _ = progress
2294 .send(Progress {
2295 current: completed,
2296 total,
2297 status: Some("Downloading".to_string()),
2298 })
2299 .await;
2300 }
2301
2302 Ok::<(), Error>(())
2303 })
2304 })
2305 .collect::<Vec<_>>();
2306
2307 join_all(tasks)
2308 .await
2309 .into_iter()
2310 .collect::<Result<Vec<_>, _>>()?
2311 .into_iter()
2312 .collect::<Result<Vec<_>, _>>()?;
2313
2314 Ok(())
2315 }
2316
2317 /// Builds a filename with smart prefixing for flatten mode.
2318 ///
2319 /// When flattening sequences into a single directory, this function ensures
2320 /// unique filenames by checking if the sequence prefix already exists and
2321 /// adding it if necessary.
2322 ///
2323 /// # Logic
2324 ///
2325 /// - If `flatten=false`: returns `base_name` unchanged
2326 /// - If `flatten=true` and no sequence: returns `base_name` unchanged
2327 /// - If `flatten=true` and in sequence:
2328 /// - Already prefixed with `{sequence_name}_`: returns `base_name`
2329 /// unchanged
2330 /// - Not prefixed: returns `{sequence_name}_{frame}_{base_name}` or
2331 /// `{sequence_name}_{base_name}`
2332 fn build_filename(
2333 base_name: &str,
2334 flatten: bool,
2335 sequence_name: Option<&String>,
2336 frame_number: Option<u32>,
2337 ) -> String {
2338 if !flatten || sequence_name.is_none() {
2339 return base_name.to_string();
2340 }
2341
2342 let seq_name = sequence_name.unwrap();
2343 let prefix = format!("{}_", seq_name);
2344
2345 // Check if already prefixed with sequence name
2346 if base_name.starts_with(&prefix) {
2347 base_name.to_string()
2348 } else {
2349 // Add sequence (and optionally frame) prefix
2350 match frame_number {
2351 Some(frame) => format!("{}{}_{}", prefix, frame, base_name),
2352 None => format!("{}{}", prefix, base_name),
2353 }
2354 }
2355 }
2356
2357 /// List available annotation sets for the specified dataset.
2358 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2359 pub async fn annotation_sets(
2360 &self,
2361 dataset_id: DatasetID,
2362 ) -> Result<Vec<AnnotationSet>, Error> {
2363 let params = HashMap::from([("dataset_id", dataset_id)]);
2364 self.rpc("annset.list".to_owned(), Some(params)).await
2365 }
2366
2367 /// Create a new annotation set for the specified dataset.
2368 ///
2369 /// # Arguments
2370 ///
2371 /// * `dataset_id` - The ID of the dataset to create the annotation set in
2372 /// * `name` - The name of the new annotation set
2373 /// * `description` - Optional description for the annotation set
2374 ///
2375 /// # Returns
2376 ///
2377 /// Returns the annotation set ID of the newly created annotation set.
2378 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
2379 pub async fn create_annotation_set(
2380 &self,
2381 dataset_id: DatasetID,
2382 name: &str,
2383 description: Option<&str>,
2384 ) -> Result<AnnotationSetID, Error> {
2385 #[derive(Serialize)]
2386 struct Params<'a> {
2387 dataset_id: DatasetID,
2388 name: &'a str,
2389 operator: &'a str,
2390 #[serde(skip_serializing_if = "Option::is_none")]
2391 description: Option<&'a str>,
2392 }
2393
2394 #[derive(Deserialize)]
2395 struct CreateAnnotationSetResult {
2396 id: AnnotationSetID,
2397 }
2398
2399 let username = self.username().await?;
2400 let result: CreateAnnotationSetResult = self
2401 .rpc(
2402 "annset.add".to_owned(),
2403 Some(Params {
2404 dataset_id,
2405 name,
2406 operator: &username,
2407 description,
2408 }),
2409 )
2410 .await?;
2411 Ok(result.id)
2412 }
2413
2414 /// Deletes an annotation set by marking it as deleted.
2415 ///
2416 /// # Arguments
2417 ///
2418 /// * `annotation_set_id` - The ID of the annotation set to delete
2419 ///
2420 /// # Returns
2421 ///
2422 /// Returns `Ok(())` if the annotation set was successfully marked as
2423 /// deleted.
2424 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2425 pub async fn delete_annotation_set(
2426 &self,
2427 annotation_set_id: AnnotationSetID,
2428 ) -> Result<(), Error> {
2429 let params = HashMap::from([("id", annotation_set_id)]);
2430 let _: serde_json::Value = self.rpc("annset.delete".to_owned(), Some(params)).await?;
2431 Ok(())
2432 }
2433
2434 /// Retrieve the annotation set with the specified ID.
2435 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2436 pub async fn annotation_set(
2437 &self,
2438 annotation_set_id: AnnotationSetID,
2439 ) -> Result<AnnotationSet, Error> {
2440 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
2441 self.rpc("annset.get".to_owned(), Some(params)).await
2442 }
2443
2444 /// Get the annotations for the specified annotation set with the
2445 /// requested annotation types. The annotation types are used to filter
2446 /// the annotations returned. The groups parameter is used to filter for
2447 /// dataset groups (train, val, test). Images which do not have any
2448 /// annotations are also included in the result as long as they are in the
2449 /// requested groups (when specified).
2450 ///
2451 /// The result is a vector of Annotations objects which contain the
2452 /// full dataset along with the annotations for the specified types.
2453 ///
2454 /// # Progress
2455 ///
2456 /// Reports progress with `status: None` as samples are fetched and
2457 /// processed for their annotations. Progress unit is samples processed
2458 /// (not individual annotations).
2459 ///
2460 /// To get the annotations as a DataFrame, use the `samples_dataframe`
2461 /// method instead.
2462 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(annotation_set_id = %annotation_set_id)))]
2463 pub async fn annotations(
2464 &self,
2465 annotation_set_id: AnnotationSetID,
2466 groups: &[String],
2467 annotation_types: &[AnnotationType],
2468 progress: Option<Sender<Progress>>,
2469 ) -> Result<Vec<Annotation>, Error> {
2470 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
2471 let labels = self
2472 .labels(dataset_id)
2473 .await?
2474 .into_iter()
2475 .map(|label| (label.name().to_string(), label.index()))
2476 .collect::<HashMap<_, _>>();
2477 let total = self
2478 .samples_count(
2479 dataset_id,
2480 Some(annotation_set_id),
2481 annotation_types,
2482 groups,
2483 &[],
2484 )
2485 .await?
2486 .total as usize;
2487
2488 if total == 0 {
2489 return Ok(vec![]);
2490 }
2491
2492 let context = FetchContext {
2493 dataset_id,
2494 annotation_set_id: Some(annotation_set_id),
2495 groups,
2496 types: annotation_types.iter().map(|t| t.to_string()).collect(),
2497 labels: &labels,
2498 };
2499
2500 self.fetch_annotations_paginated(context, total, progress)
2501 .await
2502 }
2503
2504 async fn fetch_annotations_paginated(
2505 &self,
2506 context: FetchContext<'_>,
2507 total: usize,
2508 progress: Option<Sender<Progress>>,
2509 ) -> Result<Vec<Annotation>, Error> {
2510 let mut annotations = vec![];
2511 let mut continue_token: Option<String> = None;
2512 let mut current = 0;
2513
2514 loop {
2515 let params = SamplesListParams {
2516 dataset_id: context.dataset_id,
2517 annotation_set_id: context.annotation_set_id,
2518 types: context.types.clone(),
2519 group_names: context.groups.to_vec(),
2520 continue_token,
2521 };
2522
2523 let result: SamplesListResult =
2524 self.rpc("samples.list".to_owned(), Some(params)).await?;
2525 current += result.samples.len();
2526 continue_token = result.continue_token;
2527
2528 if result.samples.is_empty() {
2529 break;
2530 }
2531
2532 self.process_sample_annotations(&result.samples, context.labels, &mut annotations);
2533
2534 if let Some(progress) = &progress {
2535 let _ = progress
2536 .send(Progress {
2537 current,
2538 total,
2539 status: None,
2540 })
2541 .await;
2542 }
2543
2544 match &continue_token {
2545 Some(token) if !token.is_empty() => continue,
2546 _ => break,
2547 }
2548 }
2549
2550 drop(progress);
2551 Ok(annotations)
2552 }
2553
2554 fn process_sample_annotations(
2555 &self,
2556 samples: &[Sample],
2557 labels: &HashMap<String, u64>,
2558 annotations: &mut Vec<Annotation>,
2559 ) {
2560 for sample in samples {
2561 if sample.annotations().is_empty() {
2562 let mut annotation = Annotation::new();
2563 annotation.set_sample_id(sample.id());
2564 annotation.set_name(sample.name());
2565 annotation.set_sequence_name(sample.sequence_name().cloned());
2566 annotation.set_frame_number(sample.frame_number());
2567 annotation.set_group(sample.group().cloned());
2568 annotations.push(annotation);
2569 continue;
2570 }
2571
2572 for annotation in sample.annotations() {
2573 let mut annotation = annotation.clone();
2574 annotation.set_sample_id(sample.id());
2575 annotation.set_name(sample.name());
2576 annotation.set_sequence_name(sample.sequence_name().cloned());
2577 annotation.set_frame_number(sample.frame_number());
2578 annotation.set_group(sample.group().cloned());
2579 Self::set_label_index_from_map(&mut annotation, labels);
2580 annotations.push(annotation);
2581 }
2582 }
2583 }
2584
2585 /// Delete annotations in bulk from specified samples.
2586 ///
2587 /// This method calls the `annotation.bulk.del` API to efficiently remove
2588 /// annotations from multiple samples at once. Useful for clearing
2589 /// annotations before re-importing updated data.
2590 ///
2591 /// # Arguments
2592 /// * `annotation_set_id` - The annotation set containing the annotations
2593 /// * `annotation_types` - Types to delete: "box" for bounding boxes, "seg"
2594 /// for masks
2595 /// * `sample_ids` - Sample IDs (image IDs) to delete annotations from
2596 ///
2597 /// # Example
2598 /// ```no_run
2599 /// # use edgefirst_client::{Client, AnnotationSetID, SampleID};
2600 /// # async fn example() -> Result<(), edgefirst_client::Error> {
2601 /// # let client = Client::new()?.with_login("user", "pass").await?;
2602 /// let annotation_set_id = AnnotationSetID::from(123);
2603 /// let sample_ids = vec![SampleID::from(1), SampleID::from(2)];
2604 ///
2605 /// client
2606 /// .delete_annotations_bulk(
2607 /// annotation_set_id,
2608 /// &["box".to_string(), "seg".to_string()],
2609 /// &sample_ids,
2610 /// )
2611 /// .await?;
2612 /// # Ok(())
2613 /// # }
2614 /// ```
2615 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, sample_ids), fields(annotation_set_id = %annotation_set_id)))]
2616 pub async fn delete_annotations_bulk(
2617 &self,
2618 annotation_set_id: AnnotationSetID,
2619 annotation_types: &[String],
2620 sample_ids: &[SampleID],
2621 ) -> Result<(), Error> {
2622 use crate::api::AnnotationBulkDeleteParams;
2623
2624 let params = AnnotationBulkDeleteParams {
2625 annotation_set_id: annotation_set_id.into(),
2626 annotation_types: annotation_types.to_vec(),
2627 image_ids: sample_ids.iter().map(|id| (*id).into()).collect(),
2628 delete_all: None,
2629 };
2630
2631 let _: String = self
2632 .rpc("annotation.bulk.del".to_owned(), Some(params))
2633 .await?;
2634 Ok(())
2635 }
2636
2637 /// Add annotations in bulk.
2638 ///
2639 /// This method calls the `annotation.add_bulk` API to efficiently add
2640 /// multiple annotations at once. The annotations must be in server format
2641 /// with image_id references.
2642 ///
2643 /// # Arguments
2644 /// * `annotation_set_id` - The annotation set to add annotations to
2645 /// * `annotations` - Vector of server-format annotations to add
2646 ///
2647 /// # Returns
2648 /// Vector of created annotation records from the server.
2649 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotations), fields(annotation_count = annotations.len())))]
2650 pub async fn add_annotations_bulk(
2651 &self,
2652 annotation_set_id: AnnotationSetID,
2653 annotations: Vec<crate::api::ServerAnnotation>,
2654 ) -> Result<Vec<serde_json::Value>, Error> {
2655 use crate::api::AnnotationAddBulkParams;
2656
2657 let params = AnnotationAddBulkParams {
2658 annotation_set_id: annotation_set_id.into(),
2659 annotations,
2660 };
2661
2662 self.rpc("annotation.add_bulk".to_owned(), Some(params))
2663 .await
2664 }
2665
2666 /// Helper to parse frame number from image_name when sequence_name is
2667 /// present. This ensures frame_number is always derived from the image
2668 /// filename, not from the server's frame_number field (which may be
2669 /// inconsistent).
2670 ///
2671 /// Returns Some(frame_number) if sequence_name is present and frame can be
2672 /// parsed, otherwise None.
2673 fn parse_frame_from_image_name(
2674 image_name: Option<&String>,
2675 sequence_name: Option<&String>,
2676 ) -> Option<u32> {
2677 use std::path::Path;
2678
2679 let sequence = sequence_name?;
2680 let name = image_name?;
2681
2682 // Extract stem (remove extension)
2683 let stem = Path::new(name).file_stem().and_then(|s| s.to_str())?;
2684
2685 // Parse frame from format: "sequence_XXX" where XXX is the frame number
2686 stem.strip_prefix(sequence)
2687 .and_then(|suffix| suffix.strip_prefix('_'))
2688 .and_then(|frame_str| frame_str.parse::<u32>().ok())
2689 }
2690
2691 /// Helper to set label index from a label map
2692 fn set_label_index_from_map(annotation: &mut Annotation, labels: &HashMap<String, u64>) {
2693 if let Some(label) = annotation.label() {
2694 annotation.set_label_index(Some(labels[label.as_str()]));
2695 }
2696 }
2697
2698 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2699 pub async fn samples_count(
2700 &self,
2701 dataset_id: DatasetID,
2702 annotation_set_id: Option<AnnotationSetID>,
2703 annotation_types: &[AnnotationType],
2704 groups: &[String],
2705 types: &[FileType],
2706 ) -> Result<SamplesCountResult, Error> {
2707 // Use server type names for API calls (e.g., "box" instead of "box2d")
2708 let types = annotation_types
2709 .iter()
2710 .map(|t| t.as_server_type().to_string())
2711 .chain(types.iter().map(|t| t.to_string()))
2712 .collect::<Vec<_>>();
2713
2714 let params = SamplesListParams {
2715 dataset_id,
2716 annotation_set_id,
2717 group_names: groups.to_vec(),
2718 types,
2719 continue_token: None,
2720 };
2721
2722 self.rpc("samples.count".to_owned(), Some(params)).await
2723 }
2724
2725 /// Fetches samples from a dataset with optional annotation and file type
2726 /// filters.
2727 ///
2728 /// # Arguments
2729 ///
2730 /// * `dataset_id` - The dataset to fetch samples from
2731 /// * `annotation_set_id` - Optional annotation set to include annotations
2732 /// from
2733 /// * `annotation_types` - Filter by annotation types (box2d, box3d, mask)
2734 /// * `groups` - Filter by sample groups (e.g., "train", "val", "test")
2735 /// * `types` - File types to include metadata for
2736 /// * `progress` - Optional channel for progress updates
2737 ///
2738 /// # Progress
2739 ///
2740 /// Reports progress with `status: None` as samples are fetched from the
2741 /// server in paginated batches. Progress unit is samples fetched.
2742 ///
2743 /// # Returns
2744 ///
2745 /// Vector of [`Sample`] objects with metadata and optionally annotations.
2746 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, annotation_types, groups, types, progress), fields(dataset_id = %dataset_id, annotation_set_id = ?annotation_set_id)))]
2747 pub async fn samples(
2748 &self,
2749 dataset_id: DatasetID,
2750 annotation_set_id: Option<AnnotationSetID>,
2751 annotation_types: &[AnnotationType],
2752 groups: &[String],
2753 types: &[FileType],
2754 progress: Option<Sender<Progress>>,
2755 ) -> Result<Vec<Sample>, Error> {
2756 // Use server type names for API calls (e.g., "box" instead of "box2d")
2757 let types_vec = annotation_types
2758 .iter()
2759 .map(|t| t.as_server_type().to_string())
2760 .chain(types.iter().map(|t| t.to_string()))
2761 .collect::<Vec<_>>();
2762 let labels = self
2763 .labels(dataset_id)
2764 .await?
2765 .into_iter()
2766 .map(|label| (label.name().to_string(), label.index()))
2767 .collect::<HashMap<_, _>>();
2768 let total = self
2769 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
2770 .await?
2771 .total as usize;
2772
2773 if total == 0 {
2774 return Ok(vec![]);
2775 }
2776
2777 let context = FetchContext {
2778 dataset_id,
2779 annotation_set_id,
2780 groups,
2781 types: types_vec,
2782 labels: &labels,
2783 };
2784
2785 self.fetch_samples_paginated(context, total, progress).await
2786 }
2787
2788 /// Get all sample names in a dataset.
2789 ///
2790 /// This is an efficient method for checking which samples already exist,
2791 /// useful for resuming interrupted imports. It only retrieves sample names
2792 /// without loading full annotation data.
2793 ///
2794 /// # Arguments
2795 ///
2796 /// * `dataset_id` - The dataset to query
2797 /// * `groups` - Optional group filter (empty = all groups)
2798 /// * `progress` - Optional progress channel
2799 ///
2800 /// # Progress
2801 ///
2802 /// Reports progress with `status: None` as sample names are fetched from
2803 /// the server in paginated batches. Progress unit is samples fetched.
2804 ///
2805 /// # Returns
2806 ///
2807 /// A HashSet of sample names (image_name field) that exist in the dataset.
2808 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
2809 pub async fn sample_names(
2810 &self,
2811 dataset_id: DatasetID,
2812 groups: &[String],
2813 progress: Option<Sender<Progress>>,
2814 ) -> Result<std::collections::HashSet<String>, Error> {
2815 use std::collections::HashSet;
2816
2817 let total = self
2818 .samples_count(dataset_id, None, &[], groups, &[])
2819 .await?
2820 .total as usize;
2821
2822 if total == 0 {
2823 return Ok(HashSet::new());
2824 }
2825
2826 let mut names = HashSet::with_capacity(total);
2827 let mut continue_token: Option<String> = None;
2828 let mut current = 0;
2829
2830 loop {
2831 let params = SamplesListParams {
2832 dataset_id,
2833 annotation_set_id: None,
2834 types: vec![], // No type filter - we just want names
2835 group_names: groups.to_vec(),
2836 continue_token: continue_token.clone(),
2837 };
2838
2839 let result: SamplesListResult =
2840 self.rpc("samples.list".to_owned(), Some(params)).await?;
2841 current += result.samples.len();
2842 continue_token = result.continue_token;
2843
2844 if result.samples.is_empty() {
2845 break;
2846 }
2847
2848 // Extract sample names (normalized without extension)
2849 for sample in result.samples {
2850 if let Some(name) = sample.name() {
2851 names.insert(name);
2852 }
2853 }
2854
2855 if let Some(ref p) = progress {
2856 let _ = p
2857 .send(Progress {
2858 current,
2859 total,
2860 status: None,
2861 })
2862 .await;
2863 }
2864
2865 match &continue_token {
2866 Some(token) if !token.is_empty() => continue,
2867 _ => break,
2868 }
2869 }
2870
2871 Ok(names)
2872 }
2873
2874 async fn fetch_samples_paginated(
2875 &self,
2876 context: FetchContext<'_>,
2877 total: usize,
2878 progress: Option<Sender<Progress>>,
2879 ) -> Result<Vec<Sample>, Error> {
2880 let mut samples = vec![];
2881 let mut continue_token: Option<String> = None;
2882 let mut current = 0;
2883
2884 loop {
2885 let params = SamplesListParams {
2886 dataset_id: context.dataset_id,
2887 annotation_set_id: context.annotation_set_id,
2888 types: context.types.clone(),
2889 group_names: context.groups.to_vec(),
2890 continue_token: continue_token.clone(),
2891 };
2892
2893 let result: SamplesListResult =
2894 self.rpc("samples.list".to_owned(), Some(params)).await?;
2895 current += result.samples.len();
2896 continue_token = result.continue_token;
2897
2898 if result.samples.is_empty() {
2899 break;
2900 }
2901
2902 samples.append(
2903 &mut result
2904 .samples
2905 .into_iter()
2906 .map(|s| {
2907 // Use server's frame_number if valid (>= 0 after deserialization)
2908 // Otherwise parse from image_name as fallback
2909 // This ensures we respect explicit frame_number from uploads
2910 // while still handling legacy data that only has filename encoding
2911 let frame_number = s.frame_number.or_else(|| {
2912 Self::parse_frame_from_image_name(
2913 s.image_name.as_ref(),
2914 s.sequence_name.as_ref(),
2915 )
2916 });
2917
2918 let mut anns = s.annotations().to_vec();
2919 for ann in &mut anns {
2920 // Set annotation fields from parent sample
2921 ann.set_name(s.name());
2922 ann.set_group(s.group().cloned());
2923 ann.set_sequence_name(s.sequence_name().cloned());
2924 ann.set_frame_number(frame_number);
2925 Self::set_label_index_from_map(ann, context.labels);
2926 }
2927 s.with_annotations(anns).with_frame_number(frame_number)
2928 })
2929 .collect::<Vec<_>>(),
2930 );
2931
2932 if let Some(progress) = &progress {
2933 let _ = progress
2934 .send(Progress {
2935 current,
2936 total,
2937 status: None,
2938 })
2939 .await;
2940 }
2941
2942 match &continue_token {
2943 Some(token) if !token.is_empty() => continue,
2944 _ => break,
2945 }
2946 }
2947
2948 drop(progress);
2949 Ok(samples)
2950 }
2951
2952 /// Populates (imports) samples into a dataset using the `samples.populate2`
2953 /// API.
2954 ///
2955 /// This method creates new samples in the specified dataset, optionally
2956 /// with annotations and sensor data files. For each sample, the `files`
2957 /// field is checked for local file paths. If a filename is a valid path
2958 /// to an existing file, the file will be automatically uploaded to S3
2959 /// using presigned URLs returned by the server. The filename in the
2960 /// request is replaced with the basename (path removed) before sending
2961 /// to the server.
2962 ///
2963 /// # Important Notes
2964 ///
2965 /// - **`annotation_set_id` is REQUIRED** when importing samples with
2966 /// annotations. Without it, the server will accept the request but will
2967 /// not save the annotation data. Use [`Client::annotation_sets`] to query
2968 /// available annotation sets for a dataset, or create a new one via the
2969 /// Studio UI.
2970 /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
2971 /// boxes. Divide pixel coordinates by image width/height before creating
2972 /// [`Box2d`](crate::Box2d) annotations.
2973 /// - **Files are uploaded automatically** when the filename is a valid
2974 /// local path. The method will replace the full path with just the
2975 /// basename before sending to the server.
2976 /// - **Image dimensions are extracted automatically** for image files using
2977 /// the `imagesize` crate. The width/height are sent to the server and
2978 /// stored in the `image_files` table. These dimensions are returned by
2979 /// `samples.list` and used in [`samples_dataframe`](crate::samples_dataframe)
2980 /// to populate the `size` column.
2981 /// - **UUIDs are generated automatically** if not provided. If you need
2982 /// deterministic UUIDs, set `sample.uuid` explicitly before calling.
2983 ///
2984 /// # Arguments
2985 ///
2986 /// * `dataset_id` - The ID of the dataset to populate
2987 /// * `annotation_set_id` - **Required** if samples contain annotations,
2988 /// otherwise they will be ignored. Query with
2989 /// [`Client::annotation_sets`].
2990 /// * `samples` - Vector of samples to import with metadata and file
2991 /// references. For files, use the full local path - it will be uploaded
2992 /// automatically. UUIDs and image dimensions will be
2993 /// auto-generated/extracted if not provided.
2994 /// * `progress` - Optional channel for progress updates
2995 ///
2996 /// # Progress
2997 ///
2998 /// Reports progress with `status: None` as each sample's files are
2999 /// uploaded. Progress unit is samples (not individual files). Each
3000 /// sample may contain multiple files (image, lidar, radar, etc.) which
3001 /// are all uploaded before the sample is counted as complete.
3002 ///
3003 /// # Returns
3004 ///
3005 /// Returns the API result with sample UUIDs and upload status.
3006 ///
3007 /// # Example
3008 ///
3009 /// ```no_run
3010 /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
3011 ///
3012 /// # async fn example() -> Result<(), edgefirst_client::Error> {
3013 /// # let client = Client::new()?.with_login("user", "pass").await?;
3014 /// # let dataset_id = DatasetID::from(1);
3015 /// // Query available annotation sets for the dataset
3016 /// let annotation_sets = client.annotation_sets(dataset_id).await?;
3017 /// let annotation_set_id = annotation_sets
3018 /// .first()
3019 /// .ok_or_else(|| {
3020 /// edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
3021 /// })?
3022 /// .id();
3023 ///
3024 /// // Create sample with annotation (UUID will be auto-generated)
3025 /// let mut sample = Sample::new();
3026 /// sample.width = Some(1920);
3027 /// sample.height = Some(1080);
3028 /// sample.group = Some("train".to_string());
3029 ///
3030 /// // Add file - use full path to local file, it will be uploaded automatically
3031 /// sample.files = vec![SampleFile::with_filename(
3032 /// "image".to_string(),
3033 /// "/path/to/image.jpg".to_string(),
3034 /// )];
3035 ///
3036 /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
3037 /// let mut annotation = Annotation::new();
3038 /// annotation.set_label(Some("person".to_string()));
3039 /// // Normalize pixel coordinates by dividing by image dimensions
3040 /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
3041 /// annotation.set_box2d(Some(bbox));
3042 /// sample.annotations = vec![annotation];
3043 ///
3044 /// // Populate with annotation_set_id (REQUIRED for annotations)
3045 /// let result = client
3046 /// .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
3047 /// .await?;
3048 /// # Ok(())
3049 /// # }
3050 /// ```
3051 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
3052 pub async fn populate_samples(
3053 &self,
3054 dataset_id: DatasetID,
3055 annotation_set_id: Option<AnnotationSetID>,
3056 samples: Vec<Sample>,
3057 progress: Option<Sender<Progress>>,
3058 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
3059 self.populate_samples_with_concurrency(
3060 dataset_id,
3061 annotation_set_id,
3062 samples,
3063 progress,
3064 None,
3065 )
3066 .await
3067 }
3068
3069 /// Populate samples with custom upload concurrency.
3070 ///
3071 /// Same as [`populate_samples`](Self::populate_samples) but allows
3072 /// specifying the maximum number of concurrent file uploads. Use this
3073 /// for bulk imports where higher concurrency can significantly reduce
3074 /// upload time.
3075 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, samples, progress), fields(sample_count = samples.len())))]
3076 pub async fn populate_samples_with_concurrency(
3077 &self,
3078 dataset_id: DatasetID,
3079 annotation_set_id: Option<AnnotationSetID>,
3080 samples: Vec<Sample>,
3081 progress: Option<Sender<Progress>>,
3082 concurrency: Option<usize>,
3083 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
3084 use crate::api::SamplesPopulateParams;
3085 #[cfg(feature = "profiling")]
3086 use tracing::Instrument as _;
3087
3088 // Track which files need to be uploaded
3089 let mut files_to_upload: Vec<(String, String, FileSource, String)> = Vec::new();
3090
3091 // Process samples to detect local files and generate UUIDs. This is
3092 // synchronous CPU/metadata work; the span uses `.entered()` since it
3093 // runs on the current task with no await inside.
3094 let samples = {
3095 #[cfg(feature = "profiling")]
3096 let _prepare_span = tracing::info_span!("prepare_samples", n = samples.len()).entered();
3097 self.prepare_samples_for_upload(samples, &mut files_to_upload)?
3098 };
3099
3100 let has_files_to_upload = !files_to_upload.is_empty();
3101
3102 // Call populate API with presigned_urls=true if we have files to upload
3103 let params = SamplesPopulateParams {
3104 dataset_id,
3105 annotation_set_id,
3106 presigned_urls: Some(has_files_to_upload),
3107 samples,
3108 };
3109
3110 #[cfg(feature = "profiling")]
3111 let rpc_start = std::time::Instant::now();
3112 let results: Vec<crate::SamplesPopulateResult> = self
3113 .rpc("samples.populate2".to_owned(), Some(params))
3114 .await?;
3115 #[cfg(feature = "profiling")]
3116 upload_stats::add_rpc_nanos(rpc_start.elapsed().as_nanos() as u64);
3117
3118 // Upload files if we have any. The S3 fan-out is async, so the span is
3119 // attached to the future with `.instrument()` (not `.entered()`) to stay
3120 // correct when this batch overlaps others.
3121 if has_files_to_upload {
3122 #[cfg(feature = "profiling")]
3123 let n_files = files_to_upload.len();
3124 #[cfg(feature = "profiling")]
3125 let upload_start = std::time::Instant::now();
3126 let upload_fut =
3127 self.upload_sample_files(&results, files_to_upload, progress, concurrency);
3128 #[cfg(feature = "profiling")]
3129 let upload_fut =
3130 upload_fut.instrument(tracing::info_span!("upload_files", files = n_files));
3131 upload_fut.await?;
3132 #[cfg(feature = "profiling")]
3133 upload_stats::add_upload_nanos(upload_start.elapsed().as_nanos() as u64);
3134 }
3135
3136 Ok(results)
3137 }
3138
3139 fn prepare_samples_for_upload(
3140 &self,
3141 samples: Vec<Sample>,
3142 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
3143 ) -> Result<Vec<Sample>, Error> {
3144 Ok(samples
3145 .into_iter()
3146 .map(|mut sample| {
3147 // Generate UUID if not provided
3148 if sample.uuid.is_none() {
3149 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
3150 }
3151
3152 let sample_uuid = sample.uuid.clone().expect("UUID just set above");
3153
3154 // Process files: detect local paths and queue for upload
3155 let files_copy = sample.files.clone();
3156 let updated_files: Vec<crate::SampleFile> = files_copy
3157 .iter()
3158 .map(|file| {
3159 self.process_sample_file(file, &sample_uuid, &mut sample, files_to_upload)
3160 })
3161 .collect();
3162
3163 sample.files = updated_files;
3164 sample
3165 })
3166 .collect())
3167 }
3168
3169 fn process_sample_file(
3170 &self,
3171 file: &crate::SampleFile,
3172 sample_uuid: &str,
3173 sample: &mut Sample,
3174 files_to_upload: &mut Vec<(String, String, FileSource, String)>,
3175 ) -> crate::SampleFile {
3176 use std::path::Path;
3177
3178 // Handle files with raw bytes (e.g., from ZIP archives)
3179 if let Some(bytes) = file.bytes()
3180 && let Some(filename) = file.filename()
3181 {
3182 // For image files with bytes, try to extract dimensions if not already set
3183 if file.file_type() == "image"
3184 && (sample.width.is_none() || sample.height.is_none())
3185 && let Ok(size) = imagesize::blob_size(bytes)
3186 {
3187 sample.width = Some(size.width as u32);
3188 sample.height = Some(size.height as u32);
3189 }
3190
3191 // Store the bytes for later upload
3192 files_to_upload.push((
3193 sample_uuid.to_string(),
3194 file.file_type().to_string(),
3195 FileSource::Bytes(bytes.to_vec()),
3196 filename.to_string(),
3197 ));
3198
3199 // Return SampleFile with just the filename
3200 return crate::SampleFile::with_filename(
3201 file.file_type().to_string(),
3202 filename.to_string(),
3203 );
3204 }
3205
3206 // Handle files with local paths
3207 if let Some(filename) = file.filename() {
3208 let path = Path::new(filename);
3209
3210 // Check if this is a valid local file path
3211 if path.exists()
3212 && path.is_file()
3213 && let Some(basename) = path.file_name().and_then(|s| s.to_str())
3214 {
3215 // For image files, try to extract dimensions if not already set
3216 if file.file_type() == "image"
3217 && (sample.width.is_none() || sample.height.is_none())
3218 && let Ok(size) = imagesize::size(path)
3219 {
3220 sample.width = Some(size.width as u32);
3221 sample.height = Some(size.height as u32);
3222 }
3223
3224 // Store the full path for later upload
3225 files_to_upload.push((
3226 sample_uuid.to_string(),
3227 file.file_type().to_string(),
3228 FileSource::Path(path.to_path_buf()),
3229 basename.to_string(),
3230 ));
3231
3232 // Return SampleFile with just the basename
3233 return crate::SampleFile::with_filename(
3234 file.file_type().to_string(),
3235 basename.to_string(),
3236 );
3237 }
3238 }
3239 // Return the file unchanged if not a local path
3240 file.clone()
3241 }
3242
3243 async fn upload_sample_files(
3244 &self,
3245 results: &[crate::SamplesPopulateResult],
3246 files_to_upload: Vec<(String, String, FileSource, String)>,
3247 progress: Option<Sender<Progress>>,
3248 concurrency: Option<usize>,
3249 ) -> Result<(), Error> {
3250 // Build a map from (sample_uuid, basename) -> file source
3251 let mut upload_map: HashMap<(String, String), FileSource> = HashMap::new();
3252 for (uuid, _file_type, source, basename) in files_to_upload {
3253 upload_map.insert((uuid, basename), source);
3254 }
3255
3256 let http = self.bulk_http.clone();
3257
3258 // Extract the data we need for parallel upload
3259 let upload_tasks: Vec<_> = results
3260 .iter()
3261 .map(|result| (result.uuid.clone(), result.urls.clone()))
3262 .collect();
3263
3264 parallel_foreach_items(
3265 upload_tasks,
3266 progress.clone(),
3267 concurrency,
3268 move |(uuid, urls)| {
3269 let http = http.clone();
3270 let upload_map = upload_map.clone();
3271
3272 async move {
3273 // Upload all files for this sample
3274 for url_info in &urls {
3275 if let Some(source) =
3276 upload_map.get(&(uuid.clone(), url_info.filename.clone()))
3277 {
3278 match source {
3279 FileSource::Path(path) => {
3280 upload_file_to_presigned_url(
3281 http.clone(),
3282 &url_info.url,
3283 path.clone(),
3284 )
3285 .await?;
3286 }
3287 FileSource::Bytes(bytes) => {
3288 upload_bytes_to_presigned_url(
3289 http.clone(),
3290 &url_info.url,
3291 bytes.clone(),
3292 &url_info.filename,
3293 )
3294 .await?;
3295 }
3296 }
3297 }
3298 }
3299
3300 Ok(())
3301 }
3302 },
3303 )
3304 .await
3305 }
3306
3307 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3308 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
3309 // Validate URL is absolute (has scheme) to avoid RelativeUrlWithoutBase error
3310 if !url.starts_with("http://") && !url.starts_with("https://") {
3311 return Err(Error::InvalidParameters(format!(
3312 "Invalid URL (must be absolute): {}",
3313 url
3314 )));
3315 }
3316
3317 let resp = self.bulk_http.get(url).send().await?;
3318
3319 if !resp.status().is_success() {
3320 return Err(Error::HttpError(resp.error_for_status().unwrap_err()));
3321 }
3322
3323 let bytes = resp.bytes().await?;
3324 Ok(bytes.to_vec())
3325 }
3326
3327 /// Get samples as a DataFrame with complete 2025.10 schema.
3328 ///
3329 /// This is the recommended method for obtaining dataset annotations in
3330 /// DataFrame format. It includes all sample metadata (size, location,
3331 /// pose, degradation) as optional columns.
3332 ///
3333 /// # Arguments
3334 ///
3335 /// * `dataset_id` - Dataset identifier
3336 /// * `annotation_set_id` - Optional annotation set filter
3337 /// * `groups` - Dataset groups to include (train, val, test)
3338 /// * `types` - Annotation types to filter (bbox, box3d, mask)
3339 /// * `progress` - Optional progress callback
3340 ///
3341 /// # Progress
3342 ///
3343 /// Reports progress with `status: None` as samples are fetched from the
3344 /// server in paginated batches. Progress unit is samples fetched. This
3345 /// method delegates to [`samples()`](Self::samples) and shares its
3346 /// progress behavior.
3347 ///
3348 /// # Example
3349 ///
3350 /// ```rust,no_run
3351 /// use edgefirst_client::Client;
3352 ///
3353 /// # async fn example() -> Result<(), edgefirst_client::Error> {
3354 /// # let client = Client::new()?;
3355 /// # let dataset_id = 1.into();
3356 /// # let annotation_set_id = 1.into();
3357 /// let df = client
3358 /// .samples_dataframe(
3359 /// dataset_id,
3360 /// Some(annotation_set_id),
3361 /// &["train".to_string()],
3362 /// &[],
3363 /// None,
3364 /// )
3365 /// .await?;
3366 /// println!("DataFrame shape: {:?}", df.shape());
3367 /// # Ok(())
3368 /// # }
3369 /// ```
3370 #[cfg(feature = "polars")]
3371 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
3372 pub async fn samples_dataframe(
3373 &self,
3374 dataset_id: DatasetID,
3375 annotation_set_id: Option<AnnotationSetID>,
3376 groups: &[String],
3377 types: &[AnnotationType],
3378 progress: Option<Sender<Progress>>,
3379 ) -> Result<DataFrame, Error> {
3380 use crate::dataset::samples_dataframe;
3381
3382 let samples = self
3383 .samples(dataset_id, annotation_set_id, types, groups, &[], progress)
3384 .await?;
3385 samples_dataframe(&samples)
3386 }
3387
3388 /// Update image dimensions for existing samples in a dataset.
3389 ///
3390 /// This is useful for backfilling width/height data on samples that were
3391 /// uploaded before dimension extraction was added, or where dimensions
3392 /// could not be determined at upload time.
3393 ///
3394 /// # Arguments
3395 ///
3396 /// * `dataset_id` - The dataset containing the samples
3397 /// * `updates` - List of dimension updates (sample ID, width, height)
3398 ///
3399 /// # Returns
3400 ///
3401 /// The number of samples that were successfully updated.
3402 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, updates), fields(dataset_id = %dataset_id, count = updates.len())))]
3403 pub async fn update_sample_dimensions(
3404 &self,
3405 dataset_id: DatasetID,
3406 updates: Vec<crate::SampleDimensionUpdate>,
3407 ) -> Result<u64, Error> {
3408 use crate::api::SamplesUpdateDimensionsParams;
3409
3410 if updates.is_empty() {
3411 return Ok(0);
3412 }
3413
3414 // Batch in groups of 500 to stay within server limits
3415 let mut total_updated = 0u64;
3416 for chunk in updates.chunks(500) {
3417 let params = SamplesUpdateDimensionsParams {
3418 dataset_id,
3419 samples: chunk.to_vec(),
3420 };
3421 let result: crate::SamplesUpdateDimensionsResult = self
3422 .rpc("samples.update_dimensions".to_owned(), Some(params))
3423 .await?;
3424 total_updated += result.updated;
3425 }
3426 Ok(total_updated)
3427 }
3428
3429 /// Backfill missing image dimensions for a dataset.
3430 ///
3431 /// Downloads image data for samples that are missing width/height,
3432 /// extracts the dimensions using the `imagesize` crate, and updates
3433 /// the server with the computed values.
3434 ///
3435 /// This is a one-time repair operation for datasets that were uploaded
3436 /// before the client added automatic dimension extraction.
3437 ///
3438 /// # Arguments
3439 ///
3440 /// * `dataset_id` - The dataset to backfill
3441 /// * `progress` - Optional progress channel
3442 ///
3443 /// # Returns
3444 ///
3445 /// The number of samples whose dimensions were updated.
3446 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(dataset_id = %dataset_id)))]
3447 pub async fn backfill_sample_dimensions(
3448 &self,
3449 dataset_id: DatasetID,
3450 progress: Option<Sender<Progress>>,
3451 ) -> Result<u64, Error> {
3452 // Fetch all samples; listing progress is not forwarded to the caller
3453 // since it would interleave with the dimension-computing phase.
3454 let samples = self.samples(dataset_id, None, &[], &[], &[], None).await?;
3455
3456 // Filter to samples missing dimensions
3457 let missing: Vec<&Sample> = samples
3458 .iter()
3459 .filter(|s| s.width.is_none() || s.height.is_none())
3460 .collect();
3461
3462 if missing.is_empty() {
3463 return Ok(0);
3464 }
3465
3466 let total = missing.len();
3467 let mut updates: Vec<crate::SampleDimensionUpdate> = Vec::with_capacity(total);
3468
3469 for (i, sample) in missing.into_iter().enumerate() {
3470 let current = i + 1;
3471
3472 let Some(id) = sample.id() else {
3473 Self::send_progress(&progress, current, total).await;
3474 continue;
3475 };
3476
3477 let Some(url) = sample.image_url() else {
3478 #[cfg(feature = "profiling")]
3479 tracing::warn!(sample_id = %id, "skipping sample: no image URL");
3480 Self::send_progress(&progress, current, total).await;
3481 continue;
3482 };
3483
3484 // Download image data to determine dimensions
3485 let resp = self.bulk_http.get(url).send().await;
3486 let Ok(resp) = resp else {
3487 #[cfg(feature = "profiling")]
3488 tracing::warn!(sample_id = %id, "skipping sample: download failed");
3489 Self::send_progress(&progress, current, total).await;
3490 continue;
3491 };
3492
3493 // Skip non-success responses (e.g. 404, 500) rather than parsing error pages
3494 if !resp.status().is_success() {
3495 #[cfg(feature = "profiling")]
3496 tracing::warn!(sample_id = %id, status = %resp.status(), "skipping sample: non-success HTTP status");
3497 Self::send_progress(&progress, current, total).await;
3498 continue;
3499 }
3500
3501 let Ok(bytes) = resp.bytes().await else {
3502 #[cfg(feature = "profiling")]
3503 tracing::warn!(sample_id = %id, "skipping sample: failed to read response body");
3504 Self::send_progress(&progress, current, total).await;
3505 continue;
3506 };
3507
3508 // Extract dimensions from the downloaded image
3509 let Ok(size) = imagesize::blob_size(&bytes) else {
3510 #[cfg(feature = "profiling")]
3511 tracing::warn!(sample_id = %id, "skipping sample: could not determine dimensions");
3512 Self::send_progress(&progress, current, total).await;
3513 continue;
3514 };
3515
3516 let (Ok(width), Ok(height)) = (u32::try_from(size.width), u32::try_from(size.height))
3517 else {
3518 #[cfg(feature = "profiling")]
3519 tracing::warn!(sample_id = %id, width = size.width, height = size.height, "skipping sample: dimensions overflow u32");
3520 Self::send_progress(&progress, current, total).await;
3521 continue;
3522 };
3523
3524 updates.push(crate::SampleDimensionUpdate { id, width, height });
3525 Self::send_progress(&progress, current, total).await;
3526 }
3527
3528 // Send updates to server
3529 self.update_sample_dimensions(dataset_id, updates).await
3530 }
3531
3532 /// Emit a progress event if a progress channel is provided.
3533 async fn send_progress(progress: &Option<Sender<Progress>>, current: usize, total: usize) {
3534 if let Some(tx) = progress {
3535 let _ = tx
3536 .send(Progress {
3537 current,
3538 total,
3539 status: Some("Computing dimensions".to_string()),
3540 })
3541 .await;
3542 }
3543 }
3544
3545 /// List available snapshots. If a name is provided, only snapshots
3546 /// containing that name are returned.
3547 ///
3548 /// Results are sorted by match quality: exact matches first, then
3549 /// case-insensitive exact matches, then shorter descriptions (more
3550 /// specific), then alphabetically.
3551 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
3552 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
3553 let snapshots: Vec<Snapshot> = self
3554 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
3555 .await?;
3556 if let Some(name) = name {
3557 Ok(filter_and_sort_by_name(snapshots, name, |s| {
3558 s.description()
3559 }))
3560 } else {
3561 Ok(snapshots)
3562 }
3563 }
3564
3565 /// Get the snapshot with the specified id.
3566 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
3567 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
3568 let params = HashMap::from([("snapshot_id", snapshot_id)]);
3569 self.rpc("snapshots.get".to_owned(), Some(params)).await
3570 }
3571
3572 /// Create a new snapshot from an MCAP file or EdgeFirst Dataset directory.
3573 ///
3574 /// Snapshots are frozen datasets in EdgeFirst Dataset Format (Zip/Arrow
3575 /// pairs) that serve two primary purposes:
3576 ///
3577 /// 1. **MCAP uploads**: Upload MCAP files containing sensor data (images,
3578 /// point clouds, IMU, GPS) to EdgeFirst Studio. Snapshots can then be
3579 /// restored with AGTG (Automatic Ground Truth Generation) and optional
3580 /// auto-depth processing.
3581 ///
3582 /// 2. **Dataset exchange**: Export datasets for backup, sharing, or
3583 /// migration between EdgeFirst Studio instances using the create →
3584 /// download → upload → restore workflow.
3585 ///
3586 /// Large files are automatically chunked into 100MB parts and uploaded
3587 /// concurrently using S3 multipart upload with presigned URLs. Each chunk
3588 /// is streamed without loading into memory, maintaining constant memory
3589 /// usage.
3590 ///
3591 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
3592 /// uploads (default: half of CPU cores, min 2, max 8). Lower values work
3593 /// better for large files to avoid timeout issues. Higher values (16-32)
3594 /// are better for many small files.
3595 ///
3596 /// # Arguments
3597 ///
3598 /// * `path` - Local file path to MCAP file or directory containing
3599 /// EdgeFirst Dataset Format files (Zip/Arrow pairs)
3600 /// * `progress` - Optional channel to receive upload progress updates
3601 ///
3602 /// # Progress
3603 ///
3604 /// Reports progress with `status: None` as file data is uploaded. Progress
3605 /// unit is bytes uploaded. For single files, total is the file size. For
3606 /// directories, total is the combined size of all files.
3607 ///
3608 /// # Returns
3609 ///
3610 /// Returns a `Snapshot` object with ID, description, status, path, and
3611 /// creation timestamp on success.
3612 ///
3613 /// # Errors
3614 ///
3615 /// Returns an error if:
3616 /// * Path doesn't exist or contains invalid UTF-8
3617 /// * File format is invalid (not MCAP or EdgeFirst Dataset Format)
3618 /// * Upload fails or network error occurs
3619 /// * Server rejects the snapshot
3620 ///
3621 /// # Example
3622 ///
3623 /// ```no_run
3624 /// # use edgefirst_client::{Client, Progress};
3625 /// # use tokio::sync::mpsc;
3626 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3627 /// let client = Client::new()?.with_token_path(None)?;
3628 ///
3629 /// // Upload MCAP file with progress tracking
3630 /// let (tx, mut rx) = mpsc::channel(1);
3631 /// tokio::spawn(async move {
3632 /// while let Some(Progress {
3633 /// current,
3634 /// total,
3635 /// status,
3636 /// }) = rx.recv().await
3637 /// {
3638 /// println!(
3639 /// "{}: {}/{} bytes ({:.1}%)",
3640 /// status.as_deref().unwrap_or("Upload"),
3641 /// current,
3642 /// total,
3643 /// (current as f64 / total as f64) * 100.0
3644 /// );
3645 /// }
3646 /// });
3647 /// let snapshot = client.create_snapshot("data.mcap", Some(tx)).await?;
3648 /// println!("Created snapshot: {:?}", snapshot.id());
3649 ///
3650 /// // Upload dataset directory (no progress)
3651 /// let snapshot = client.create_snapshot("./dataset_export/", None).await?;
3652 /// # Ok(())
3653 /// # }
3654 /// ```
3655 ///
3656 /// # See Also
3657 ///
3658 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3659 /// dataset
3660 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
3661 /// data
3662 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
3663 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
3664 /// * [Snapshots Guide](https://doc.edgefirst.ai/latest/studio/snapshots/)
3665 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3666 pub async fn create_snapshot(
3667 &self,
3668 path: &str,
3669 progress: Option<Sender<Progress>>,
3670 ) -> Result<Snapshot, Error> {
3671 let path = Path::new(path);
3672
3673 if path.is_dir() {
3674 let path_str = path.to_str().ok_or_else(|| {
3675 Error::IoError(std::io::Error::new(
3676 std::io::ErrorKind::InvalidInput,
3677 "Path contains invalid UTF-8",
3678 ))
3679 })?;
3680 return self.create_snapshot_folder(path_str, progress).await;
3681 }
3682
3683 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3684 Error::IoError(std::io::Error::new(
3685 std::io::ErrorKind::InvalidInput,
3686 "Invalid filename",
3687 ))
3688 })?;
3689 let total = path.metadata()?.len() as usize;
3690 let current = Arc::new(AtomicUsize::new(0));
3691
3692 if let Some(progress) = &progress {
3693 let _ = progress
3694 .send(Progress {
3695 current: 0,
3696 total,
3697 status: None,
3698 })
3699 .await;
3700 }
3701
3702 let params = SnapshotCreateMultipartParams {
3703 snapshot_name: name.to_owned(),
3704 keys: vec![name.to_owned()],
3705 file_sizes: vec![total],
3706 snapshot_type: None,
3707 };
3708 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3709 .rpc(
3710 "snapshots.create_upload_url_multipart".to_owned(),
3711 Some(params),
3712 )
3713 .await?;
3714
3715 let snapshot_id = match multipart.get("snapshot_id") {
3716 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3717 _ => return Err(Error::InvalidResponse),
3718 };
3719
3720 let snapshot = self.snapshot(snapshot_id).await?;
3721 let part_prefix = snapshot
3722 .path()
3723 .split("::/")
3724 .last()
3725 .ok_or(Error::InvalidResponse)?
3726 .to_owned();
3727 let part_key = format!("{}/{}", part_prefix, name);
3728 let mut part = match multipart.get(&part_key) {
3729 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3730 _ => return Err(Error::InvalidResponse),
3731 }
3732 .clone();
3733 part.key = Some(part_key);
3734
3735 let params = upload_multipart(
3736 self.bulk_http.clone(),
3737 part.clone(),
3738 path.to_path_buf(),
3739 total,
3740 current,
3741 progress.clone(),
3742 )
3743 .await?;
3744
3745 let complete: String = self
3746 .rpc(
3747 "snapshots.complete_multipart_upload".to_owned(),
3748 Some(params),
3749 )
3750 .await?;
3751 debug!("Snapshot Multipart Complete: {:?}", complete);
3752
3753 let params: SnapshotStatusParams = SnapshotStatusParams {
3754 snapshot_id,
3755 status: "available".to_owned(),
3756 };
3757 let _: SnapshotStatusResult = self
3758 .rpc("snapshots.update".to_owned(), Some(params))
3759 .await?;
3760
3761 if let Some(progress) = progress {
3762 drop(progress);
3763 }
3764
3765 self.snapshot(snapshot_id).await
3766 }
3767
3768 async fn create_snapshot_folder(
3769 &self,
3770 path: &str,
3771 progress: Option<Sender<Progress>>,
3772 ) -> Result<Snapshot, Error> {
3773 let path = Path::new(path);
3774 let name = path.file_name().and_then(|n| n.to_str()).ok_or_else(|| {
3775 Error::IoError(std::io::Error::new(
3776 std::io::ErrorKind::InvalidInput,
3777 "Invalid directory name",
3778 ))
3779 })?;
3780
3781 let files = WalkDir::new(path)
3782 .into_iter()
3783 .filter_map(|entry| entry.ok())
3784 .filter(|entry| entry.file_type().is_file())
3785 .filter_map(|entry| entry.path().strip_prefix(path).ok().map(|p| p.to_owned()))
3786 .collect::<Vec<_>>();
3787
3788 let total: usize = files
3789 .iter()
3790 .filter_map(|file| path.join(file).metadata().ok())
3791 .map(|metadata| metadata.len() as usize)
3792 .sum();
3793 let current = Arc::new(AtomicUsize::new(0));
3794
3795 if let Some(progress) = &progress {
3796 let _ = progress
3797 .send(Progress {
3798 current: 0,
3799 total,
3800 status: None,
3801 })
3802 .await;
3803 }
3804
3805 let keys = files
3806 .iter()
3807 .filter_map(|key| key.to_str().map(|s| s.to_owned()))
3808 .collect::<Vec<_>>();
3809 let file_sizes = files
3810 .iter()
3811 .filter_map(|key| path.join(key).metadata().ok())
3812 .map(|metadata| metadata.len() as usize)
3813 .collect::<Vec<_>>();
3814
3815 let params = SnapshotCreateMultipartParams {
3816 snapshot_name: name.to_owned(),
3817 keys,
3818 file_sizes,
3819 snapshot_type: None,
3820 };
3821
3822 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
3823 .rpc(
3824 "snapshots.create_upload_url_multipart".to_owned(),
3825 Some(params),
3826 )
3827 .await?;
3828
3829 let snapshot_id = match multipart.get("snapshot_id") {
3830 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
3831 _ => return Err(Error::InvalidResponse),
3832 };
3833
3834 let snapshot = self.snapshot(snapshot_id).await?;
3835 let part_prefix = snapshot
3836 .path()
3837 .split("::/")
3838 .last()
3839 .ok_or(Error::InvalidResponse)?
3840 .to_owned();
3841
3842 for file in files {
3843 let file_str = file.to_str().ok_or_else(|| {
3844 Error::IoError(std::io::Error::new(
3845 std::io::ErrorKind::InvalidInput,
3846 "File path contains invalid UTF-8",
3847 ))
3848 })?;
3849 let part_key = format!("{}/{}", part_prefix, file_str);
3850 let mut part = match multipart.get(&part_key) {
3851 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
3852 _ => return Err(Error::InvalidResponse),
3853 }
3854 .clone();
3855 part.key = Some(part_key);
3856
3857 let params = upload_multipart(
3858 self.bulk_http.clone(),
3859 part.clone(),
3860 path.join(file),
3861 total,
3862 current.clone(),
3863 progress.clone(),
3864 )
3865 .await?;
3866
3867 let complete: String = self
3868 .rpc(
3869 "snapshots.complete_multipart_upload".to_owned(),
3870 Some(params),
3871 )
3872 .await?;
3873 debug!("Snapshot Part Complete: {:?}", complete);
3874 }
3875
3876 let params = SnapshotStatusParams {
3877 snapshot_id,
3878 status: "available".to_owned(),
3879 };
3880 let _: SnapshotStatusResult = self
3881 .rpc("snapshots.update".to_owned(), Some(params))
3882 .await?;
3883
3884 if let Some(progress) = progress {
3885 drop(progress);
3886 }
3887
3888 self.snapshot(snapshot_id).await
3889 }
3890
3891 /// Create a snapshot from EdgeFirst Dataset Format files (.arrow + .zip).
3892 ///
3893 /// Uploads a paired Arrow manifest and ZIP archive as a single snapshot.
3894 /// This format is the native EdgeFirst Dataset Format used for efficient
3895 /// dataset storage and transfer.
3896 ///
3897 /// # Arguments
3898 ///
3899 /// * `arrow_path` - Path to the Arrow manifest file (.arrow)
3900 /// * `zip_path` - Path to the ZIP archive containing images (.zip)
3901 /// * `description` - Optional description for the snapshot
3902 /// * `progress` - Optional progress channel for upload tracking
3903 ///
3904 /// # File Requirements
3905 ///
3906 /// - Arrow file must have `.arrow` extension
3907 /// - ZIP file must have `.zip` extension
3908 /// - Both files must exist and be readable
3909 ///
3910 /// # Example
3911 ///
3912 /// ```no_run
3913 /// # use edgefirst_client::Client;
3914 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
3915 /// let client = Client::new()?.with_token_path(None)?;
3916 ///
3917 /// let snapshot = client
3918 /// .create_snapshot_edgefirst_format(
3919 /// "dataset.arrow",
3920 /// "dataset.zip",
3921 /// Some("My Dataset Snapshot"),
3922 /// None,
3923 /// )
3924 /// .await?;
3925 /// println!("Created snapshot: {}", snapshot.id());
3926 /// # Ok(())
3927 /// # }
3928 /// ```
3929 ///
3930 /// # See Also
3931 ///
3932 /// * [`create_snapshot`](Self::create_snapshot) - Upload single file or
3933 /// folder
3934 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
3935 /// dataset
3936 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress)))]
3937 pub async fn create_snapshot_edgefirst_format(
3938 &self,
3939 arrow_path: &str,
3940 zip_path: &str,
3941 description: Option<&str>,
3942 progress: Option<Sender<Progress>>,
3943 ) -> Result<Snapshot, Error> {
3944 let arrow_path = Path::new(arrow_path);
3945 let zip_path = Path::new(zip_path);
3946
3947 // Validate files exist
3948 if !arrow_path.exists() {
3949 return Err(Error::IoError(std::io::Error::new(
3950 std::io::ErrorKind::NotFound,
3951 format!("Arrow file not found: {}", arrow_path.display()),
3952 )));
3953 }
3954 if !zip_path.exists() {
3955 return Err(Error::IoError(std::io::Error::new(
3956 std::io::ErrorKind::NotFound,
3957 format!("ZIP file not found: {}", zip_path.display()),
3958 )));
3959 }
3960
3961 // Get file names
3962 let arrow_name = arrow_path
3963 .file_name()
3964 .and_then(|n| n.to_str())
3965 .ok_or_else(|| {
3966 Error::IoError(std::io::Error::new(
3967 std::io::ErrorKind::InvalidInput,
3968 "Invalid Arrow filename",
3969 ))
3970 })?;
3971 let zip_name = zip_path
3972 .file_name()
3973 .and_then(|n| n.to_str())
3974 .ok_or_else(|| {
3975 Error::IoError(std::io::Error::new(
3976 std::io::ErrorKind::InvalidInput,
3977 "Invalid ZIP filename",
3978 ))
3979 })?;
3980
3981 // Generate snapshot name from arrow file (without extension)
3982 let snapshot_name = description
3983 .map(|s| s.to_string())
3984 .or_else(|| {
3985 arrow_path
3986 .file_stem()
3987 .and_then(|s| s.to_str())
3988 .map(|s| s.to_string())
3989 })
3990 .unwrap_or_else(|| "edgefirst_dataset".to_string());
3991
3992 // Calculate file sizes
3993 let arrow_size = arrow_path.metadata()?.len() as usize;
3994 let zip_size = zip_path.metadata()?.len() as usize;
3995 let total = arrow_size + zip_size;
3996 let current = Arc::new(AtomicUsize::new(0));
3997
3998 if let Some(progress) = &progress {
3999 let _ = progress
4000 .send(Progress {
4001 current: 0,
4002 total,
4003 status: None,
4004 })
4005 .await;
4006 }
4007
4008 // Create multipart upload request with "ziparrow" type
4009 let params = SnapshotCreateMultipartParams {
4010 snapshot_name,
4011 keys: vec![arrow_name.to_owned(), zip_name.to_owned()],
4012 file_sizes: vec![arrow_size, zip_size],
4013 snapshot_type: Some("ziparrow".to_string()),
4014 };
4015
4016 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
4017 .rpc(
4018 "snapshots.create_upload_url_multipart".to_owned(),
4019 Some(params),
4020 )
4021 .await?;
4022
4023 let snapshot_id = match multipart.get("snapshot_id") {
4024 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
4025 _ => return Err(Error::InvalidResponse),
4026 };
4027
4028 let snapshot = self.snapshot(snapshot_id).await?;
4029 let part_prefix = snapshot
4030 .path()
4031 .split("::/")
4032 .last()
4033 .ok_or(Error::InvalidResponse)?
4034 .to_owned();
4035
4036 // Upload Arrow file
4037 let arrow_key = format!("{}/{}", part_prefix, arrow_name);
4038 let mut arrow_part = match multipart.get(&arrow_key) {
4039 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
4040 _ => return Err(Error::InvalidResponse),
4041 };
4042 arrow_part.key = Some(arrow_key);
4043
4044 let params = upload_multipart(
4045 self.bulk_http.clone(),
4046 arrow_part,
4047 arrow_path.to_path_buf(),
4048 total,
4049 current.clone(),
4050 progress.clone(),
4051 )
4052 .await?;
4053
4054 let _: String = self
4055 .rpc(
4056 "snapshots.complete_multipart_upload".to_owned(),
4057 Some(params),
4058 )
4059 .await?;
4060 debug!("Arrow file upload complete");
4061
4062 // Upload ZIP file
4063 let zip_key = format!("{}/{}", part_prefix, zip_name);
4064 let mut zip_part = match multipart.get(&zip_key) {
4065 Some(SnapshotCreateMultipartResultField::Part(part)) => part.clone(),
4066 _ => return Err(Error::InvalidResponse),
4067 };
4068 zip_part.key = Some(zip_key);
4069
4070 let params = upload_multipart(
4071 self.bulk_http.clone(),
4072 zip_part,
4073 zip_path.to_path_buf(),
4074 total,
4075 current.clone(),
4076 progress.clone(),
4077 )
4078 .await?;
4079
4080 let _: String = self
4081 .rpc(
4082 "snapshots.complete_multipart_upload".to_owned(),
4083 Some(params),
4084 )
4085 .await?;
4086 debug!("ZIP file upload complete");
4087
4088 // Mark snapshot as available
4089 let params = SnapshotStatusParams {
4090 snapshot_id,
4091 status: "available".to_owned(),
4092 };
4093 let _: SnapshotStatusResult = self
4094 .rpc("snapshots.update".to_owned(), Some(params))
4095 .await?;
4096
4097 if let Some(progress) = progress {
4098 drop(progress);
4099 }
4100
4101 self.snapshot(snapshot_id).await
4102 }
4103
4104 /// Delete a snapshot from EdgeFirst Studio.
4105 ///
4106 /// Permanently removes a snapshot and its associated data. This operation
4107 /// cannot be undone.
4108 ///
4109 /// # Arguments
4110 ///
4111 /// * `snapshot_id` - The snapshot ID to delete
4112 ///
4113 /// # Errors
4114 ///
4115 /// Returns an error if:
4116 /// * Snapshot doesn't exist
4117 /// * User lacks permission to delete the snapshot
4118 /// * Server error occurs
4119 ///
4120 /// # Example
4121 ///
4122 /// ```no_run
4123 /// # use edgefirst_client::{Client, SnapshotID};
4124 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4125 /// let client = Client::new()?.with_token_path(None)?;
4126 /// let snapshot_id = SnapshotID::from(123);
4127 /// client.delete_snapshot(snapshot_id).await?;
4128 /// # Ok(())
4129 /// # }
4130 /// ```
4131 ///
4132 /// # See Also
4133 ///
4134 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4135 /// * [`snapshots`](Self::snapshots) - List all snapshots
4136 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(snapshot_id = %snapshot_id)))]
4137 pub async fn delete_snapshot(&self, snapshot_id: SnapshotID) -> Result<(), Error> {
4138 let params = HashMap::from([("snapshot_id", snapshot_id)]);
4139 let _: serde_json::Value = self
4140 .rpc("snapshots.delete".to_owned(), Some(params))
4141 .await?;
4142 Ok(())
4143 }
4144
4145 /// Create a snapshot from an existing dataset on the server.
4146 ///
4147 /// Triggers server-side snapshot generation which exports the dataset's
4148 /// images and annotations into a downloadable EdgeFirst Dataset Format
4149 /// snapshot.
4150 ///
4151 /// This is the inverse of [`restore_snapshot`](Self::restore_snapshot) -
4152 /// while restore creates a dataset from a snapshot, this method creates a
4153 /// snapshot from a dataset.
4154 ///
4155 /// # Arguments
4156 ///
4157 /// * `dataset_id` - The dataset ID to create snapshot from
4158 /// * `description` - Description for the created snapshot
4159 ///
4160 /// # Returns
4161 ///
4162 /// Returns a `SnapshotCreateResult` containing the snapshot ID and task ID
4163 /// for monitoring progress.
4164 ///
4165 /// # Errors
4166 ///
4167 /// Returns an error if:
4168 /// * Dataset doesn't exist
4169 /// * User lacks permission to access the dataset
4170 /// * Server rejects the request
4171 ///
4172 /// # Example
4173 ///
4174 /// ```no_run
4175 /// # use edgefirst_client::{Client, DatasetID};
4176 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4177 /// let client = Client::new()?.with_token_path(None)?;
4178 /// let dataset_id = DatasetID::from(123);
4179 ///
4180 /// // Create snapshot from dataset (all annotation sets)
4181 /// let result = client
4182 /// .create_snapshot_from_dataset(dataset_id, "My Dataset Backup", None)
4183 /// .await?;
4184 /// println!("Created snapshot: {:?}", result.id);
4185 ///
4186 /// // Monitor progress via task ID
4187 /// if let Some(task_id) = result.task_id {
4188 /// println!("Task: {}", task_id);
4189 /// }
4190 /// # Ok(())
4191 /// # }
4192 /// ```
4193 ///
4194 /// # See Also
4195 ///
4196 /// * [`create_snapshot`](Self::create_snapshot) - Upload local files as
4197 /// snapshot
4198 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
4199 /// dataset
4200 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
4201 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(dataset_id = %dataset_id)))]
4202 pub async fn create_snapshot_from_dataset(
4203 &self,
4204 dataset_id: DatasetID,
4205 description: &str,
4206 annotation_set_id: Option<AnnotationSetID>,
4207 ) -> Result<SnapshotFromDatasetResult, Error> {
4208 // Resolve annotation_set_id: use provided value or fetch default
4209 let annotation_set_id = match annotation_set_id {
4210 Some(id) => id,
4211 None => {
4212 // Fetch annotation sets and find default ("annotations") or use first
4213 let sets = self.annotation_sets(dataset_id).await?;
4214 if sets.is_empty() {
4215 return Err(Error::InvalidParameters(
4216 "No annotation sets available for dataset".to_owned(),
4217 ));
4218 }
4219 // Look for "annotations" set (default), otherwise use first
4220 sets.iter()
4221 .find(|s| s.name() == "annotations")
4222 .unwrap_or(&sets[0])
4223 .id()
4224 }
4225 };
4226 let params = SnapshotCreateFromDataset {
4227 description: description.to_owned(),
4228 dataset_id,
4229 annotation_set_id,
4230 };
4231 self.rpc("snapshots.create".to_owned(), Some(params)).await
4232 }
4233
4234 /// Download a snapshot from EdgeFirst Studio to local storage.
4235 ///
4236 /// Downloads all files in a snapshot (single MCAP file or directory of
4237 /// EdgeFirst Dataset Format files) to the specified output path. Files are
4238 /// downloaded concurrently with progress tracking.
4239 ///
4240 /// **Concurrency tuning**: Set `MAX_TASKS` to control concurrent
4241 /// downloads (default: half of CPU cores, min 2, max 8).
4242 ///
4243 /// # Arguments
4244 ///
4245 /// * `snapshot_id` - The snapshot ID to download
4246 /// * `output` - Local directory path to save downloaded files
4247 /// * `progress` - Optional channel to receive download progress updates
4248 ///
4249 /// # Progress
4250 ///
4251 /// Reports progress with `status: None` as file data is received. Progress
4252 /// unit is bytes downloaded across all files combined. The total
4253 /// accumulates as file sizes become known (from HTTP Content-Length
4254 /// headers), so both `current` and `total` may increase during
4255 /// download.
4256 ///
4257 /// # Errors
4258 ///
4259 /// Returns an error if:
4260 /// * Snapshot doesn't exist
4261 /// * Output directory cannot be created
4262 /// * Download fails or network error occurs
4263 ///
4264 /// # Example
4265 ///
4266 /// ```no_run
4267 /// # use edgefirst_client::{Client, SnapshotID, Progress};
4268 /// # use tokio::sync::mpsc;
4269 /// # use std::path::PathBuf;
4270 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4271 /// let client = Client::new()?.with_token_path(None)?;
4272 /// let snapshot_id = SnapshotID::from(123);
4273 ///
4274 /// // Download with progress tracking
4275 /// let (tx, mut rx) = mpsc::channel(1);
4276 /// tokio::spawn(async move {
4277 /// while let Some(Progress {
4278 /// current,
4279 /// total,
4280 /// status,
4281 /// }) = rx.recv().await
4282 /// {
4283 /// println!(
4284 /// "{}: {}/{} bytes",
4285 /// status.as_deref().unwrap_or("Download"),
4286 /// current,
4287 /// total
4288 /// );
4289 /// }
4290 /// });
4291 /// client
4292 /// .download_snapshot(snapshot_id, PathBuf::from("./output"), Some(tx))
4293 /// .await?;
4294 /// # Ok(())
4295 /// # }
4296 /// ```
4297 ///
4298 /// # See Also
4299 ///
4300 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4301 /// * [`restore_snapshot`](Self::restore_snapshot) - Restore snapshot to
4302 /// dataset
4303 /// * [`delete_snapshot`](Self::delete_snapshot) - Delete snapshot
4304 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(snapshot_id = %snapshot_id, output = %output.display())))]
4305 pub async fn download_snapshot(
4306 &self,
4307 snapshot_id: SnapshotID,
4308 output: PathBuf,
4309 progress: Option<Sender<Progress>>,
4310 ) -> Result<(), Error> {
4311 fs::create_dir_all(&output).await?;
4312
4313 let params = HashMap::from([("snapshot_id", snapshot_id)]);
4314 let items: HashMap<String, String> = self
4315 .rpc("snapshots.create_download_url".to_owned(), Some(params))
4316 .await?;
4317
4318 // Single-phase: each task holds its semaphore permit for the full
4319 // lifetime of the request (GET → headers → stream → disk). This bounds
4320 // the number of simultaneously-open connections to max_tasks() and
4321 // avoids accumulating all responses in memory before streaming.
4322 //
4323 // total is updated atomically as each response's Content-Length header
4324 // arrives, so progress tracking is accurate without a separate phase.
4325 let http = self.bulk_http.clone();
4326 let current = Arc::new(AtomicUsize::new(0));
4327 let total = Arc::new(AtomicUsize::new(0));
4328 let sem = Arc::new(Semaphore::new(max_tasks()));
4329
4330 let tasks = items
4331 .into_iter()
4332 .map(|(key, url)| {
4333 let http = http.clone();
4334 let output = output.clone();
4335 let progress = progress.clone();
4336 let current = current.clone();
4337 let total = total.clone();
4338 let sem = sem.clone();
4339
4340 tokio::spawn(async move {
4341 let _permit = sem.acquire().await.map_err(|_| {
4342 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
4343 })?;
4344
4345 let res = http.get(url).send().await?;
4346 let res = res.error_for_status()?;
4347
4348 // Contribute this file's size to the running total so the
4349 // caller's progress bar knows the overall scope.
4350 if let Some(len) = res.content_length() {
4351 total.fetch_add(len as usize, Ordering::SeqCst);
4352 }
4353
4354 let mut file = File::create(output.join(key)).await?;
4355 let mut stream = res.bytes_stream();
4356
4357 while let Some(chunk) = stream.next().await {
4358 let chunk = chunk?;
4359 file.write_all(&chunk).await?;
4360 let len = chunk.len();
4361
4362 if let Some(progress) = &progress {
4363 let cur = current.fetch_add(len, Ordering::SeqCst) + len;
4364 let tot = total.load(Ordering::SeqCst);
4365 let _ = progress
4366 .send(Progress {
4367 current: cur,
4368 total: tot,
4369 status: None,
4370 })
4371 .await;
4372 }
4373 }
4374
4375 Ok::<(), Error>(())
4376 })
4377 })
4378 .collect::<Vec<_>>();
4379
4380 join_all(tasks)
4381 .await
4382 .into_iter()
4383 .collect::<Result<Vec<_>, _>>()?
4384 .into_iter()
4385 .collect::<Result<Vec<_>, _>>()?;
4386
4387 Ok(())
4388 }
4389
4390 /// Restore a snapshot to a dataset in EdgeFirst Studio with optional AGTG.
4391 ///
4392 /// Restores a snapshot (MCAP file or EdgeFirst Dataset) into a dataset in
4393 /// the specified project. For MCAP files, supports:
4394 ///
4395 /// * **AGTG (Automatic Ground Truth Generation)**: Automatically annotate
4396 /// detected objects with 2D masks/boxes and 3D boxes (if radar/LiDAR
4397 /// present)
4398 /// * **Auto-depth**: Generate depthmaps (Maivin/Raivin cameras only)
4399 /// * **Topic filtering**: Select specific MCAP topics to restore
4400 ///
4401 /// For EdgeFirst Dataset snapshots, this simply imports the pre-existing
4402 /// dataset structure.
4403 ///
4404 /// # Arguments
4405 ///
4406 /// * `project_id` - Target project ID
4407 /// * `snapshot_id` - Snapshot ID to restore
4408 /// * `topics` - MCAP topics to include (empty = all topics)
4409 /// * `autolabel` - Object labels for AGTG (empty = no auto-annotation)
4410 /// * `autodepth` - Generate depthmaps (Maivin/Raivin only)
4411 /// * `dataset_name` - Optional custom dataset name
4412 /// * `dataset_description` - Optional dataset description
4413 ///
4414 /// # Returns
4415 ///
4416 /// Returns a `SnapshotRestoreResult` with the new dataset ID and status.
4417 ///
4418 /// # Errors
4419 ///
4420 /// Returns an error if:
4421 /// * Snapshot or project doesn't exist
4422 /// * Snapshot format is invalid
4423 /// * Server rejects restoration parameters
4424 ///
4425 /// # Example
4426 ///
4427 /// ```no_run
4428 /// # use edgefirst_client::{Client, ProjectID, SnapshotID};
4429 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
4430 /// let client = Client::new()?.with_token_path(None)?;
4431 /// let project_id = ProjectID::from(1);
4432 /// let snapshot_id = SnapshotID::from(123);
4433 ///
4434 /// // Restore MCAP with AGTG for "person" and "car" detection
4435 /// let result = client
4436 /// .restore_snapshot(
4437 /// project_id,
4438 /// snapshot_id,
4439 /// &[], // All topics
4440 /// &["person".to_string(), "car".to_string()], // AGTG labels
4441 /// true, // Auto-depth
4442 /// Some("Highway Dataset"),
4443 /// Some("Collected on I-95"),
4444 /// )
4445 /// .await?;
4446 /// println!("Restored to dataset: {:?}", result.dataset_id);
4447 /// # Ok(())
4448 /// # }
4449 /// ```
4450 ///
4451 /// # See Also
4452 ///
4453 /// * [`create_snapshot`](Self::create_snapshot) - Upload snapshot
4454 /// * [`download_snapshot`](Self::download_snapshot) - Download snapshot
4455 /// * [AGTG Documentation](https://doc.edgefirst.ai/latest/datasets/tutorials/annotations/automatic/)
4456 #[allow(clippy::too_many_arguments)]
4457 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4458 pub async fn restore_snapshot(
4459 &self,
4460 project_id: ProjectID,
4461 snapshot_id: SnapshotID,
4462 topics: &[String],
4463 autolabel: &[String],
4464 autodepth: bool,
4465 dataset_name: Option<&str>,
4466 dataset_description: Option<&str>,
4467 ) -> Result<SnapshotRestoreResult, Error> {
4468 let params = SnapshotRestore {
4469 project_id,
4470 snapshot_id,
4471 fps: 1,
4472 autodepth,
4473 agtg_pipeline: !autolabel.is_empty(),
4474 autolabel: autolabel.to_vec(),
4475 topics: topics.to_vec(),
4476 dataset_name: dataset_name.map(|s| s.to_owned()),
4477 dataset_description: dataset_description.map(|s| s.to_owned()),
4478 };
4479 self.rpc("snapshots.restore".to_owned(), Some(params)).await
4480 }
4481
4482 /// Returns a list of experiments available to the user. The experiments
4483 /// are returned as a vector of Experiment objects. If name is provided
4484 /// then only experiments containing this string are returned.
4485 ///
4486 /// Results are sorted by match quality: exact matches first, then
4487 /// case-insensitive exact matches, then shorter names (more specific),
4488 /// then alphabetically.
4489 ///
4490 /// Experiments provide a method of organizing training and validation
4491 /// sessions together and are akin to an Experiment in MLFlow terminology.
4492 /// Each experiment can have multiple trainer sessions associated with it,
4493 /// these would be akin to runs in MLFlow terminology.
4494 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4495 pub async fn experiments(
4496 &self,
4497 project_id: ProjectID,
4498 name: Option<&str>,
4499 ) -> Result<Vec<Experiment>, Error> {
4500 let params = HashMap::from([("project_id", project_id)]);
4501 let experiments: Vec<Experiment> =
4502 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
4503 if let Some(name) = name {
4504 Ok(filter_and_sort_by_name(experiments, name, |e| e.name()))
4505 } else {
4506 Ok(experiments)
4507 }
4508 }
4509
4510 /// Return the experiment with the specified experiment ID. If the
4511 /// experiment does not exist, an error is returned.
4512 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4513 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
4514 let params = HashMap::from([("trainer_id", experiment_id)]);
4515 self.rpc("trainer.get".to_owned(), Some(params)).await
4516 }
4517
4518 /// Returns a list of trainer sessions available to the user. The trainer
4519 /// sessions are returned as a vector of TrainingSession objects. If name
4520 /// is provided then only trainer sessions containing this string are
4521 /// returned.
4522 ///
4523 /// Results are sorted by match quality: exact matches first, then
4524 /// case-insensitive exact matches, then shorter names (more specific),
4525 /// then alphabetically.
4526 ///
4527 /// Trainer sessions are akin to runs in MLFlow terminology. These
4528 /// represent an actual training session which will produce metrics and
4529 /// model artifacts.
4530 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4531 pub async fn training_sessions(
4532 &self,
4533 experiment_id: ExperimentID,
4534 name: Option<&str>,
4535 ) -> Result<Vec<TrainingSession>, Error> {
4536 let params = HashMap::from([("trainer_id", experiment_id)]);
4537 let sessions: Vec<TrainingSession> = self
4538 .rpc("trainer.session.list".to_owned(), Some(params))
4539 .await?;
4540 if let Some(name) = name {
4541 Ok(filter_and_sort_by_name(sessions, name, |s| s.name()))
4542 } else {
4543 Ok(sessions)
4544 }
4545 }
4546
4547 /// Return the trainer session with the specified trainer session ID. If
4548 /// the trainer session does not exist, an error is returned.
4549 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4550 pub async fn training_session(
4551 &self,
4552 session_id: TrainingSessionID,
4553 ) -> Result<TrainingSession, Error> {
4554 let params = HashMap::from([("trainer_session_id", session_id)]);
4555 self.rpc("trainer.session.get".to_owned(), Some(params))
4556 .await
4557 }
4558
4559 /// List validation sessions for the given project.
4560 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4561 pub async fn validation_sessions(
4562 &self,
4563 project_id: ProjectID,
4564 ) -> Result<Vec<ValidationSession>, Error> {
4565 let params = HashMap::from([("project_id", project_id)]);
4566 self.rpc("validate.session.list".to_owned(), Some(params))
4567 .await
4568 }
4569
4570 /// Retrieve a specific validation session.
4571 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4572 pub async fn validation_session(
4573 &self,
4574 session_id: ValidationSessionID,
4575 ) -> Result<ValidationSession, Error> {
4576 let params = HashMap::from([("validate_session_id", session_id)]);
4577 self.rpc("validate.session.get".to_owned(), Some(params))
4578 .await
4579 }
4580
4581 /// Create a new validation session via Studio's `cloud.server.start`.
4582 ///
4583 /// Pass `is_local: true` in the [`StartValidationRequest`] to create
4584 /// a **user-managed** session: the database row is created and the
4585 /// session is fully usable for data uploads / downloads / metrics,
4586 /// but no EC2 instance is provisioned and no automated validator
4587 /// pipeline is started. That is the mode our integration tests use
4588 /// — they create a session, exercise the wrapper APIs against it,
4589 /// then call [`Client::delete_validation_sessions`] in teardown so
4590 /// no stray sessions accumulate on the test account.
4591 ///
4592 /// Returns a [`NewValidationSession`] carrying the backing task id
4593 /// and the freshly-minted validation session id.
4594 ///
4595 /// # Errors
4596 ///
4597 /// Surfaces any RPC error from `cloud.server.start`. Common cases:
4598 /// `RpcError(101, …)` if a required entity is missing (project,
4599 /// training session, dataset, …); `PermissionDenied` if the caller
4600 /// can't write to the target project.
4601 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, req)))]
4602 pub async fn start_validation_session(
4603 &self,
4604 req: StartValidationRequest,
4605 ) -> Result<NewValidationSession, Error> {
4606 // Build the params shape the server expects. `cloud.server.start`
4607 // is intentionally generic — different server types pull
4608 // different fields out of `params` — so we serialize manually to
4609 // match the JS frontend's call site verbatim (see
4610 // `dve-frontend/src/components/ValidationPage/StartValidatorModal.vue`).
4611 let mut body = serde_json::Map::new();
4612 body.insert(
4613 "type".into(),
4614 serde_json::Value::String("validation".into()),
4615 );
4616 body.insert("name".into(), serde_json::Value::String(req.name));
4617 body.insert("project_id".into(), serde_json::to_value(req.project_id)?);
4618 body.insert(
4619 "training_session_id".into(),
4620 serde_json::to_value(req.training_session_id)?,
4621 );
4622 body.insert(
4623 "model_file".into(),
4624 serde_json::Value::String(req.model_file),
4625 );
4626 body.insert("val_type".into(), serde_json::Value::String(req.val_type));
4627 body.insert("is_local".into(), serde_json::Value::Bool(req.is_local));
4628 body.insert(
4629 "is_kubernetes".into(),
4630 serde_json::Value::Bool(req.is_kubernetes),
4631 );
4632
4633 // `validate.session` reads its config from `params.params` (one
4634 // extra envelope level). The outer `params` wrapper is required
4635 // even when the inner map is empty.
4636 let inner = serde_json::to_value(req.params)?;
4637 let mut outer = serde_json::Map::new();
4638 outer.insert("params".into(), inner);
4639 body.insert("params".into(), serde_json::Value::Object(outer));
4640
4641 if let Some(d) = req.description {
4642 body.insert("description".into(), serde_json::Value::String(d));
4643 }
4644 if let Some(id) = req.dataset_id {
4645 body.insert("dataset_id".into(), serde_json::to_value(id)?);
4646 }
4647 if let Some(id) = req.annotation_set_id {
4648 body.insert("annotation_set_id".into(), serde_json::to_value(id)?);
4649 }
4650 if let Some(id) = req.snapshot_id {
4651 body.insert("snapshot_id".into(), serde_json::to_value(id)?);
4652 }
4653
4654 self.rpc("cloud.server.start".to_owned(), Some(body)).await
4655 }
4656
4657 /// Delete one or more validation sessions via
4658 /// `validate.session.delete`.
4659 ///
4660 /// Used by integration tests to tear down sessions they created
4661 /// with [`Client::start_validation_session`]; idempotent against
4662 /// already-deleted ids on the server side (the RPC accepts the
4663 /// list, deletes what it can, and surfaces an error only if none
4664 /// of the ids were resolvable).
4665 ///
4666 /// # Errors
4667 ///
4668 /// Surfaces any RPC error from `validate.session.delete`. A
4669 /// `PermissionDenied` indicates the caller lacks
4670 /// `TrainerWrite` on at least one of the listed sessions.
4671 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4672 pub async fn delete_validation_sessions(
4673 &self,
4674 session_ids: &[ValidationSessionID],
4675 ) -> Result<(), Error> {
4676 let mut body = serde_json::Map::new();
4677 body.insert("session_ids".into(), serde_json::to_value(session_ids)?);
4678 let _: serde_json::Value = self
4679 .rpc("validate.session.delete".to_owned(), Some(body))
4680 .await?;
4681 Ok(())
4682 }
4683
4684 /// List the artifacts for the specified trainer session. The artifacts
4685 /// are returned as a vector of strings.
4686 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4687 pub async fn artifacts(
4688 &self,
4689 training_session_id: TrainingSessionID,
4690 ) -> Result<Vec<Artifact>, Error> {
4691 let params = HashMap::from([("training_session_id", training_session_id)]);
4692 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
4693 .await
4694 }
4695
4696 /// Download the model artifact for the specified trainer session to the
4697 /// specified file path, if path is not provided it will be downloaded to
4698 /// the current directory with the same filename.
4699 ///
4700 /// # Progress
4701 ///
4702 /// Reports progress with `status: None` as file data is received. Progress
4703 /// unit is bytes downloaded. Total is determined from the HTTP
4704 /// Content-Length header (may be 0 if server doesn't provide it).
4705 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4706 pub async fn download_artifact(
4707 &self,
4708 training_session_id: TrainingSessionID,
4709 modelname: &str,
4710 filename: Option<PathBuf>,
4711 progress: Option<Sender<Progress>>,
4712 ) -> Result<(), Error> {
4713 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
4714 let resp = self
4715 .bulk_http
4716 .get(format!(
4717 "{}/download_model?training_session_id={}&file={}",
4718 self.url,
4719 training_session_id.value(),
4720 modelname
4721 ))
4722 .header("Authorization", format!("Bearer {}", self.token().await))
4723 .send()
4724 .await?;
4725 if !resp.status().is_success() {
4726 let err = resp.error_for_status_ref().unwrap_err();
4727 return Err(Error::HttpError(err));
4728 }
4729
4730 if let Some(parent) = filename.parent() {
4731 fs::create_dir_all(parent).await?;
4732 }
4733
4734 stream_response_to_file(resp, &filename, progress).await
4735 }
4736
4737 /// Download the model checkpoint associated with the specified trainer
4738 /// session to the specified file path, if path is not provided it will be
4739 /// downloaded to the current directory with the same filename.
4740 ///
4741 /// There is no API for listing checkpoints it is expected that trainers are
4742 /// aware of possible checkpoints and their names within the checkpoint
4743 /// folder on the server.
4744 ///
4745 /// # Progress
4746 ///
4747 /// Reports progress with `status: None` as file data is received. Progress
4748 /// unit is bytes downloaded. Total is determined from the HTTP
4749 /// Content-Length header (may be 0 if server doesn't provide it).
4750 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, progress), fields(training_session_id = %training_session_id)))]
4751 pub async fn download_checkpoint(
4752 &self,
4753 training_session_id: TrainingSessionID,
4754 checkpoint: &str,
4755 filename: Option<PathBuf>,
4756 progress: Option<Sender<Progress>>,
4757 ) -> Result<(), Error> {
4758 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
4759 let resp = self
4760 .bulk_http
4761 .get(format!(
4762 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
4763 self.url,
4764 training_session_id.value(),
4765 checkpoint
4766 ))
4767 .header("Authorization", format!("Bearer {}", self.token().await))
4768 .send()
4769 .await?;
4770 if !resp.status().is_success() {
4771 let err = resp.error_for_status_ref().unwrap_err();
4772 return Err(Error::HttpError(err));
4773 }
4774
4775 if let Some(parent) = filename.parent() {
4776 fs::create_dir_all(parent).await?;
4777 }
4778
4779 stream_response_to_file(resp, &filename, progress).await
4780 }
4781
4782 /// Return a list of tasks for the current user.
4783 ///
4784 /// # Arguments
4785 ///
4786 /// * `name` - Optional filter for task name (client-side substring match)
4787 /// * `workflow` - Optional filter for workflow/task type. If provided,
4788 /// filters server-side by exact match. Valid values include: "trainer",
4789 /// "validation", "snapshot-create", "snapshot-restore", "copyds",
4790 /// "upload", "auto-ann", "auto-seg", "aigt", "import", "export",
4791 /// "convertor", "twostage"
4792 /// * `status` - Optional filter for task status (e.g., "running",
4793 /// "complete", "error")
4794 /// * `manager` - Optional filter for task manager type (e.g., "aws",
4795 /// "user", "kubernetes")
4796 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4797 pub async fn tasks(
4798 &self,
4799 name: Option<&str>,
4800 workflow: Option<&str>,
4801 status: Option<&str>,
4802 manager: Option<&str>,
4803 ) -> Result<Vec<Task>, Error> {
4804 let mut params = TasksListParams {
4805 continue_token: None,
4806 types: workflow.map(|w| vec![w.to_owned()]),
4807 status: status.map(|s| vec![s.to_owned()]),
4808 manager: manager.map(|m| vec![m.to_owned()]),
4809 };
4810 let mut tasks = Vec::new();
4811
4812 loop {
4813 let result = self
4814 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
4815 .await?;
4816 tasks.extend(result.tasks);
4817
4818 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
4819 params.continue_token = None;
4820 } else {
4821 params.continue_token = result.continue_token;
4822 }
4823
4824 if params.continue_token.is_none() {
4825 break;
4826 }
4827 }
4828
4829 if let Some(name) = name {
4830 tasks = filter_and_sort_by_name(tasks, name, |t| t.name());
4831 }
4832
4833 Ok(tasks)
4834 }
4835
4836 /// Submits a job (app run) to the server and returns the resulting `Job`
4837 /// record (which carries the linked task id alongside the cloud-batch
4838 /// metadata).
4839 ///
4840 /// # Arguments
4841 /// * `app_name` - The name of the registered app to run (e.g., `"edgefirst-validator"`).
4842 /// * `job_name` - A user-defined label for this run.
4843 /// * `env` - Environment variables passed to the job (string-string map).
4844 /// * `data` - Job input payload (e.g., session ids, parameters).
4845 ///
4846 /// # Returns
4847 /// The full `Job` record returned by the server (wraps the BK_BATCH object),
4848 /// including AWS Batch job ID, state, and the linked `task_id`. Callers that
4849 /// only need the task ID can call `.task_id()` on the returned `Job`.
4850 pub async fn job_run(
4851 &self,
4852 app_name: &str,
4853 job_name: &str,
4854 env: std::collections::HashMap<String, String>,
4855 data: std::collections::HashMap<String, crate::api::Parameter>,
4856 ) -> Result<crate::api::Job, Error> {
4857 let req = JobRunRequest {
4858 name: app_name.to_owned(),
4859 job_name: job_name.to_owned(),
4860 env,
4861 data,
4862 };
4863 let resp: crate::api::Job = match self.rpc("job.run".to_owned(), Some(&req)).await {
4864 Ok(r) => r,
4865 Err(Error::RpcError(code, msg)) => {
4866 return Err(map_rpc_error("job.run", code, msg, None));
4867 }
4868 Err(e) => return Err(e),
4869 };
4870 Ok(resp)
4871 }
4872
4873 /// Requests a running job task be stopped.
4874 ///
4875 /// Returns `Ok(())` if the stop request was accepted by the server. The
4876 /// task may still take time to fully terminate; poll `task_info` if you
4877 /// need to wait for shutdown.
4878 pub async fn job_stop(&self, task_id: crate::api::TaskID) -> Result<(), Error> {
4879 let req = JobStopRequest {
4880 task_id: task_id.value(),
4881 };
4882 // We don't care about the response body; deserialize as serde_json::Value.
4883 let _resp: serde_json::Value = match self.rpc("job.stop".to_owned(), Some(&req)).await {
4884 Ok(r) => r,
4885 Err(Error::RpcError(code, msg)) => {
4886 return Err(map_rpc_error("job.stop", code, msg, Some(task_id)));
4887 }
4888 Err(e) => return Err(e),
4889 };
4890 Ok(())
4891 }
4892
4893 /// Lists job (app-run) entries visible to the authenticated user.
4894 ///
4895 /// The server returns AWS Batch-wrapper entries (not bare `Task` objects),
4896 /// surfacing cloud-batch state (`RUNNING`/`SUCCEEDED`/...) and the linked
4897 /// `task_id`. Use `Job::task_id()` + `Client::task_info` to fetch the
4898 /// underlying task details.
4899 ///
4900 /// The server does not support server-side filters, so the optional
4901 /// `name` argument is applied client-side as a substring match against
4902 /// each job's `job_name`.
4903 pub async fn jobs(&self, name: Option<&str>) -> Result<Vec<crate::api::Job>, Error> {
4904 let req = JobsListRequest {};
4905 let mut jobs: Vec<crate::api::Job> = match self.rpc("job.list".to_owned(), Some(&req)).await
4906 {
4907 Ok(r) => r,
4908 Err(Error::RpcError(code, msg)) => {
4909 return Err(map_rpc_error("job.list", code, msg, None));
4910 }
4911 Err(e) => return Err(e),
4912 };
4913 if let Some(name) = name {
4914 let needle = name.to_lowercase();
4915 jobs.retain(|j| j.job_name.to_lowercase().contains(&needle));
4916 jobs.sort_by(|a, b| a.job_name.cmp(&b.job_name));
4917 }
4918 Ok(jobs)
4919 }
4920
4921 /// Retrieve the task information and status.
4922 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self), fields(task_id = %task_id)))]
4923 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
4924 self.rpc(
4925 "task.get".to_owned(),
4926 Some(HashMap::from([("id", task_id)])),
4927 )
4928 .await
4929 }
4930
4931 /// Updates the tasks status.
4932 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4933 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
4934 let status = TaskStatus {
4935 task_id,
4936 status: status.to_owned(),
4937 };
4938 self.rpc("docker.update.status".to_owned(), Some(status))
4939 .await
4940 }
4941
4942 /// Defines the stages for the task. The stages are defined as a mapping
4943 /// from stage names to their descriptions. Once stages are defined their
4944 /// status can be updated using the update_stage method.
4945 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, stages)))]
4946 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
4947 let stages: Vec<HashMap<String, String>> = stages
4948 .iter()
4949 .map(|(key, value)| {
4950 let mut stage_map = HashMap::new();
4951 stage_map.insert(key.to_string(), value.to_string());
4952 stage_map
4953 })
4954 .collect();
4955 let params = TaskStages { task_id, stages };
4956 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
4957 Ok(())
4958 }
4959
4960 /// Updates the progress of the task for the provided stage and status
4961 /// information.
4962 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4963 pub async fn update_stage(
4964 &self,
4965 task_id: TaskID,
4966 stage: &str,
4967 status: &str,
4968 message: &str,
4969 percentage: u8,
4970 ) -> Result<(), Error> {
4971 let stage = Stage::new(
4972 Some(task_id),
4973 stage.to_owned(),
4974 Some(status.to_owned()),
4975 Some(message.to_owned()),
4976 percentage,
4977 );
4978 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
4979 Ok(())
4980 }
4981
4982 /// Authenticated fetch from the Studio server using the bulk HTTP client
4983 /// (no total-request timeout; idle read timeout per chunk).
4984 ///
4985 /// **Buffers the entire response body into memory.** Suitable for small to
4986 /// medium payloads. For very large binary downloads (multi-GB artifacts or
4987 /// checkpoints), prefer a streaming approach that writes directly to disk.
4988 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self)))]
4989 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
4990 let req = self
4991 .bulk_http
4992 .get(format!("{}/{}", self.url, query))
4993 .header("User-Agent", "EdgeFirst Client")
4994 .header("Authorization", format!("Bearer {}", self.token().await));
4995 let resp = req.send().await?;
4996
4997 if resp.status().is_success() {
4998 let body = resp.bytes().await?;
4999
5000 if log_enabled!(Level::Trace) {
5001 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
5002 }
5003
5004 Ok(body.to_vec())
5005 } else {
5006 let err = resp.error_for_status_ref().unwrap_err();
5007 Err(Error::HttpError(err))
5008 }
5009 }
5010
5011 /// Sends a multipart post request to the server. This is used by the
5012 /// upload and download APIs which do not use JSON-RPC but instead transfer
5013 /// files using multipart/form-data.
5014 ///
5015 /// The result field is deserialized as `serde_json::Value` rather than
5016 /// `String` because different server endpoints return different shapes —
5017 /// `val.data.upload` returns a plain string while `task.data.upload`
5018 /// returns an object `{"message":…,"path":…,"size":…}`. All current
5019 /// callers discard the return value so this is backwards-compatible.
5020 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, form)))]
5021 pub async fn post_multipart(
5022 &self,
5023 method: &str,
5024 form: Form,
5025 ) -> Result<serde_json::Value, Error> {
5026 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5027 .ok()
5028 .and_then(|s| s.parse().ok())
5029 .unwrap_or(600u64);
5030
5031 let req = self
5032 .http
5033 .post(format!("{}/api?method={}", self.url, method))
5034 .header("Accept", "application/json")
5035 .header("User-Agent", "EdgeFirst Client")
5036 .header("Authorization", format!("Bearer {}", self.token().await))
5037 .timeout(Duration::from_secs(upload_timeout_secs))
5038 .multipart(form);
5039 let resp = req.send().await?;
5040
5041 if resp.status().is_success() {
5042 let body = resp.bytes().await?;
5043
5044 if log_enabled!(Level::Trace) {
5045 trace!(
5046 "POST Multipart Response: {}",
5047 String::from_utf8_lossy(&body)
5048 );
5049 }
5050
5051 let response: RpcResponse<serde_json::Value> = match serde_json::from_slice(&body) {
5052 Ok(response) => response,
5053 Err(err) => {
5054 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
5055 return Err(err.into());
5056 }
5057 };
5058
5059 if let Some(error) = response.error {
5060 Err(Error::RpcError(error.code, error.message))
5061 } else if let Some(result) = response.result {
5062 Ok(result)
5063 } else {
5064 Err(Error::InvalidResponse)
5065 }
5066 } else {
5067 // HTTP-level failure on the multipart upload. Map 413 to the
5068 // typed `PayloadTooLarge` variant so callers see the same error
5069 // type from both single-file rpc_download paths and multipart
5070 // upload paths; everything else falls through to HttpError.
5071 let status = resp.status();
5072 if status.as_u16() == 413 {
5073 return Err(Error::PayloadTooLarge {
5074 method: method.to_string(),
5075 size_hint: None,
5076 });
5077 }
5078 let err = resp.error_for_status_ref().unwrap_err();
5079 Err(Error::HttpError(err))
5080 }
5081 }
5082
5083 /// Internal helper: POST a JSON-RPC request and stream the binary response
5084 /// to `output_path`. The response is assumed to be raw binary (not a JSON
5085 /// envelope). Use for endpoints that return file contents directly.
5086 ///
5087 /// On HTTP non-success, the response body is read as text and surfaced
5088 /// via `Error::RpcError(status_code, body)`.
5089 pub(crate) async fn rpc_download<P: Serialize>(
5090 &self,
5091 method: &str,
5092 params: &P,
5093 output_path: &std::path::Path,
5094 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
5095 ) -> Result<(), Error> {
5096 let envelope = serde_json::json!({
5097 "jsonrpc": "2.0",
5098 "id": 0,
5099 "method": method,
5100 "params": params,
5101 });
5102
5103 let url = format!("{}/api", self.url);
5104 let resp = self
5105 .bulk_http
5106 .post(&url)
5107 .header("Authorization", format!("Bearer {}", self.token().await))
5108 .json(&envelope)
5109 .send()
5110 .await?;
5111
5112 let status = resp.status();
5113 if !status.is_success() {
5114 if status.as_u16() == 413 {
5115 return Err(Error::PayloadTooLarge {
5116 method: method.to_string(),
5117 size_hint: None,
5118 });
5119 }
5120 let body = resp.text().await.unwrap_or_default();
5121 return Err(Error::RpcError(status.as_u16() as i32, body));
5122 }
5123
5124 // HTTP 200 with Content-Type: application/json can mean two things:
5125 // (a) a JSON-RPC error envelope when the server failed mid-way
5126 // (e.g. {"jsonrpc":"2.0","error":{"code":N,"message":"..."}}),
5127 // (b) a legitimate JSON file payload — validation traces, chart
5128 // bodies, metrics, etc., are typically served with this MIME.
5129 //
5130 // Disambiguate structurally: a JSON-RPC 2.0 envelope is required to
5131 // carry a `jsonrpc` member, and an *error* envelope further requires
5132 // an `error.code` integer (per RFC 8259 + JSON-RPC 2.0 §5). Only
5133 // decode the body as an error if both markers are present. This is
5134 // strict enough to leave legitimate JSON artifacts that happen to
5135 // contain a free-form `error` field (metrics, diagnostics, log
5136 // dumps) untouched, while still catching every real server
5137 // failure.
5138 let content_type = resp
5139 .headers()
5140 .get(reqwest::header::CONTENT_TYPE)
5141 .and_then(|v| v.to_str().ok())
5142 .unwrap_or("")
5143 .to_owned();
5144 if content_type.contains("application/json") {
5145 let body = resp.bytes().await?;
5146 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(&body)
5147 && is_jsonrpc_error_envelope(&val)
5148 && let Some(err_obj) = val.get("error")
5149 {
5150 let code = err_obj.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
5151 let message = err_obj
5152 .get("message")
5153 .and_then(|m| m.as_str())
5154 .unwrap_or("unknown error")
5155 .to_string();
5156 return Err(Error::RpcError(code, message));
5157 }
5158 // Not an error envelope — body is a JSON file. Write it to disk
5159 // and emit a single completion progress event so callers (e.g.,
5160 // Python download_data progress callbacks) see the download
5161 // finish.
5162 //
5163 // `Path::parent` returns `Some("")` for a bare filename like
5164 // "metrics.json"; `create_dir_all("")` errors out with
5165 // `NotFound`, so only create the parent when it actually names
5166 // a directory.
5167 if let Some(parent) = output_path.parent()
5168 && !parent.as_os_str().is_empty()
5169 {
5170 tokio::fs::create_dir_all(parent).await?;
5171 }
5172 let mut file = tokio::fs::File::create(output_path).await?;
5173 file.write_all(&body).await?;
5174 file.flush().await?;
5175 if let Some(tx) = progress {
5176 let total = body.len();
5177 // Use the awaited send for the final event so completion
5178 // handlers are never silently dropped.
5179 let _ = tx
5180 .send(Progress {
5181 current: total,
5182 total,
5183 status: None,
5184 })
5185 .await;
5186 }
5187 return Ok(());
5188 }
5189
5190 // Same empty-parent guard for the streaming download path: passing
5191 // a bare filename like "metrics.json" must write to the current
5192 // directory rather than failing on `create_dir_all("")`.
5193 if let Some(parent) = output_path.parent()
5194 && !parent.as_os_str().is_empty()
5195 {
5196 tokio::fs::create_dir_all(parent).await?;
5197 }
5198
5199 stream_response_to_file(resp, output_path, progress).await
5200 }
5201
5202 /// Send a JSON-RPC request to the server. The method is the name of the
5203 /// method to call on the server. The params are the parameters to pass to
5204 /// the method. The method and params are serialized into a JSON-RPC
5205 /// request and sent to the server. The response is deserialized into
5206 /// the specified type and returned to the caller.
5207 ///
5208 /// NOTE: This API would generally not be called directly and instead users
5209 /// should use the higher-level methods provided by the client.
5210 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method)))]
5211 pub async fn rpc<Params, RpcResult>(
5212 &self,
5213 method: String,
5214 params: Option<Params>,
5215 ) -> Result<RpcResult, Error>
5216 where
5217 Params: Serialize,
5218 RpcResult: DeserializeOwned,
5219 {
5220 let auth_expires = self.token_expiration().await?;
5221 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
5222 self.renew_token().await?;
5223 }
5224
5225 self.rpc_without_auth(method, params).await
5226 }
5227
5228 #[cfg_attr(feature = "profiling", tracing::instrument(skip(self, params), fields(method = %method, request = tracing::field::Empty, response = tracing::field::Empty)))]
5229 async fn rpc_without_auth<Params, RpcResult>(
5230 &self,
5231 method: String,
5232 params: Option<Params>,
5233 ) -> Result<RpcResult, Error>
5234 where
5235 Params: Serialize,
5236 RpcResult: DeserializeOwned,
5237 {
5238 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5239 .ok()
5240 .and_then(|s| s.parse().ok())
5241 .unwrap_or(5usize);
5242
5243 let url = format!("{}/api", self.url);
5244
5245 // Serialize request body once before retry loop to avoid Clone bound on Params
5246 let request = RpcRequest {
5247 method: method.clone(),
5248 params,
5249 ..Default::default()
5250 };
5251
5252 // Log request for debugging (log crate) and profiling (tracing crate)
5253 let request_json = if method == "auth.login" {
5254 // Redact auth.login params (contains password)
5255 serde_json::json!({
5256 "jsonrpc": "2.0",
5257 "method": &method,
5258 "params": "[REDACTED - contains credentials]",
5259 "id": request.id
5260 })
5261 .to_string()
5262 } else {
5263 serde_json::to_string(&request)?
5264 };
5265
5266 if log_enabled!(Level::Trace) {
5267 trace!("RPC Request: {}", request_json);
5268 }
5269
5270 // Record request on current span for Perfetto when profiling is enabled
5271 #[cfg(feature = "profiling")]
5272 tracing::Span::current().record("request", &request_json);
5273
5274 let request_body = serde_json::to_vec(&request)?;
5275 let mut last_error: Option<Error> = None;
5276
5277 for attempt in 0..=max_retries {
5278 if attempt > 0 {
5279 // Exponential backoff with jitter: base delay * 2^attempt, capped at 30s
5280 // Jitter: randomize between 100%-150% of base delay to avoid thundering herd
5281 // while ensuring we never retry faster than the base delay
5282 let base_delay_secs = (1u64 << (attempt - 1).min(5)).min(30);
5283 let jitter_factor = 1.0 + (rand::random::<f64>() * 0.5); // 1.0 to 1.5
5284 let delay_ms = (base_delay_secs as f64 * 1000.0 * jitter_factor) as u64;
5285 let delay = Duration::from_millis(delay_ms);
5286 warn!(
5287 "Retry {}/{} for RPC '{}' after {:?}",
5288 attempt, max_retries, method, delay
5289 );
5290 tokio::time::sleep(delay).await;
5291 }
5292
5293 let result = self
5294 .http
5295 .post(&url)
5296 .header("Accept", "application/json")
5297 .header("Content-Type", "application/json")
5298 .header("User-Agent", "EdgeFirst Client")
5299 .header("Authorization", format!("Bearer {}", self.token().await))
5300 .body(request_body.clone())
5301 .send()
5302 .await;
5303
5304 match result {
5305 Ok(res) => {
5306 let status = res.status();
5307 let status_code = status.as_u16();
5308
5309 // Check for retryable HTTP status codes before processing response
5310 if matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504)
5311 && attempt < max_retries
5312 {
5313 warn!(
5314 "RPC '{}' failed with HTTP {} (retrying)",
5315 method, status_code
5316 );
5317 last_error = Some(Error::HttpError(res.error_for_status().unwrap_err()));
5318 continue;
5319 }
5320
5321 // Process the response
5322 match self.process_rpc_response(res).await {
5323 Ok(result) => {
5324 if attempt > 0 {
5325 debug!("RPC '{}' succeeded on retry {}", method, attempt);
5326 }
5327 return Ok(result);
5328 }
5329 Err(e) => {
5330 // Don't retry client errors (4xx except 408, 429)
5331 if attempt > 0 {
5332 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
5333 }
5334 return Err(e);
5335 }
5336 }
5337 }
5338 Err(e) => {
5339 // Transport error (timeout, connection failure, etc.)
5340 let is_timeout = e.is_timeout();
5341 let is_connect = e.is_connect();
5342
5343 if (is_timeout || is_connect) && attempt < max_retries {
5344 warn!(
5345 "RPC '{}' transport error (retrying): {}",
5346 method,
5347 if is_timeout {
5348 "timeout"
5349 } else {
5350 "connection failed"
5351 }
5352 );
5353 last_error = Some(Error::HttpError(e));
5354 continue;
5355 }
5356
5357 if attempt > 0 {
5358 error!("RPC '{}' failed after {} retries: {}", method, attempt, e);
5359 }
5360 return Err(Error::HttpError(e));
5361 }
5362 }
5363 }
5364
5365 // Should not reach here
5366 Err(last_error.unwrap_or_else(|| {
5367 Error::InvalidParameters(format!(
5368 "RPC '{}' failed after {} retries",
5369 method, max_retries
5370 ))
5371 }))
5372 }
5373
5374 async fn process_rpc_response<RpcResult>(
5375 &self,
5376 res: reqwest::Response,
5377 ) -> Result<RpcResult, Error>
5378 where
5379 RpcResult: DeserializeOwned,
5380 {
5381 let body = res.bytes().await?;
5382 let response_str = String::from_utf8_lossy(&body);
5383
5384 if log_enabled!(Level::Trace) {
5385 trace!("RPC Response: {}", response_str);
5386 }
5387
5388 // Record response on current span for Perfetto when profiling is enabled
5389 // Truncate large responses to avoid bloating trace files
5390 #[cfg(feature = "profiling")]
5391 {
5392 const MAX_RESPONSE_LEN: usize = 4096;
5393 let truncated = if response_str.len() > MAX_RESPONSE_LEN {
5394 // Use floor_char_boundary to avoid panicking on multi-byte UTF-8 chars
5395 let safe_end = response_str.floor_char_boundary(MAX_RESPONSE_LEN);
5396 format!(
5397 "{}...[truncated {} bytes]",
5398 &response_str[..safe_end],
5399 response_str.len() - safe_end
5400 )
5401 } else {
5402 response_str.to_string()
5403 };
5404 tracing::Span::current().record("response", &truncated);
5405 }
5406
5407 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
5408 Ok(response) => response,
5409 Err(err) => {
5410 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
5411 return Err(err.into());
5412 }
5413 };
5414
5415 // FIXME: Studio Server always returns 999 as the id.
5416 // if request.id.to_string() != response.id {
5417 // return Err(Error::InvalidRpcId(response.id));
5418 // }
5419
5420 if let Some(error) = response.error {
5421 Err(Error::RpcError(error.code, error.message))
5422 } else if let Some(result) = response.result {
5423 Ok(result)
5424 } else {
5425 Err(Error::InvalidResponse)
5426 }
5427 }
5428}
5429
5430/// Process items in parallel with semaphore concurrency control and progress
5431/// tracking.
5432///
5433/// This helper eliminates boilerplate for parallel item processing with:
5434/// - Semaphore limiting concurrent tasks (configurable via `concurrency` param
5435/// or `MAX_TASKS` env var, default: half of CPU cores clamped to 2-8)
5436/// - Atomic progress counter with automatic item-level updates
5437/// - Progress updates sent after each item completes (not byte-level streaming)
5438/// - Proper error propagation from spawned tasks
5439///
5440/// Note: This is optimized for discrete items with post-completion progress
5441/// updates. For byte-level streaming progress or custom retry logic, use
5442/// specialized implementations.
5443///
5444/// # Arguments
5445///
5446/// * `items` - Collection of items to process in parallel
5447/// * `progress` - Optional progress channel for tracking completion
5448/// * `concurrency` - Optional max concurrent tasks (defaults to `max_tasks()`)
5449/// * `work_fn` - Async function to execute for each item
5450///
5451/// # Examples
5452///
5453/// ```rust,ignore
5454/// // Use default concurrency
5455/// parallel_foreach_items(samples, progress, None, |sample| async move {
5456/// sample.download(&client, file_type).await?;
5457/// Ok(())
5458/// }).await?;
5459/// ```
5460async fn parallel_foreach_items<T, F, Fut>(
5461 items: Vec<T>,
5462 progress: Option<Sender<Progress>>,
5463 concurrency: Option<usize>,
5464 work_fn: F,
5465) -> Result<(), Error>
5466where
5467 T: Send + 'static,
5468 F: Fn(T) -> Fut + Send + Sync + 'static,
5469 Fut: Future<Output = Result<(), Error>> + Send + 'static,
5470{
5471 let total = items.len();
5472 let current = Arc::new(AtomicUsize::new(0));
5473 let sem = Arc::new(Semaphore::new(concurrency.unwrap_or_else(max_tasks)));
5474 let work_fn = Arc::new(work_fn);
5475
5476 let tasks = items
5477 .into_iter()
5478 .map(|item| {
5479 let sem = sem.clone();
5480 let current = current.clone();
5481 let progress = progress.clone();
5482 let work_fn = work_fn.clone();
5483
5484 tokio::spawn(async move {
5485 let _permit = sem.acquire().await.map_err(|_| {
5486 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5487 })?;
5488
5489 // Execute the actual work
5490 work_fn(item).await?;
5491
5492 // Update progress
5493 if let Some(progress) = &progress {
5494 let current = current.fetch_add(1, Ordering::SeqCst);
5495 let _ = progress
5496 .send(Progress {
5497 current: current + 1,
5498 total,
5499 status: None,
5500 })
5501 .await;
5502 }
5503
5504 Ok::<(), Error>(())
5505 })
5506 })
5507 .collect::<Vec<_>>();
5508
5509 join_all(tasks)
5510 .await
5511 .into_iter()
5512 .collect::<Result<Vec<_>, _>>()?
5513 .into_iter()
5514 .collect::<Result<Vec<_>, _>>()?;
5515
5516 if let Some(progress) = progress {
5517 drop(progress);
5518 }
5519
5520 Ok(())
5521}
5522
5523/// Upload a file to S3 using multipart upload with presigned URLs.
5524///
5525/// Splits a file into chunks (100MB each) and uploads them in parallel using
5526/// S3 multipart upload protocol. Returns completion parameters with ETags for
5527/// finalizing the upload.
5528///
5529/// This function handles:
5530/// - Splitting files into parts based on PART_SIZE (100MB)
5531/// - Parallel upload with concurrency limiting via `max_tasks()` (configurable
5532/// with `MAX_TASKS`, default: half of CPU cores, min 2, max 8)
5533/// - Retry logic (handled by reqwest client)
5534/// - Progress tracking across all parts
5535///
5536/// # Arguments
5537///
5538/// * `http` - HTTP client for making requests
5539/// * `part` - Snapshot part info with presigned URLs for each chunk
5540/// * `path` - Local file path to upload
5541/// * `total` - Total bytes across all files for progress calculation
5542/// * `current` - Atomic counter tracking bytes uploaded across all operations
5543/// * `progress` - Optional channel for sending progress updates
5544///
5545/// # Returns
5546///
5547/// Parameters needed to complete the multipart upload (key, upload_id, ETags)
5548async fn upload_multipart(
5549 http: reqwest::Client,
5550 part: SnapshotPart,
5551 path: PathBuf,
5552 total: usize,
5553 confirmed_bytes: Arc<AtomicUsize>,
5554 progress: Option<Sender<Progress>>,
5555) -> Result<SnapshotCompleteMultipartParams, Error> {
5556 let filesize = path.metadata()?.len() as usize;
5557 let n_parts = filesize.div_ceil(PART_SIZE);
5558 let sem = Arc::new(Semaphore::new(max_upload_tasks()));
5559
5560 let key = part.key.ok_or(Error::InvalidResponse)?;
5561 let upload_id = part.upload_id;
5562
5563 let urls = part.urls.clone();
5564
5565 // Pre-allocate ETag slots for all parts
5566 let etags = Arc::new(tokio::sync::Mutex::new(vec![
5567 EtagPart {
5568 etag: "".to_owned(),
5569 part_number: 0,
5570 };
5571 n_parts
5572 ]));
5573
5574 // Per-part byte counters for streaming progress (reset on retry)
5575 let part_bytes: Arc<Vec<AtomicUsize>> = Arc::new(
5576 (0..n_parts)
5577 .map(|_| AtomicUsize::new(0))
5578 .collect::<Vec<_>>(),
5579 );
5580
5581 // Upload all parts in parallel with concurrency limiting
5582 let tasks = (0..n_parts)
5583 .map(|part_idx| {
5584 let http = http.clone();
5585 let url = urls[part_idx].clone();
5586 let etags = etags.clone();
5587 let path = path.to_owned();
5588 let sem = sem.clone();
5589 let progress = progress.clone();
5590 let confirmed_bytes = confirmed_bytes.clone();
5591 let part_bytes = part_bytes.clone();
5592
5593 // Calculate this part's size
5594 let part_size = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5595 filesize % PART_SIZE
5596 } else {
5597 PART_SIZE
5598 };
5599
5600 tokio::spawn(async move {
5601 // Acquire semaphore permit to limit concurrent uploads
5602 let _permit = sem.acquire().await.map_err(|_| {
5603 Error::IoError(std::io::Error::other("Semaphore closed unexpectedly"))
5604 })?;
5605
5606 // Upload part with streaming progress and retry logic
5607 let etag = upload_part_with_progress(
5608 http,
5609 url,
5610 path,
5611 part_idx,
5612 n_parts,
5613 part_size,
5614 total,
5615 confirmed_bytes.clone(),
5616 part_bytes.clone(),
5617 progress.clone(),
5618 )
5619 .await?;
5620
5621 // Store ETag for this part (needed to complete multipart upload)
5622 let mut etags_guard = etags.lock().await;
5623 etags_guard[part_idx] = EtagPart {
5624 etag,
5625 part_number: part_idx + 1,
5626 };
5627
5628 // Part completed successfully - add to confirmed bytes
5629 confirmed_bytes.fetch_add(part_size, Ordering::SeqCst);
5630 // Reset part counter since it's now confirmed
5631 part_bytes[part_idx].store(0, Ordering::SeqCst);
5632
5633 // Send final progress update for this part
5634 if let Some(progress) = &progress {
5635 let current = confirmed_bytes.load(Ordering::SeqCst)
5636 + part_bytes
5637 .iter()
5638 .map(|p| p.load(Ordering::SeqCst))
5639 .sum::<usize>();
5640 let _ = progress
5641 .send(Progress {
5642 current,
5643 total,
5644 status: None,
5645 })
5646 .await;
5647 }
5648
5649 Ok::<(), Error>(())
5650 })
5651 })
5652 .collect::<Vec<_>>();
5653
5654 // Wait for all parts to complete (double collect to handle both JoinError and
5655 // inner Error)
5656 join_all(tasks)
5657 .await
5658 .into_iter()
5659 .collect::<Result<Vec<_>, _>>()?
5660 .into_iter()
5661 .collect::<Result<Vec<_>, _>>()?;
5662
5663 Ok(SnapshotCompleteMultipartParams {
5664 key,
5665 upload_id,
5666 etag_list: etags.lock().await.clone(),
5667 })
5668}
5669
5670/// Upload a single part with streaming progress tracking and retry logic.
5671///
5672/// Progress is reported continuously as bytes are sent. On retry, the part's
5673/// progress counter is reset to avoid over-reporting.
5674#[allow(clippy::too_many_arguments)]
5675async fn upload_part_with_progress(
5676 http: reqwest::Client,
5677 url: String,
5678 path: PathBuf,
5679 part_idx: usize,
5680 n_parts: usize,
5681 part_size: usize,
5682 total: usize,
5683 confirmed_bytes: Arc<AtomicUsize>,
5684 part_bytes: Arc<Vec<AtomicUsize>>,
5685 progress: Option<Sender<Progress>>,
5686) -> Result<String, Error> {
5687 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5688 .ok()
5689 .and_then(|s| s.parse().ok())
5690 .unwrap_or(5usize);
5691
5692 // Per-part total upload timeout. Covers the send phase (request body) where
5693 // read_timeout does not apply. Each part is at most PART_SIZE (100MB), so
5694 // this bounds how long a stalled upload can block before retrying.
5695 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5696 .ok()
5697 .and_then(|s| s.parse().ok())
5698 .unwrap_or(600u64); // 600s = 100MB at ~170 KB/s minimum
5699
5700 let mut last_error: Option<Error> = None;
5701
5702 for attempt in 0..=max_retries {
5703 if attempt > 0 {
5704 // Reset this part's progress counter before retry
5705 part_bytes[part_idx].store(0, Ordering::SeqCst);
5706
5707 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5708 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5709 warn!(
5710 "Retry {}/{} for part {} after {:?}",
5711 attempt, max_retries, part_idx, delay
5712 );
5713 tokio::time::sleep(delay).await;
5714 }
5715
5716 match upload_part_streaming(
5717 http.clone(),
5718 url.clone(),
5719 path.clone(),
5720 part_idx,
5721 n_parts,
5722 part_size,
5723 total,
5724 upload_timeout_secs,
5725 confirmed_bytes.clone(),
5726 part_bytes.clone(),
5727 progress.clone(),
5728 )
5729 .await
5730 {
5731 Ok(etag) => return Ok(etag),
5732 Err(e) => {
5733 // Check if error is retryable
5734 let is_retryable = matches!(
5735 &e,
5736 Error::HttpError(re) if re.is_timeout() || re.is_connect() ||
5737 re.status().map(|s: reqwest::StatusCode| s.as_u16()).unwrap_or(0) >= 500
5738 );
5739
5740 if is_retryable && attempt < max_retries {
5741 last_error = Some(e);
5742 continue;
5743 }
5744
5745 return Err(e);
5746 }
5747 }
5748 }
5749
5750 Err(last_error
5751 .unwrap_or_else(|| Error::IoError(std::io::Error::other("Upload failed after retries"))))
5752}
5753
5754/// Perform the actual upload with streaming progress.
5755#[allow(clippy::too_many_arguments)]
5756async fn upload_part_streaming(
5757 http: reqwest::Client,
5758 url: String,
5759 path: PathBuf,
5760 part_idx: usize,
5761 n_parts: usize,
5762 _part_size: usize,
5763 total: usize,
5764 upload_timeout_secs: u64,
5765 confirmed_bytes: Arc<AtomicUsize>,
5766 part_bytes: Arc<Vec<AtomicUsize>>,
5767 progress: Option<Sender<Progress>>,
5768) -> Result<String, Error> {
5769 let filesize = path.metadata()?.len() as usize;
5770 let mut file = File::open(&path).await?;
5771 file.seek(SeekFrom::Start((part_idx * PART_SIZE) as u64))
5772 .await?;
5773 let file = file.take(PART_SIZE as u64);
5774
5775 let body_length = if part_idx + 1 == n_parts && !filesize.is_multiple_of(PART_SIZE) {
5776 filesize % PART_SIZE
5777 } else {
5778 PART_SIZE
5779 };
5780
5781 // Create stream with progress tracking
5782 let stream = FramedRead::new(file, BytesCodec::new());
5783
5784 // Wrap stream to track bytes sent and report progress
5785 let progress_stream = stream.map(move |result| {
5786 if let Ok(ref bytes) = result {
5787 let bytes_len = bytes.len();
5788 part_bytes[part_idx].fetch_add(bytes_len, Ordering::SeqCst);
5789
5790 // Send progress update (fire-and-forget via try_send to avoid blocking)
5791 if let Some(ref progress) = progress {
5792 let current = confirmed_bytes.load(Ordering::SeqCst)
5793 + part_bytes
5794 .iter()
5795 .map(|p| p.load(Ordering::SeqCst))
5796 .sum::<usize>();
5797 // Best-effort progress reporting: use try_send to avoid blocking.
5798 // If the channel is full or closed, we intentionally skip this update
5799 // to avoid stalling the upload; subsequent updates will still be delivered.
5800 let _ = progress.try_send(Progress {
5801 current,
5802 total,
5803 status: None,
5804 });
5805 }
5806 }
5807 result.map(|b| b.freeze())
5808 });
5809
5810 let body = Body::wrap_stream(progress_stream);
5811
5812 let resp = http
5813 .put(url)
5814 .header(CONTENT_LENGTH, body_length)
5815 .timeout(Duration::from_secs(upload_timeout_secs))
5816 .body(body)
5817 .send()
5818 .await?
5819 .error_for_status()?;
5820
5821 let etag = resp
5822 .headers()
5823 .get("etag")
5824 .ok_or_else(|| Error::InvalidEtag("Missing ETag header".to_string()))?
5825 .to_str()
5826 .map_err(|_| Error::InvalidEtag("Invalid ETag encoding".to_string()))?
5827 .to_owned();
5828
5829 // Studio Server requires etag without the quotes.
5830 let etag = etag
5831 .strip_prefix("\"")
5832 .ok_or_else(|| Error::InvalidEtag("Missing opening quote".to_string()))?;
5833 let etag = etag
5834 .strip_suffix("\"")
5835 .ok_or_else(|| Error::InvalidEtag("Missing closing quote".to_string()))?;
5836
5837 Ok(etag.to_owned())
5838}
5839
5840/// Upload a complete file to a presigned S3 URL using HTTP PUT.
5841///
5842/// This is used for populate_samples to upload files to S3 after
5843/// receiving presigned URLs from the server.
5844///
5845/// Includes explicit retry logic with exponential backoff for transient
5846/// failures.
5847/// Classify a reqwest transport error (one where no HTTP response was received)
5848/// as a transient failure worth retrying.
5849///
5850/// Presigned-URL uploads buffer the body in memory and a PUT to the same object
5851/// key is idempotent, so replaying any transport-level failure is safe. Besides
5852/// timeouts and connect failures this covers request/body send errors such as
5853/// hyper's `IncompleteMessage` (a peer closing a keep-alive connection mid-send)
5854/// — transients that pipelined, high-concurrency uploads provoke far more often
5855/// than serial ones, and which the previous `is_timeout() || is_connect()` gate
5856/// missed (aborting the whole upload on a single blip).
5857fn is_retryable_upload_error(e: &reqwest::Error) -> bool {
5858 e.is_timeout() || e.is_connect() || e.is_request() || e.is_body()
5859}
5860
5861/// Reliable, `Instant`-based upload timing accumulators (profiling builds only).
5862///
5863/// Async `tracing` spans cannot measure per-await latency or task concurrency
5864/// under a multi-threaded runtime — a future's span fragments across worker
5865/// threads — so these atomics accumulate real measured durations and byte counts
5866/// for a trustworthy phase breakdown. Durations are summed across concurrent
5867/// batches, so totals can exceed wall-clock; `(rpc + upload) / wall` gives the
5868/// effective parallelism, and `bytes / wall` the effective upload bandwidth.
5869#[cfg(feature = "profiling")]
5870pub mod upload_stats {
5871 use std::sync::atomic::{AtomicU64, Ordering};
5872
5873 static RPC_NANOS: AtomicU64 = AtomicU64::new(0);
5874 static UPLOAD_NANOS: AtomicU64 = AtomicU64::new(0);
5875 static UPLOAD_BYTES: AtomicU64 = AtomicU64::new(0);
5876
5877 pub(crate) fn add_rpc_nanos(n: u64) {
5878 RPC_NANOS.fetch_add(n, Ordering::Relaxed);
5879 }
5880 pub(crate) fn add_upload_nanos(n: u64) {
5881 UPLOAD_NANOS.fetch_add(n, Ordering::Relaxed);
5882 }
5883 pub(crate) fn add_upload_bytes(n: u64) {
5884 UPLOAD_BYTES.fetch_add(n, Ordering::Relaxed);
5885 }
5886
5887 /// Zero all accumulators. Call once before starting an upload.
5888 pub fn reset() {
5889 RPC_NANOS.store(0, Ordering::Relaxed);
5890 UPLOAD_NANOS.store(0, Ordering::Relaxed);
5891 UPLOAD_BYTES.store(0, Ordering::Relaxed);
5892 }
5893
5894 /// Snapshot of `(rpc_nanos, upload_nanos, upload_bytes)` accumulated so far.
5895 pub fn snapshot() -> (u64, u64, u64) {
5896 (
5897 RPC_NANOS.load(Ordering::Relaxed),
5898 UPLOAD_NANOS.load(Ordering::Relaxed),
5899 UPLOAD_BYTES.load(Ordering::Relaxed),
5900 )
5901 }
5902}
5903
5904async fn upload_file_to_presigned_url(
5905 http: reqwest::Client,
5906 url: &str,
5907 path: PathBuf,
5908) -> Result<(), Error> {
5909 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
5910 .ok()
5911 .and_then(|s| s.parse().ok())
5912 .unwrap_or(5usize);
5913
5914 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
5915 .ok()
5916 .and_then(|s| s.parse().ok())
5917 .unwrap_or(600u64);
5918
5919 // Read the entire file into memory once
5920 let file_data = fs::read(&path).await?;
5921 let file_size = file_data.len();
5922 let filename = path.file_name().unwrap_or_default().to_string_lossy();
5923
5924 let mut last_error: Option<Error> = None;
5925
5926 for attempt in 0..=max_retries {
5927 if attempt > 0 {
5928 // Exponential backoff: 1s, 2s, 4s, 8s, ...
5929 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
5930 warn!(
5931 "Retry {}/{} for upload '{}' after {:?}",
5932 attempt, max_retries, filename, delay
5933 );
5934 tokio::time::sleep(delay).await;
5935 }
5936
5937 // Attempt upload
5938 let result = http
5939 .put(url)
5940 .header(CONTENT_LENGTH, file_size)
5941 .timeout(Duration::from_secs(upload_timeout_secs))
5942 .body(file_data.clone())
5943 .send()
5944 .await;
5945
5946 match result {
5947 Ok(resp) => {
5948 if resp.status().is_success() {
5949 if attempt > 0 {
5950 debug!(
5951 "Upload '{}' succeeded on retry {} ({} bytes)",
5952 filename, attempt, file_size
5953 );
5954 } else {
5955 debug!(
5956 "Successfully uploaded file: {} ({} bytes)",
5957 filename, file_size
5958 );
5959 }
5960 #[cfg(feature = "profiling")]
5961 upload_stats::add_upload_bytes(file_size as u64);
5962 return Ok(());
5963 }
5964
5965 let status = resp.status();
5966 let status_code = status.as_u16();
5967
5968 // Check if error is retryable
5969 let is_retryable =
5970 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
5971
5972 if is_retryable && attempt < max_retries {
5973 let error_text = resp.text().await.unwrap_or_default();
5974 warn!(
5975 "Upload '{}' failed with HTTP {} (retryable): {}",
5976 filename, status_code, error_text
5977 );
5978 last_error = Some(Error::InvalidParameters(format!(
5979 "Upload failed: HTTP {} - {}",
5980 status, error_text
5981 )));
5982 continue;
5983 }
5984
5985 // Non-retryable error or max retries exceeded
5986 let error_text = resp.text().await.unwrap_or_default();
5987 if attempt > 0 {
5988 error!(
5989 "Upload '{}' failed after {} retries: HTTP {} - {}",
5990 filename, attempt, status, error_text
5991 );
5992 }
5993 return Err(Error::InvalidParameters(format!(
5994 "Upload failed: HTTP {} - {}",
5995 status, error_text
5996 )));
5997 }
5998 Err(e) => {
5999 // Transport error: no HTTP response was received. The body is
6000 // buffered in memory and the PUT is idempotent, so any transient
6001 // transport failure is safe to replay (see
6002 // `is_retryable_upload_error`).
6003 if is_retryable_upload_error(&e) && attempt < max_retries {
6004 warn!("Upload '{}' transport error (retrying): {}", filename, e);
6005 last_error = Some(Error::HttpError(e));
6006 continue;
6007 }
6008
6009 // Non-retryable or max retries exceeded
6010 if attempt > 0 {
6011 error!(
6012 "Upload '{}' failed after {} retries: {}",
6013 filename, attempt, e
6014 );
6015 }
6016 return Err(Error::HttpError(e));
6017 }
6018 }
6019 }
6020
6021 // Should not reach here, but return last error if we do
6022 Err(last_error.unwrap_or_else(|| {
6023 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
6024 }))
6025}
6026
6027/// Upload bytes directly to a presigned S3 URL using HTTP PUT.
6028///
6029/// This is used for populate_samples to upload file content from memory
6030/// (e.g., from ZIP archives) without writing to disk first.
6031///
6032/// Includes explicit retry logic with exponential backoff for transient
6033/// failures.
6034async fn upload_bytes_to_presigned_url(
6035 http: reqwest::Client,
6036 url: &str,
6037 file_data: Vec<u8>,
6038 filename: &str,
6039) -> Result<(), Error> {
6040 let max_retries = std::env::var("EDGEFIRST_MAX_RETRIES")
6041 .ok()
6042 .and_then(|s| s.parse().ok())
6043 .unwrap_or(5usize);
6044
6045 let upload_timeout_secs = std::env::var("EDGEFIRST_UPLOAD_TIMEOUT")
6046 .ok()
6047 .and_then(|s| s.parse().ok())
6048 .unwrap_or(600u64);
6049
6050 let file_size = file_data.len();
6051 let mut last_error: Option<Error> = None;
6052
6053 for attempt in 0..=max_retries {
6054 if attempt > 0 {
6055 // Exponential backoff: 1s, 2s, 4s, 8s, ...
6056 let delay = Duration::from_secs(1 << (attempt - 1).min(4));
6057 warn!(
6058 "Retry {}/{} for upload '{}' after {:?}",
6059 attempt, max_retries, filename, delay
6060 );
6061 tokio::time::sleep(delay).await;
6062 }
6063
6064 // Attempt upload
6065 let result = http
6066 .put(url)
6067 .header(CONTENT_LENGTH, file_size)
6068 .timeout(Duration::from_secs(upload_timeout_secs))
6069 .body(file_data.clone())
6070 .send()
6071 .await;
6072
6073 match result {
6074 Ok(resp) => {
6075 if resp.status().is_success() {
6076 if attempt > 0 {
6077 debug!(
6078 "Upload '{}' succeeded on retry {} ({} bytes)",
6079 filename, attempt, file_size
6080 );
6081 } else {
6082 debug!(
6083 "Successfully uploaded file: {} ({} bytes)",
6084 filename, file_size
6085 );
6086 }
6087 #[cfg(feature = "profiling")]
6088 upload_stats::add_upload_bytes(file_size as u64);
6089 return Ok(());
6090 }
6091
6092 let status = resp.status();
6093 let status_code = status.as_u16();
6094
6095 // Check if error is retryable
6096 let is_retryable =
6097 matches!(status_code, 408 | 429 | 500 | 502 | 503 | 504 | 409 | 423);
6098
6099 if is_retryable && attempt < max_retries {
6100 let error_text = resp.text().await.unwrap_or_default();
6101 warn!(
6102 "Upload '{}' failed with HTTP {} (retryable): {}",
6103 filename, status_code, error_text
6104 );
6105 last_error = Some(Error::InvalidParameters(format!(
6106 "Upload failed: HTTP {} - {}",
6107 status, error_text
6108 )));
6109 continue;
6110 }
6111
6112 // Non-retryable error or max retries exceeded
6113 let error_text = resp.text().await.unwrap_or_default();
6114 if attempt > 0 {
6115 error!(
6116 "Upload '{}' failed after {} retries: HTTP {} - {}",
6117 filename, attempt, status, error_text
6118 );
6119 }
6120 return Err(Error::InvalidParameters(format!(
6121 "Upload failed: HTTP {} - {}",
6122 status, error_text
6123 )));
6124 }
6125 Err(e) => {
6126 // Transport error: no HTTP response was received. The body is
6127 // buffered in memory and the PUT is idempotent, so any transient
6128 // transport failure is safe to replay (see
6129 // `is_retryable_upload_error`).
6130 if is_retryable_upload_error(&e) && attempt < max_retries {
6131 warn!("Upload '{}' transport error (retrying): {}", filename, e);
6132 last_error = Some(Error::HttpError(e));
6133 continue;
6134 }
6135
6136 // Non-retryable or max retries exceeded
6137 if attempt > 0 {
6138 error!(
6139 "Upload '{}' failed after {} retries: {}",
6140 filename, attempt, e
6141 );
6142 }
6143 return Err(Error::HttpError(e));
6144 }
6145 }
6146 }
6147
6148 // Should not reach here, but return last error if we do
6149 Err(last_error.unwrap_or_else(|| {
6150 Error::InvalidParameters(format!("Upload failed after {} retries", max_retries))
6151 }))
6152}
6153
6154#[cfg(test)]
6155mod tests {
6156 use super::*;
6157
6158 #[test]
6159 fn test_filter_and_sort_by_name_exact_match_first() {
6160 // Test that exact matches come first
6161 let items = vec![
6162 "Deer Roundtrip 123".to_string(),
6163 "Deer".to_string(),
6164 "Reindeer".to_string(),
6165 "DEER".to_string(),
6166 ];
6167 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6168 assert_eq!(result[0], "Deer"); // Exact match first
6169 assert_eq!(result[1], "DEER"); // Case-insensitive exact match second
6170 }
6171
6172 #[test]
6173 fn test_filter_and_sort_by_name_shorter_names_preferred() {
6174 // Test that shorter names (more specific) come before longer ones
6175 let items = vec![
6176 "Test Dataset ABC".to_string(),
6177 "Test".to_string(),
6178 "Test Dataset".to_string(),
6179 ];
6180 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
6181 assert_eq!(result[0], "Test"); // Exact match first
6182 assert_eq!(result[1], "Test Dataset"); // Shorter substring match
6183 assert_eq!(result[2], "Test Dataset ABC"); // Longer substring match
6184 }
6185
6186 #[test]
6187 fn test_filter_and_sort_by_name_case_insensitive_filter() {
6188 // Test that filtering is case-insensitive
6189 let items = vec![
6190 "UPPERCASE".to_string(),
6191 "lowercase".to_string(),
6192 "MixedCase".to_string(),
6193 ];
6194 let result = filter_and_sort_by_name(items, "case", |s| s.as_str());
6195 assert_eq!(result.len(), 3); // All items should match
6196 }
6197
6198 #[test]
6199 fn test_filter_and_sort_by_name_no_matches() {
6200 // Test that empty result is returned when no matches
6201 let items = vec!["Apple".to_string(), "Banana".to_string()];
6202 let result = filter_and_sort_by_name(items, "Cherry", |s| s.as_str());
6203 assert!(result.is_empty());
6204 }
6205
6206 #[test]
6207 fn test_filter_and_sort_by_name_alphabetical_tiebreaker() {
6208 // Test alphabetical ordering for same-length names
6209 let items = vec![
6210 "TestC".to_string(),
6211 "TestA".to_string(),
6212 "TestB".to_string(),
6213 ];
6214 let result = filter_and_sort_by_name(items, "Test", |s| s.as_str());
6215 assert_eq!(result, vec!["TestA", "TestB", "TestC"]);
6216 }
6217
6218 #[test]
6219 fn test_collect_labels_from_samples() {
6220 let mut sample = Sample::new();
6221 let mut ann = Annotation::new();
6222 ann.set_label(Some("ace".to_string()));
6223 ann.set_label_index(Some(12));
6224 sample.annotations.push(ann);
6225 let (names, indices) = Client::collect_labels_from_samples(&[sample]).unwrap();
6226 assert_eq!(names, vec!["ace".to_string()]);
6227 assert_eq!(indices, vec![Some(12)]);
6228 }
6229
6230 #[test]
6231 fn test_collect_labels_from_samples_inconsistent_name() {
6232 let mut s1 = Sample::new();
6233 let mut a1 = Annotation::new();
6234 a1.set_label(Some("ace".to_string()));
6235 a1.set_label_index(Some(12));
6236 s1.annotations.push(a1);
6237
6238 let mut s2 = Sample::new();
6239 let mut a2 = Annotation::new();
6240 a2.set_label(Some("ace".to_string()));
6241 a2.set_label_index(Some(2));
6242 s2.annotations.push(a2);
6243
6244 let err = Client::collect_labels_from_samples(&[s1, s2]).unwrap_err();
6245 assert!(err.to_string().contains("inconsistent label_index"));
6246 }
6247
6248 #[test]
6249 fn test_validate_label_batch_duplicate_index() {
6250 let names = vec!["ace".to_string(), "king".to_string()];
6251 let indices = [Some(12_u64), Some(12)];
6252 let err = Client::validate_label_batch(&names, Some(&indices)).unwrap_err();
6253 assert!(err.to_string().contains("duplicate label_index"));
6254 }
6255
6256 #[test]
6257 fn test_build_filename_no_flatten() {
6258 // When flatten=false, should return base_name unchanged
6259 let result = Client::build_filename("image.jpg", false, Some(&"seq".to_string()), Some(42));
6260 assert_eq!(result, "image.jpg");
6261
6262 let result = Client::build_filename("test.png", false, None, None);
6263 assert_eq!(result, "test.png");
6264 }
6265
6266 #[test]
6267 fn test_build_filename_flatten_no_sequence() {
6268 // When flatten=true but no sequence, should return base_name unchanged
6269 let result = Client::build_filename("standalone.jpg", true, None, None);
6270 assert_eq!(result, "standalone.jpg");
6271 }
6272
6273 #[test]
6274 fn test_build_filename_flatten_with_sequence_not_prefixed() {
6275 // When flatten=true, in sequence, filename not prefixed → add prefix
6276 let result = Client::build_filename(
6277 "image.camera.jpeg",
6278 true,
6279 Some(&"deer_sequence".to_string()),
6280 Some(42),
6281 );
6282 assert_eq!(result, "deer_sequence_42_image.camera.jpeg");
6283 }
6284
6285 #[test]
6286 fn test_build_filename_flatten_with_sequence_no_frame() {
6287 // When flatten=true, in sequence, no frame number → prefix with sequence only
6288 let result =
6289 Client::build_filename("image.jpg", true, Some(&"sequence_A".to_string()), None);
6290 assert_eq!(result, "sequence_A_image.jpg");
6291 }
6292
6293 #[test]
6294 fn test_build_filename_flatten_already_prefixed() {
6295 // When flatten=true, filename already starts with sequence_ → return unchanged
6296 let result = Client::build_filename(
6297 "deer_sequence_042.camera.jpeg",
6298 true,
6299 Some(&"deer_sequence".to_string()),
6300 Some(42),
6301 );
6302 assert_eq!(result, "deer_sequence_042.camera.jpeg");
6303 }
6304
6305 #[test]
6306 fn test_build_filename_flatten_already_prefixed_different_frame() {
6307 // Edge case: filename has sequence prefix but we're adding different frame
6308 // Should still respect existing prefix
6309 let result = Client::build_filename(
6310 "sequence_A_001.jpg",
6311 true,
6312 Some(&"sequence_A".to_string()),
6313 Some(2),
6314 );
6315 assert_eq!(result, "sequence_A_001.jpg");
6316 }
6317
6318 #[test]
6319 fn test_build_filename_flatten_partial_match() {
6320 // Edge case: filename contains sequence name but not as prefix
6321 let result = Client::build_filename(
6322 "test_sequence_A_image.jpg",
6323 true,
6324 Some(&"sequence_A".to_string()),
6325 Some(5),
6326 );
6327 // Should add prefix because it doesn't START with "sequence_A_"
6328 assert_eq!(result, "sequence_A_5_test_sequence_A_image.jpg");
6329 }
6330
6331 #[test]
6332 fn test_build_filename_flatten_preserves_extension() {
6333 // Verify that file extensions are preserved correctly
6334 let extensions = vec![
6335 "jpeg",
6336 "jpg",
6337 "png",
6338 "camera.jpeg",
6339 "lidar.pcd",
6340 "depth.png",
6341 ];
6342
6343 for ext in extensions {
6344 let filename = format!("image.{}", ext);
6345 let result = Client::build_filename(&filename, true, Some(&"seq".to_string()), Some(1));
6346 assert!(
6347 result.ends_with(&format!(".{}", ext)),
6348 "Extension .{} not preserved in {}",
6349 ext,
6350 result
6351 );
6352 }
6353 }
6354
6355 #[test]
6356 fn test_build_filename_flatten_sanitization_compatibility() {
6357 // Test with sanitized path components (no special chars)
6358 let result = Client::build_filename(
6359 "sample_001.jpg",
6360 true,
6361 Some(&"seq_name_with_underscores".to_string()),
6362 Some(10),
6363 );
6364 assert_eq!(result, "seq_name_with_underscores_10_sample_001.jpg");
6365 }
6366
6367 // =========================================================================
6368 // Additional filter_and_sort_by_name tests for exact match determinism
6369 // =========================================================================
6370
6371 #[test]
6372 fn test_filter_and_sort_by_name_exact_match_is_deterministic() {
6373 // Test that searching for "Deer" always returns "Deer" first, not
6374 // "Deer Roundtrip 20251129" or similar
6375 let items = vec![
6376 "Deer Roundtrip 20251129".to_string(),
6377 "White-Tailed Deer".to_string(),
6378 "Deer".to_string(),
6379 "Deer Snapshot Test".to_string(),
6380 "Reindeer Dataset".to_string(),
6381 ];
6382
6383 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6384
6385 // CRITICAL: First result must be exact match "Deer"
6386 assert_eq!(
6387 result.first().map(|s| s.as_str()),
6388 Some("Deer"),
6389 "Expected exact match 'Deer' first, got: {:?}",
6390 result.first()
6391 );
6392
6393 // Verify all items containing "Deer" are present (case-insensitive)
6394 assert_eq!(result.len(), 5);
6395 }
6396
6397 #[test]
6398 fn test_filter_and_sort_by_name_exact_match_with_different_cases() {
6399 // Verify case-sensitive exact match takes priority over case-insensitive
6400 let items = vec![
6401 "DEER".to_string(),
6402 "deer".to_string(),
6403 "Deer".to_string(),
6404 "Deer Test".to_string(),
6405 ];
6406
6407 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6408
6409 // Priority 1: Case-sensitive exact match "Deer" first
6410 assert_eq!(result[0], "Deer");
6411 // Priority 2: Case-insensitive exact matches next
6412 assert!(result[1] == "DEER" || result[1] == "deer");
6413 assert!(result[2] == "DEER" || result[2] == "deer");
6414 }
6415
6416 #[test]
6417 fn test_filter_and_sort_by_name_snapshot_realistic_scenario() {
6418 // Realistic scenario: User searches for snapshot "Deer" and multiple
6419 // snapshots exist with similar names
6420 let items = vec![
6421 "Unit Testing - Deer Dataset Backup".to_string(),
6422 "Deer".to_string(),
6423 "Deer Snapshot 2025-01-15".to_string(),
6424 "Original Deer".to_string(),
6425 ];
6426
6427 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6428
6429 // MUST return exact match first for deterministic test behavior
6430 assert_eq!(
6431 result[0], "Deer",
6432 "Searching for 'Deer' should return exact 'Deer' first"
6433 );
6434 }
6435
6436 #[test]
6437 fn test_filter_and_sort_by_name_dataset_realistic_scenario() {
6438 // Realistic scenario: User searches for dataset "Deer" but multiple
6439 // datasets have "Deer" in their name
6440 let items = vec![
6441 "Deer Roundtrip".to_string(),
6442 "Deer".to_string(),
6443 "deer".to_string(),
6444 "White-Tailed Deer".to_string(),
6445 "Deer-V2".to_string(),
6446 ];
6447
6448 let result = filter_and_sort_by_name(items, "Deer", |s| s.as_str());
6449
6450 // Exact case-sensitive match must be first
6451 assert_eq!(result[0], "Deer");
6452 // Case-insensitive exact match should be second
6453 assert_eq!(result[1], "deer");
6454 // Shorter names should come before longer names
6455 assert!(
6456 result.iter().position(|s| s == "Deer-V2").unwrap()
6457 < result.iter().position(|s| s == "Deer Roundtrip").unwrap()
6458 );
6459 }
6460
6461 #[test]
6462 fn test_filter_and_sort_by_name_first_result_is_always_best_match() {
6463 // CRITICAL: The first result should ALWAYS be the best match
6464 // This is essential for deterministic test behavior
6465 let scenarios = vec![
6466 // (items, filter, expected_first)
6467 (vec!["Deer Dataset", "Deer", "deer"], "Deer", "Deer"),
6468 (vec!["test", "TEST", "Test Data"], "test", "test"),
6469 (vec!["ABC", "ABCD", "abc"], "ABC", "ABC"),
6470 ];
6471
6472 for (items, filter, expected_first) in scenarios {
6473 let items: Vec<String> = items.iter().map(|s| s.to_string()).collect();
6474 let result = filter_and_sort_by_name(items, filter, |s| s.as_str());
6475
6476 assert_eq!(
6477 result.first().map(|s| s.as_str()),
6478 Some(expected_first),
6479 "For filter '{}', expected first result '{}', got: {:?}",
6480 filter,
6481 expected_first,
6482 result.first()
6483 );
6484 }
6485 }
6486
6487 #[test]
6488 fn test_with_server_clears_storage() {
6489 use crate::storage::MemoryTokenStorage;
6490
6491 // Create client with memory storage and a token
6492 let storage = Arc::new(MemoryTokenStorage::new());
6493 storage.store("test-token").unwrap();
6494
6495 let client = Client::new().unwrap().with_storage(storage.clone());
6496
6497 // Verify token is loaded
6498 assert_eq!(storage.load().unwrap(), Some("test-token".to_string()));
6499
6500 // Change server - should clear storage
6501 let _new_client = client.with_server("test").unwrap();
6502
6503 // Verify storage was cleared
6504 assert_eq!(storage.load().unwrap(), None);
6505 }
6506
6507 #[test]
6508 fn test_with_server_clears_storage_even_for_full_url() {
6509 // Regression: `with_server` used to short-circuit to `with_url`
6510 // when given a full URL, which preserved the bearer token. The
6511 // contract for `with_server` is that switching servers means
6512 // the token from the old server is no longer trusted.
6513 use crate::storage::MemoryTokenStorage;
6514
6515 let storage = Arc::new(MemoryTokenStorage::new());
6516 storage.store("token-from-old-server").unwrap();
6517 let client = Client::new().unwrap().with_storage(storage.clone());
6518 assert_eq!(
6519 storage.load().unwrap(),
6520 Some("token-from-old-server".to_string())
6521 );
6522
6523 // Switch to a self-hosted Studio (full URL). Storage must be
6524 // cleared, and the new client must have a blank in-memory token.
6525 let new_client = client
6526 .with_server("https://studio.example.com")
6527 .expect("https full URL through with_server");
6528 assert_eq!(storage.load().unwrap(), None);
6529 assert_eq!(new_client.url(), "https://studio.example.com");
6530
6531 // The new client should not carry the old token in memory either.
6532 let in_mem = tokio::runtime::Runtime::new()
6533 .unwrap()
6534 .block_on(async { new_client.token.read().await.clone() });
6535 assert!(in_mem.is_empty(), "expected blank token, got {in_mem:?}");
6536 }
6537
6538 #[test]
6539 fn test_with_server_rejects_insecure_full_url() {
6540 // `with_server` validates full URLs through `with_url`, so the
6541 // HTTPS rule applies uniformly. Plain http to a public host
6542 // must be rejected — the bearer token would otherwise leak in
6543 // plaintext when the caller next authenticates.
6544 let client = Client::new().unwrap();
6545 let err = client.with_server("http://studio.example.com").unwrap_err();
6546 assert!(matches!(err, Error::InsecureUrl(_)));
6547 }
6548
6549 // ===== with_url HTTPS enforcement =====
6550 //
6551 // The bearer token rides in the Authorization header, so plain
6552 // http:// to a public host would leak it in the clear. The function
6553 // must reject those URLs, but still let wiremock / local-dev URLs
6554 // through (loopback addresses, "localhost", "*.localhost").
6555
6556 #[test]
6557 fn with_url_accepts_https_public_host() {
6558 let client = Client::new().unwrap();
6559 let out = client
6560 .with_url("https://studio.example.com")
6561 .expect("https public host must be accepted");
6562 assert_eq!(out.url(), "https://studio.example.com");
6563 }
6564
6565 #[test]
6566 fn with_url_accepts_http_loopback_ipv4() {
6567 let client = Client::new().unwrap();
6568 let out = client
6569 .with_url("http://127.0.0.1:8080")
6570 .expect("http://127.0.0.1 must be accepted (loopback)");
6571 assert_eq!(out.url(), "http://127.0.0.1:8080");
6572 }
6573
6574 #[test]
6575 fn with_url_accepts_http_loopback_ipv6() {
6576 let client = Client::new().unwrap();
6577 let out = client
6578 .with_url("http://[::1]:8080")
6579 .expect("http://[::1] must be accepted (loopback)");
6580 assert!(out.url().starts_with("http://[::1]"));
6581 }
6582
6583 #[test]
6584 fn with_url_accepts_http_localhost() {
6585 let client = Client::new().unwrap();
6586 client
6587 .with_url("http://localhost:8080")
6588 .expect("http://localhost must be accepted");
6589 client
6590 .with_url("http://LOCALHOST")
6591 .expect("http://LOCALHOST must be accepted (case-insensitive)");
6592 client
6593 .with_url("http://wiremock.localhost")
6594 .expect("http://*.localhost must be accepted");
6595 }
6596
6597 #[test]
6598 fn with_url_rejects_http_public_host() {
6599 let client = Client::new().unwrap();
6600 let err = client.with_url("http://studio.example.com").unwrap_err();
6601 match err {
6602 Error::InsecureUrl(u) => assert_eq!(u, "http://studio.example.com"),
6603 other => panic!("expected InsecureUrl, got {other:?}"),
6604 }
6605 }
6606
6607 #[test]
6608 fn with_url_rejects_http_public_ip() {
6609 let client = Client::new().unwrap();
6610 // 8.8.8.8 is not loopback; must be rejected.
6611 let err = client.with_url("http://8.8.8.8").unwrap_err();
6612 assert!(matches!(err, Error::InsecureUrl(_)));
6613 }
6614
6615 #[test]
6616 fn with_url_rejects_non_http_scheme() {
6617 let client = Client::new().unwrap();
6618 // file:// would otherwise parse, but it's not a transport we
6619 // can use for RPC and we don't want to silently accept it.
6620 let err = client.with_url("file:///etc/passwd").unwrap_err();
6621 assert!(matches!(err, Error::InsecureUrl(_)));
6622 }
6623}
6624
6625#[cfg(test)]
6626mod tests_map_rpc_error {
6627 use super::*;
6628 use crate::api::TaskID;
6629
6630 #[test]
6631 fn maps_not_found_with_task_id_to_typed_variant() {
6632 // Server code 101 + "not found" message + task_id present → TaskNotFound
6633 let task_id = TaskID::try_from("task-1a2b").unwrap();
6634 let err = map_rpc_error(
6635 "task.data.list",
6636 101,
6637 "task not found".to_string(),
6638 Some(task_id),
6639 );
6640 assert!(matches!(err, Error::TaskNotFound(_)));
6641 }
6642
6643 #[test]
6644 fn maps_cannot_find_phrasing_to_typed_variant() {
6645 // The DVE server emits "Cannot find task..." — the original "not found"
6646 // substring match missed this and the caller saw a generic RpcError.
6647 let task_id = TaskID::try_from("task-1a2b").unwrap();
6648 let err = map_rpc_error(
6649 "task.data.list",
6650 101,
6651 "Cannot find task with id 6789".to_string(),
6652 Some(task_id),
6653 );
6654 assert!(
6655 matches!(err, Error::TaskNotFound(_)),
6656 "'Cannot find task' should map to TaskNotFound, got {err:?}"
6657 );
6658 }
6659
6660 #[test]
6661 fn maps_does_not_exist_phrasing_to_typed_variant() {
6662 let task_id = TaskID::try_from("task-1a2b").unwrap();
6663 let err = map_rpc_error(
6664 "task.chart.get",
6665 101,
6666 "task does not exist".to_string(),
6667 Some(task_id),
6668 );
6669 assert!(matches!(err, Error::TaskNotFound(_)));
6670 }
6671
6672 #[test]
6673 fn maps_code_101_with_unknown_phrasing_when_task_id_supplied() {
6674 // Server contract for code 101 is "resource not found"; even if the
6675 // phrasing is novel, the typed variant should be returned so callers
6676 // can write a stable `match`.
6677 let task_id = TaskID::try_from("task-1a2b").unwrap();
6678 let err = map_rpc_error(
6679 "task.data.list",
6680 101,
6681 "completely novel server message".to_string(),
6682 Some(task_id),
6683 );
6684 assert!(
6685 matches!(err, Error::TaskNotFound(_)),
6686 "code 101 + task_id should always map to TaskNotFound, got {err:?}"
6687 );
6688 }
6689
6690 #[test]
6691 fn maps_permission_codes_to_typed_variant() {
6692 for code in [401, 403] {
6693 let err = map_rpc_error("task.chart.add", code, "denied".to_string(), None);
6694 assert!(
6695 matches!(err, Error::PermissionDenied(_)),
6696 "code {} did not map",
6697 code
6698 );
6699 }
6700 }
6701
6702 #[test]
6703 fn permission_denied_records_method_for_diagnostics() {
6704 let err = map_rpc_error("task.data.upload", 403, "forbidden".to_string(), None);
6705 match err {
6706 Error::PermissionDenied(method) => assert_eq!(method, "task.data.upload"),
6707 other => panic!("expected PermissionDenied, got {:?}", other),
6708 }
6709 }
6710
6711 #[test]
6712 fn maps_payload_too_large_to_typed_variant() {
6713 let err = map_rpc_error("val.data.upload", 413, "request too large".into(), None);
6714 match err {
6715 Error::PayloadTooLarge { method, size_hint } => {
6716 assert_eq!(method, "val.data.upload");
6717 assert!(size_hint.is_none());
6718 }
6719 other => panic!("expected PayloadTooLarge, got {:?}", other),
6720 }
6721 }
6722
6723 #[test]
6724 fn falls_through_to_generic_rpc_error_for_unknown_codes() {
6725 let err = map_rpc_error("task.data.list", -99999, "weird".to_string(), None);
6726 match err {
6727 Error::RpcError(code, msg) => {
6728 assert_eq!(code, -99999);
6729 assert_eq!(msg, "weird");
6730 }
6731 other => panic!("expected RpcError, got {:?}", other),
6732 }
6733 }
6734
6735 #[test]
6736 fn not_found_without_task_id_falls_through() {
6737 // Code 101 without task_id → generic RpcError (no task to name)
6738 let err = map_rpc_error("task.data.list", 101, "not found".to_string(), None);
6739 assert!(matches!(err, Error::RpcError(101, _)));
6740 }
6741
6742 #[test]
6743 fn code_101_with_task_id_always_maps_even_with_unrelated_message() {
6744 // Previously the test asserted fall-through for non-"not found"
6745 // messages, but the contract for code 101 is "resource not found"
6746 // (see api.go), so when a task_id is present the typed variant is
6747 // returned unconditionally to give callers a stable error type.
6748 let task_id = TaskID::try_from("task-1a2b").unwrap();
6749 let err = map_rpc_error(
6750 "task.data.list",
6751 101,
6752 "permission denied".to_string(),
6753 Some(task_id),
6754 );
6755 assert!(matches!(err, Error::TaskNotFound(_)));
6756 }
6757}
6758
6759#[cfg(test)]
6760mod tests_jobs {
6761 use super::*;
6762
6763 #[test]
6764 fn jobs_list_request_serializes_to_empty_object() {
6765 let req = JobsListRequest {};
6766 assert_eq!(serde_json::to_value(&req).unwrap(), serde_json::json!({}));
6767 }
6768
6769 #[test]
6770 fn job_deserializes_from_bk_batch_shape() {
6771 let json = r#"{
6772 "code": "edgefirst-validator:2.9.5",
6773 "title": "EdgeFirst Validator",
6774 "job_name": "smoke-test",
6775 "job_id": "aws-batch-abc",
6776 "state": "RUNNING",
6777 "launch": "2026-05-14T15:00:00Z",
6778 "task_id": 6789,
6779 "docker_task": {},
6780 "extra_field": "ignored"
6781 }"#;
6782 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6783 assert_eq!(job.code, "edgefirst-validator:2.9.5");
6784 assert_eq!(job.state, "RUNNING");
6785 assert_eq!(job.task_id, 6789);
6786 assert_eq!(job.task_id().value(), 6789);
6787 }
6788}
6789
6790#[cfg(test)]
6791mod tests_job_run {
6792 use super::*;
6793 use crate::api::Parameter;
6794 use std::collections::HashMap;
6795
6796 #[test]
6797 fn job_run_request_serializes_with_expected_fields() {
6798 let req = JobRunRequest {
6799 name: "edgefirst-validator".into(),
6800 job_name: "post-profile-run".into(),
6801 env: HashMap::from([("LOG_LEVEL".into(), "info".into())]),
6802 data: HashMap::from([("validation_session_id".into(), Parameter::Integer(2707))]),
6803 };
6804 let json = serde_json::to_value(&req).unwrap();
6805 assert_eq!(json["name"], "edgefirst-validator");
6806 assert_eq!(json["job_name"], "post-profile-run");
6807 assert_eq!(json["env"]["LOG_LEVEL"], "info");
6808 assert_eq!(json["data"]["validation_session_id"], 2707);
6809 }
6810
6811 #[test]
6812 fn job_run_response_deserializes_as_job() {
6813 // job.run now returns the full BK_BATCH record; deserialize as Job.
6814 let json = r#"{
6815 "code": "edgefirst-validator:2.9.5",
6816 "title": "EdgeFirst Validator",
6817 "job_name": "post-profile-run",
6818 "job_id": "aws-batch-job-xxx",
6819 "state": "SUBMITTED",
6820 "task_id": 6789
6821 }"#;
6822 let job: crate::api::Job = serde_json::from_str(json).unwrap();
6823 assert_eq!(job.task_id, 6789);
6824 assert_eq!(job.job_id, "aws-batch-job-xxx");
6825 assert_eq!(job.state, "SUBMITTED");
6826 }
6827}
6828
6829#[cfg(test)]
6830mod tests_job_stop {
6831 use super::*;
6832 use crate::api::TaskID;
6833
6834 #[test]
6835 fn job_stop_request_serializes_with_task_id() {
6836 let task_id = TaskID::try_from("task-1a2b").unwrap();
6837 let req = JobStopRequest {
6838 task_id: task_id.value(),
6839 };
6840 let json = serde_json::to_value(&req).unwrap();
6841 assert_eq!(json["task_id"], task_id.value());
6842 }
6843}
6844
6845#[cfg(test)]
6846mod tests_task_data_list_request {
6847 use super::*;
6848 use crate::api::TaskID;
6849
6850 #[test]
6851 fn task_data_list_request_serializes_with_task_id() {
6852 let task_id = TaskID::try_from("task-1a2b").unwrap();
6853 let req = TaskDataListRequest {
6854 task_id: task_id.value(),
6855 };
6856 let json = serde_json::to_value(&req).unwrap();
6857 assert_eq!(json["task_id"], task_id.value());
6858 }
6859}
6860
6861#[cfg(test)]
6862mod tests_task_data_download {
6863 use super::*;
6864 use crate::api::TaskID;
6865
6866 #[test]
6867 fn task_data_download_request_serializes_with_all_fields() {
6868 let task_id = TaskID::try_from("task-1a2b").unwrap();
6869 let req = TaskDataDownloadRequest {
6870 task_id: task_id.value(),
6871 folder: "predictions".into(),
6872 file: "predictions.parquet".into(),
6873 };
6874 let json = serde_json::to_value(&req).unwrap();
6875 assert_eq!(json["task_id"], task_id.value());
6876 assert_eq!(json["folder"], "predictions");
6877 assert_eq!(json["file"], "predictions.parquet");
6878 }
6879}
6880
6881#[cfg(test)]
6882mod tests_task_chart_add {
6883 use super::*;
6884 use crate::api::{Parameter, TaskID};
6885
6886 #[test]
6887 fn task_chart_add_request_serializes_with_correct_fields() {
6888 let task_id = TaskID::try_from("task-1a2b").unwrap();
6889 let data = Parameter::Object(std::collections::HashMap::from([(
6890 "type".into(),
6891 Parameter::String("line".into()),
6892 )]));
6893 let req = TaskChartAddRequest {
6894 task_id: task_id.value(),
6895 group_name: "metrics".into(),
6896 chart_name: "loss".into(),
6897 params: None,
6898 data,
6899 };
6900 let json = serde_json::to_value(&req).unwrap();
6901 assert_eq!(json["task_id"], task_id.value());
6902 assert_eq!(json["group_name"], "metrics");
6903 assert_eq!(json["chart_name"], "loss");
6904 assert_eq!(json["data"]["type"], "line");
6905 assert!(json["params"].is_null());
6906 }
6907}
6908
6909#[cfg(test)]
6910mod tests_task_chart_list {
6911 use super::*;
6912 use crate::api::TaskID;
6913
6914 #[test]
6915 fn task_chart_list_request_omits_empty_group_name() {
6916 let task_id = TaskID::try_from("task-1a2b").unwrap();
6917 let req = TaskChartListRequest {
6918 task_id: task_id.value(),
6919 group_name: String::new(),
6920 };
6921 let json = serde_json::to_value(&req).unwrap();
6922 assert_eq!(json["task_id"], task_id.value());
6923 assert_eq!(json["group_name"], "");
6924 }
6925}
6926
6927#[cfg(test)]
6928mod tests_task_chart_get {
6929 use super::*;
6930 use crate::api::TaskID;
6931
6932 #[test]
6933 fn task_chart_get_request_serializes_with_all_fields() {
6934 let task_id = TaskID::try_from("task-1a2b").unwrap();
6935 let req = TaskChartGetRequest {
6936 task_id: task_id.value(),
6937 group_name: "metrics".into(),
6938 chart_name: "loss".into(),
6939 };
6940 let json = serde_json::to_value(&req).unwrap();
6941 assert_eq!(json["task_id"], task_id.value());
6942 assert_eq!(json["group_name"], "metrics");
6943 assert_eq!(json["chart_name"], "loss");
6944 }
6945}
6946
6947#[cfg(test)]
6948mod tests_val_data_download {
6949 use super::*;
6950
6951 #[test]
6952 fn val_data_download_request_serializes() {
6953 let req = ValDataDownloadRequest {
6954 session_id: 2707,
6955 filename: "trace/imx95.json".into(),
6956 };
6957 let json = serde_json::to_value(&req).unwrap();
6958 assert_eq!(json["session_id"], 2707);
6959 assert_eq!(json["filename"], "trace/imx95.json");
6960 }
6961}
6962
6963#[cfg(test)]
6964mod tests_val_data_list {
6965 use super::*;
6966
6967 #[test]
6968 fn val_data_list_request_serializes() {
6969 let req = ValDataListRequest { session_id: 2707 };
6970 assert_eq!(
6971 serde_json::to_value(&req).unwrap(),
6972 serde_json::json!({"session_id": 2707})
6973 );
6974 }
6975}
6976
6977#[cfg(test)]
6978mod tests_jsonrpc_envelope_detection {
6979 use super::*;
6980
6981 #[test]
6982 fn detects_real_envelope() {
6983 let v = serde_json::json!({
6984 "jsonrpc": "2.0",
6985 "id": 0,
6986 "error": { "code": 101, "message": "Cannot find task" },
6987 });
6988 assert!(is_jsonrpc_error_envelope(&v));
6989 }
6990
6991 #[test]
6992 fn rejects_plain_json_artifact_with_error_field() {
6993 // A diagnostics file with a free-form `error` object — must not be
6994 // misread as an RPC envelope just because the key collides.
6995 let v = serde_json::json!({
6996 "metric": "loss",
6997 "value": 0.42,
6998 "error": { "code": "ENV_NOT_FOUND", "message": "missing var" },
6999 });
7000 assert!(
7001 !is_jsonrpc_error_envelope(&v),
7002 "missing jsonrpc sentinel should mean 'not an envelope'"
7003 );
7004 }
7005
7006 #[test]
7007 fn rejects_envelope_missing_jsonrpc_sentinel() {
7008 // Bare `error` block without the protocol-version marker.
7009 let v = serde_json::json!({
7010 "id": 0,
7011 "error": { "code": 101, "message": "x" },
7012 });
7013 assert!(!is_jsonrpc_error_envelope(&v));
7014 }
7015
7016 #[test]
7017 fn rejects_envelope_with_non_object_error_field() {
7018 // A diagnostics file shaped like JSON-RPC accidentally but using
7019 // a string for `error`.
7020 let v = serde_json::json!({
7021 "jsonrpc": "2.0",
7022 "error": "something went wrong",
7023 });
7024 assert!(!is_jsonrpc_error_envelope(&v));
7025 }
7026
7027 #[test]
7028 fn rejects_envelope_without_error_code() {
7029 // Real envelopes always carry an integer error.code; missing one
7030 // is suspicious enough to refuse the envelope classification.
7031 let v = serde_json::json!({
7032 "jsonrpc": "2.0",
7033 "error": { "message": "no code" },
7034 });
7035 assert!(!is_jsonrpc_error_envelope(&v));
7036 }
7037
7038 #[test]
7039 fn rejects_envelope_with_non_numeric_error_code() {
7040 let v = serde_json::json!({
7041 "jsonrpc": "2.0",
7042 "error": { "code": "ENOENT", "message": "x" },
7043 });
7044 assert!(!is_jsonrpc_error_envelope(&v));
7045 }
7046
7047 #[test]
7048 fn rejects_non_object_root() {
7049 // A JSON file whose root is an array — common for metrics dumps —
7050 // must not be misread.
7051 let v = serde_json::json!([1, 2, 3]);
7052 assert!(!is_jsonrpc_error_envelope(&v));
7053 }
7054
7055 #[test]
7056 fn accepts_unsigned_error_code() {
7057 // The server's code is technically i32 but JSON has no signed/
7058 // unsigned distinction — accept both shapes.
7059 let v = serde_json::json!({
7060 "jsonrpc": "2.0",
7061 "error": { "code": 101u32, "message": "x" },
7062 });
7063 assert!(is_jsonrpc_error_envelope(&v));
7064 }
7065}
7066
7067#[cfg(test)]
7068mod tests_validate_chart_args {
7069 use super::*;
7070
7071 #[test]
7072 fn rejects_empty_group() {
7073 let err = validate_chart_args("", "name").unwrap_err();
7074 assert!(matches!(err, Error::InvalidParameters(_)));
7075 }
7076
7077 #[test]
7078 fn rejects_empty_name() {
7079 let err = validate_chart_args("group", "").unwrap_err();
7080 assert!(matches!(err, Error::InvalidParameters(_)));
7081 }
7082
7083 #[test]
7084 fn rejects_both_empty() {
7085 let err = validate_chart_args("", "").unwrap_err();
7086 assert!(matches!(err, Error::InvalidParameters(_)));
7087 }
7088
7089 #[test]
7090 fn accepts_valid_args() {
7091 assert!(validate_chart_args("group", "name").is_ok());
7092 }
7093
7094 #[test]
7095 fn accepts_unicode_args() {
7096 // Unicode names are allowed; only emptiness is rejected.
7097 assert!(validate_chart_args("metrics-集合", "损失").is_ok());
7098 }
7099}
7100
7101// ---------------------------------------------------------------------------
7102// Additional offline tests for request shapes + helpers added in DE-2565.
7103//
7104// These focus on the wire-shape and helper logic that does not require a
7105// live Studio server — they significantly boost coverage of client.rs.
7106// ---------------------------------------------------------------------------
7107
7108#[cfg(test)]
7109mod tests_job_run_request_shape {
7110 use super::*;
7111 use crate::api::Parameter;
7112 use std::collections::HashMap;
7113
7114 #[test]
7115 fn empty_env_and_data_serialize_as_empty_objects() {
7116 let req = JobRunRequest {
7117 name: "edgefirst-validator".into(),
7118 job_name: "smoke".into(),
7119 env: HashMap::new(),
7120 data: HashMap::new(),
7121 };
7122 let json = serde_json::to_value(&req).unwrap();
7123 assert_eq!(json["name"], "edgefirst-validator");
7124 assert_eq!(json["env"], serde_json::json!({}));
7125 assert_eq!(json["data"], serde_json::json!({}));
7126 }
7127
7128 #[test]
7129 fn data_passes_through_parameter_object_payloads() {
7130 // Confirms the Parameter wrapper survives JSON serialization round-trip
7131 // for the kind of structured chart payload that exercises Parameter
7132 // variants (Real, Integer, String, Array, Object, Boolean).
7133 let req = JobRunRequest {
7134 name: "edgefirst-validator".into(),
7135 job_name: "feat".into(),
7136 env: HashMap::new(),
7137 data: HashMap::from([
7138 ("flag".into(), Parameter::Boolean(true)),
7139 ("epochs".into(), Parameter::Integer(50)),
7140 ("lr".into(), Parameter::Real(1e-3)),
7141 ("name".into(), Parameter::String("hello".into())),
7142 ]),
7143 };
7144 let json = serde_json::to_value(&req).unwrap();
7145 assert_eq!(json["data"]["flag"], true);
7146 assert_eq!(json["data"]["epochs"], 50);
7147 assert!(json["data"]["lr"].as_f64().unwrap() > 0.0);
7148 assert_eq!(json["data"]["name"], "hello");
7149 }
7150}
7151
7152#[cfg(test)]
7153mod tests_task_data_chart_request_shape {
7154 use super::*;
7155 use crate::api::{Parameter, TaskID};
7156
7157 #[test]
7158 fn chart_add_request_with_params_serializes_object() {
7159 let task_id = TaskID::try_from("task-1a2b").unwrap();
7160 let params = Parameter::Object(std::collections::HashMap::from([(
7161 "y_axis".into(),
7162 Parameter::String("log".into()),
7163 )]));
7164 let data = Parameter::Object(std::collections::HashMap::from([(
7165 "type".into(),
7166 Parameter::String("line".into()),
7167 )]));
7168 let req = TaskChartAddRequest {
7169 task_id: task_id.value(),
7170 group_name: "metrics".into(),
7171 chart_name: "loss".into(),
7172 params: Some(params),
7173 data,
7174 };
7175 let json = serde_json::to_value(&req).unwrap();
7176 assert_eq!(json["params"]["y_axis"], "log");
7177 }
7178
7179 #[test]
7180 fn task_data_list_request_round_trips() {
7181 let task_id = TaskID::try_from("task-1a2b").unwrap();
7182 let req = TaskDataListRequest {
7183 task_id: task_id.value(),
7184 };
7185 let json = serde_json::to_string(&req).unwrap();
7186 // Field order is stable for a single-field struct, so an exact match
7187 // is meaningful here.
7188 assert_eq!(json, format!("{{\"task_id\":{}}}", task_id.value()));
7189 }
7190
7191 #[test]
7192 fn task_data_download_request_treats_folder_and_file_independently() {
7193 let task_id = TaskID::try_from("task-1a2b").unwrap();
7194 let req = TaskDataDownloadRequest {
7195 task_id: task_id.value(),
7196 folder: "validation/run-01".into(),
7197 file: "metrics.json".into(),
7198 };
7199 let json = serde_json::to_value(&req).unwrap();
7200 // Server takes folder + file separately (not a single combined path)
7201 // so callers don't have to escape slashes themselves.
7202 assert_eq!(json["folder"], "validation/run-01");
7203 assert_eq!(json["file"], "metrics.json");
7204 }
7205}
7206
7207#[cfg(test)]
7208mod tests_val_data_request_shape {
7209 use super::*;
7210
7211 #[test]
7212 fn val_data_list_round_trips() {
7213 let req = ValDataListRequest { session_id: 2707 };
7214 let s = serde_json::to_string(&req).unwrap();
7215 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
7216 assert_eq!(back["session_id"], 2707);
7217 }
7218
7219 #[test]
7220 fn val_data_download_round_trips_with_nested_path() {
7221 let req = ValDataDownloadRequest {
7222 session_id: 2707,
7223 filename: "subfolder/imx95.json".into(),
7224 };
7225 let s = serde_json::to_string(&req).unwrap();
7226 let back: serde_json::Value = serde_json::from_str(&s).unwrap();
7227 assert_eq!(back["session_id"], 2707);
7228 assert_eq!(back["filename"], "subfolder/imx95.json");
7229 }
7230}
7231
7232#[cfg(test)]
7233mod tests_progress_struct {
7234 use super::*;
7235
7236 #[test]
7237 fn progress_can_be_constructed_with_zero_total() {
7238 // Servers sometimes omit Content-Length; progress events should still
7239 // be representable. This guards the public field-level API.
7240 let p = Progress {
7241 current: 0,
7242 total: 0,
7243 status: None,
7244 };
7245 assert_eq!(p.current, 0);
7246 assert_eq!(p.total, 0);
7247 assert!(p.status.is_none());
7248 }
7249
7250 #[test]
7251 fn progress_tracks_current_independently_of_total() {
7252 let p = Progress {
7253 current: 123,
7254 total: 456,
7255 status: Some("Downloading".into()),
7256 };
7257 assert_eq!(p.current, 123);
7258 assert_eq!(p.total, 456);
7259 assert_eq!(p.status.as_deref(), Some("Downloading"));
7260 }
7261
7262 #[test]
7263 fn progress_can_be_cloned() {
7264 // Progress is consumed by progress sinks which may need to retain a
7265 // copy independently of the channel — derive(Clone) must hold.
7266 let p = Progress {
7267 current: 10,
7268 total: 20,
7269 status: Some("phase".into()),
7270 };
7271 let q = p.clone();
7272 assert_eq!(q.current, p.current);
7273 assert_eq!(q.total, p.total);
7274 assert_eq!(q.status, p.status);
7275 }
7276}
7277
7278#[cfg(test)]
7279mod tests_bare_filename_parent {
7280 // Documents the empty-parent guard added for `rpc_download` so that
7281 // callers passing a bare filename like "metrics.json" download to the
7282 // current directory instead of erroring on `create_dir_all("")`.
7283 use std::path::Path;
7284
7285 #[test]
7286 fn bare_filename_parent_is_empty_path() {
7287 // This is the invariant our guard depends on. If a future Rust
7288 // release ever changed `Path::parent` for bare filenames, the guard
7289 // would need revisiting.
7290 let p = Path::new("metrics.json");
7291 let parent = p.parent().expect("bare filename always has Some parent");
7292 assert!(
7293 parent.as_os_str().is_empty(),
7294 "Path::parent for bare filename should be empty, got: {parent:?}"
7295 );
7296 }
7297
7298 #[test]
7299 fn path_with_directory_has_non_empty_parent() {
7300 // The companion case: when the path includes a directory, the
7301 // parent is non-empty and `create_dir_all` should be invoked.
7302 let p = Path::new("dir/metrics.json");
7303 let parent = p.parent().expect("path-with-dir always has Some parent");
7304 assert!(!parent.as_os_str().is_empty());
7305 assert_eq!(parent, Path::new("dir"));
7306 }
7307}