edgefirst_client/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5    Annotation, Error, Sample, Task,
6    api::{
7        AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8        Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9        SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10        TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11        ValidationSession, ValidationSessionID,
12    },
13    dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23    collections::HashMap,
24    fs::create_dir_all,
25    io::{SeekFrom, Write as _},
26    path::{Path, PathBuf},
27    sync::{
28        Arc,
29        atomic::{AtomicUsize, Ordering},
30    },
31    time::Duration,
32    vec,
33};
34use tokio::{
35    fs::{self, File},
36    io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
37    sync::{RwLock, Semaphore, mpsc::Sender},
38};
39use tokio_util::codec::{BytesCodec, FramedRead};
40use walkdir::WalkDir;
41
42#[cfg(feature = "polars")]
43use polars::prelude::*;
44
45static MAX_TASKS: usize = 32;
46static MAX_RETRIES: u32 = 10;
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49/// Progress information for long-running operations.
50///
51/// This struct tracks the current progress of operations like file uploads,
52/// downloads, or dataset processing. It provides the current count and total
53/// count to enable progress reporting in applications.
54///
55/// # Examples
56///
57/// ```rust
58/// use edgefirst_client::Progress;
59///
60/// let progress = Progress {
61///     current: 25,
62///     total: 100,
63/// };
64/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
65/// println!(
66///     "Progress: {:.1}% ({}/{})",
67///     percentage, progress.current, progress.total
68/// );
69/// ```
70#[derive(Debug, Clone)]
71pub struct Progress {
72    /// Current number of completed items.
73    pub current: usize,
74    /// Total number of items to process.
75    pub total: usize,
76}
77
78#[derive(Serialize)]
79struct RpcRequest<Params> {
80    id: u64,
81    jsonrpc: String,
82    method: String,
83    params: Option<Params>,
84}
85
86impl<T> Default for RpcRequest<T> {
87    fn default() -> Self {
88        RpcRequest {
89            id: 0,
90            jsonrpc: "2.0".to_string(),
91            method: "".to_string(),
92            params: None,
93        }
94    }
95}
96
97#[derive(Deserialize)]
98struct RpcError {
99    code: i32,
100    message: String,
101}
102
103#[derive(Deserialize)]
104struct RpcResponse<RpcResult> {
105    #[allow(dead_code)]
106    id: String,
107    #[allow(dead_code)]
108    jsonrpc: String,
109    error: Option<RpcError>,
110    result: Option<RpcResult>,
111}
112
113#[derive(Deserialize)]
114#[allow(dead_code)]
115struct EmptyResult {}
116
117#[derive(Debug, Serialize)]
118#[allow(dead_code)]
119struct SnapshotCreateParams {
120    snapshot_name: String,
121    keys: Vec<String>,
122}
123
124#[derive(Debug, Deserialize)]
125#[allow(dead_code)]
126struct SnapshotCreateResult {
127    snapshot_id: SnapshotID,
128    urls: Vec<String>,
129}
130
131#[derive(Debug, Serialize)]
132struct SnapshotCreateMultipartParams {
133    snapshot_name: String,
134    keys: Vec<String>,
135    file_sizes: Vec<usize>,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum SnapshotCreateMultipartResultField {
141    Id(u64),
142    Part(SnapshotPart),
143}
144
145#[derive(Debug, Serialize)]
146struct SnapshotCompleteMultipartParams {
147    key: String,
148    upload_id: String,
149    etag_list: Vec<EtagPart>,
150}
151
152#[derive(Debug, Clone, Serialize)]
153struct EtagPart {
154    #[serde(rename = "ETag")]
155    etag: String,
156    #[serde(rename = "PartNumber")]
157    part_number: usize,
158}
159
160#[derive(Debug, Clone, Deserialize)]
161struct SnapshotPart {
162    key: Option<String>,
163    upload_id: String,
164    urls: Vec<String>,
165}
166
167#[derive(Debug, Serialize)]
168struct SnapshotStatusParams {
169    snapshot_id: SnapshotID,
170    status: String,
171}
172
173#[derive(Deserialize, Debug)]
174struct SnapshotStatusResult {
175    #[allow(dead_code)]
176    pub id: SnapshotID,
177    #[allow(dead_code)]
178    pub uid: String,
179    #[allow(dead_code)]
180    pub description: String,
181    #[allow(dead_code)]
182    pub date: String,
183    #[allow(dead_code)]
184    pub status: String,
185}
186
187#[derive(Serialize)]
188#[allow(dead_code)]
189struct ImageListParams {
190    images_filter: ImagesFilter,
191    image_files_filter: HashMap<String, String>,
192    only_ids: bool,
193}
194
195#[derive(Serialize)]
196#[allow(dead_code)]
197struct ImagesFilter {
198    dataset_id: DatasetID,
199}
200
201/// Main client for interacting with EdgeFirst Studio Server.
202///
203/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
204/// and manages authentication, RPC calls, and data operations. It provides
205/// methods for managing projects, datasets, experiments, training sessions,
206/// and various utility functions for data processing.
207///
208/// The client supports multiple authentication methods and can work with both
209/// SaaS and self-hosted EdgeFirst Studio instances.
210///
211/// # Features
212///
213/// - **Authentication**: Token-based authentication with automatic persistence
214/// - **Dataset Management**: Upload, download, and manipulate datasets
215/// - **Project Operations**: Create and manage projects and experiments
216/// - **Training & Validation**: Submit and monitor ML training jobs
217/// - **Data Integration**: Convert between EdgeFirst datasets and popular
218///   formats
219/// - **Progress Tracking**: Real-time progress updates for long-running
220///   operations
221///
222/// # Examples
223///
224/// ```no_run
225/// use edgefirst_client::{Client, DatasetID};
226/// use std::str::FromStr;
227///
228/// # async fn example() -> Result<(), edgefirst_client::Error> {
229/// // Create a new client and authenticate
230/// let mut client = Client::new()?;
231/// let client = client
232///     .with_login("your-email@example.com", "password")
233///     .await?;
234///
235/// // Or use an existing token
236/// let base_client = Client::new()?;
237/// let client = base_client.with_token("your-token-here")?;
238///
239/// // Get organization and projects
240/// let org = client.organization().await?;
241/// let projects = client.projects(None).await?;
242///
243/// // Work with datasets
244/// let dataset_id = DatasetID::from_str("ds-abc123")?;
245/// let dataset = client.dataset(dataset_id).await?;
246/// # Ok(())
247/// # }
248/// ```
249#[derive(Clone, Debug)]
250pub struct Client {
251    http: reqwest::Client,
252    url: String,
253    token: Arc<RwLock<String>>,
254    token_path: Option<PathBuf>,
255}
256
257impl Client {
258    /// Create a new unauthenticated client with the default saas server.  To
259    /// connect to a different server use the `with_server` method or with the
260    /// `with_token` method to create a client with a token which includes the
261    /// server instance name (test, stage, saas).
262    ///
263    /// This client is created without a token and will need to login before
264    /// using any methods that require authentication.  Use the `with_token`
265    /// method to create a client with a token.
266    pub fn new() -> Result<Self, Error> {
267        Ok(Client {
268            http: reqwest::Client::builder()
269                .read_timeout(Duration::from_secs(60))
270                .build()?,
271            url: "https://edgefirst.studio".to_string(),
272            token: Arc::new(tokio::sync::RwLock::new("".to_string())),
273            token_path: None,
274        })
275    }
276
277    /// Returns a new client connected to the specified server instance.  If a
278    /// token is already set in the client then it will be dropped as the token
279    /// is specific to the server instance.
280    pub fn with_server(&self, server: &str) -> Result<Self, Error> {
281        Ok(Client {
282            url: format!("https://{}.edgefirst.studio", server),
283            ..self.clone()
284        })
285    }
286
287    /// Returns a new client authenticated with the provided username and
288    /// password.
289    pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
290        let params = HashMap::from([("username", username), ("password", password)]);
291        let login: LoginResult = self
292            .rpc_without_auth("auth.login".to_owned(), Some(params))
293            .await?;
294        Ok(Client {
295            token: Arc::new(tokio::sync::RwLock::new(login.token)),
296            ..self.clone()
297        })
298    }
299
300    /// Returns a new client which will load and save the token to the specified
301    /// path.
302    pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
303        let token_path = match token_path {
304            Some(path) => path.to_path_buf(),
305            None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
306                .unwrap()
307                .config_dir()
308                .join("token"),
309        };
310
311        debug!("Using token path: {:?}", token_path);
312
313        let token = match token_path.exists() {
314            true => std::fs::read_to_string(&token_path)?,
315            false => "".to_string(),
316        };
317
318        if !token.is_empty() {
319            let client = self.with_token(&token)?;
320            Ok(Client {
321                token_path: Some(token_path),
322                ..client
323            })
324        } else {
325            Ok(Client {
326                token_path: Some(token_path),
327                ..self.clone()
328            })
329        }
330    }
331
332    /// Returns a new client authenticated with the provided token.
333    pub fn with_token(&self, token: &str) -> Result<Self, Error> {
334        if token.is_empty() {
335            return Ok(self.clone());
336        }
337
338        let token_parts: Vec<&str> = token.split('.').collect();
339        if token_parts.len() != 3 {
340            return Err(Error::InvalidToken);
341        }
342
343        let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
344            .decode(token_parts[1])
345            .unwrap();
346        let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
347        let server = match payload.get("database") {
348            Some(value) => Ok(value.as_str().unwrap().to_string()),
349            None => Err(Error::InvalidToken),
350        }?;
351
352        Ok(Client {
353            url: format!("https://{}.edgefirst.studio", server),
354            token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
355            ..self.clone()
356        })
357    }
358
359    pub async fn save_token(&self) -> Result<(), Error> {
360        let path = self.token_path.clone().unwrap_or_else(|| {
361            ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
362                .unwrap()
363                .config_dir()
364                .join("token")
365        });
366
367        create_dir_all(path.parent().unwrap())?;
368        let mut file = std::fs::File::create(&path)?;
369        file.write_all(self.token.read().await.as_bytes())?;
370
371        debug!("Saved token to {:?}", path);
372
373        Ok(())
374    }
375
376    /// Return the version of the EdgeFirst Studio server for the current
377    /// client connection.
378    pub async fn version(&self) -> Result<String, Error> {
379        let version: HashMap<String, String> = self
380            .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
381            .await?;
382        let version = version.get("version").ok_or(Error::InvalidResponse)?;
383        Ok(version.to_owned())
384    }
385
386    /// Clear the token used to authenticate the client with the server.  If an
387    /// optional path was provided when creating the client, the token file
388    /// will also be cleared.
389    pub async fn logout(&self) -> Result<(), Error> {
390        {
391            let mut token = self.token.write().await;
392            *token = "".to_string();
393        }
394
395        if let Some(path) = &self.token_path
396            && path.exists()
397        {
398            fs::remove_file(path).await?;
399        }
400
401        Ok(())
402    }
403
404    /// Return the token used to authenticate the client with the server.  When
405    /// logging into the server using a username and password, the token is
406    /// returned by the server and stored in the client for future interactions.
407    pub async fn token(&self) -> String {
408        self.token.read().await.clone()
409    }
410
411    /// Verify the token used to authenticate the client with the server.  This
412    /// method is used to ensure that the token is still valid and has not
413    /// expired.  If the token is invalid, the server will return an error and
414    /// the client will need to login again.
415    pub async fn verify_token(&self) -> Result<(), Error> {
416        self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
417            .await?;
418        Ok::<(), Error>(())
419    }
420
421    /// Renew the token used to authenticate the client with the server.  This
422    /// method is used to refresh the token before it expires.  If the token
423    /// has already expired, the server will return an error and the client
424    /// will need to login again.
425    pub async fn renew_token(&self) -> Result<(), Error> {
426        let params = HashMap::from([("username".to_string(), self.username().await?)]);
427        let result: LoginResult = self
428            .rpc_without_auth("auth.refresh".to_owned(), Some(params))
429            .await?;
430
431        {
432            let mut token = self.token.write().await;
433            *token = result.token;
434        }
435
436        if self.token_path.is_some() {
437            self.save_token().await?;
438        }
439
440        Ok(())
441    }
442
443    async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
444        let token = self.token.read().await;
445        if token.is_empty() {
446            return Err(Error::EmptyToken);
447        }
448
449        let token_parts: Vec<&str> = token.split('.').collect();
450        if token_parts.len() != 3 {
451            return Err(Error::InvalidToken);
452        }
453
454        let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
455            .decode(token_parts[1])
456            .unwrap();
457        let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
458        match payload.get(field) {
459            Some(value) => Ok(value.to_owned()),
460            None => Err(Error::InvalidToken),
461        }
462    }
463
464    /// Returns the URL of the EdgeFirst Studio server for the current client.
465    pub fn url(&self) -> &str {
466        &self.url
467    }
468
469    /// Returns the username associated with the current token.
470    pub async fn username(&self) -> Result<String, Error> {
471        match self.token_field("username").await? {
472            serde_json::Value::String(username) => Ok(username),
473            _ => Err(Error::InvalidToken),
474        }
475    }
476
477    /// Returns the expiration time for the current token.
478    pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
479        let ts = match self.token_field("exp").await? {
480            serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
481            _ => return Err(Error::InvalidToken),
482        };
483
484        match DateTime::<Utc>::from_timestamp_secs(ts) {
485            Some(dt) => Ok(dt),
486            None => Err(Error::InvalidToken),
487        }
488    }
489
490    /// Returns the organization information for the current user.
491    pub async fn organization(&self) -> Result<Organization, Error> {
492        self.rpc::<(), Organization>("org.get".to_owned(), None)
493            .await
494    }
495
496    /// Returns a list of projects available to the user.  The projects are
497    /// returned as a vector of Project objects.  If a name filter is
498    /// provided, only projects matching the filter are returned.
499    ///
500    /// Projects are the top-level organizational unit in EdgeFirst Studio.
501    /// Projects contain datasets, trainers, and trainer sessions.  Projects
502    /// are used to group related datasets and trainers together.
503    pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
504        let projects = self
505            .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
506            .await?;
507        if let Some(name) = name {
508            Ok(projects
509                .into_iter()
510                .filter(|p| p.name().contains(name))
511                .collect())
512        } else {
513            Ok(projects)
514        }
515    }
516
517    /// Return the project with the specified project ID.  If the project does
518    /// not exist, an error is returned.
519    pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
520        let params = HashMap::from([("project_id", project_id)]);
521        self.rpc("project.get".to_owned(), Some(params)).await
522    }
523
524    /// Returns a list of datasets available to the user.  The datasets are
525    /// returned as a vector of Dataset objects.  If a name filter is
526    /// provided, only datasets matching the filter are returned.
527    pub async fn datasets(
528        &self,
529        project_id: ProjectID,
530        name: Option<&str>,
531    ) -> Result<Vec<Dataset>, Error> {
532        let params = HashMap::from([("project_id", project_id)]);
533        let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
534        if let Some(name) = name {
535            Ok(datasets
536                .into_iter()
537                .filter(|d| d.name().contains(name))
538                .collect())
539        } else {
540            Ok(datasets)
541        }
542    }
543
544    /// Return the dataset with the specified dataset ID.  If the dataset does
545    /// not exist, an error is returned.
546    pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
547        let params = HashMap::from([("dataset_id", dataset_id)]);
548        self.rpc("dataset.get".to_owned(), Some(params)).await
549    }
550
551    /// Lists the labels for the specified dataset.
552    pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
553        let params = HashMap::from([("dataset_id", dataset_id)]);
554        self.rpc("label.list".to_owned(), Some(params)).await
555    }
556
557    /// Add a new label to the dataset with the specified name.
558    pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
559        let new_label = NewLabel {
560            dataset_id,
561            labels: vec![NewLabelObject {
562                name: name.to_owned(),
563            }],
564        };
565        let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
566        Ok(())
567    }
568
569    /// Removes the label with the specified ID from the dataset.  Label IDs are
570    /// globally unique so the dataset_id is not required.
571    pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
572        let params = HashMap::from([("label_id", label_id)]);
573        let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
574        Ok(())
575    }
576
577    /// Creates a new dataset in the specified project.
578    ///
579    /// # Arguments
580    ///
581    /// * `project_id` - The ID of the project to create the dataset in
582    /// * `name` - The name of the new dataset
583    /// * `description` - Optional description for the dataset
584    ///
585    /// # Returns
586    ///
587    /// Returns the dataset ID of the newly created dataset.
588    pub async fn create_dataset(
589        &self,
590        project_id: &str,
591        name: &str,
592        description: Option<&str>,
593    ) -> Result<DatasetID, Error> {
594        let mut params = HashMap::new();
595        params.insert("project_id", project_id);
596        params.insert("name", name);
597        if let Some(desc) = description {
598            params.insert("description", desc);
599        }
600
601        #[derive(Deserialize)]
602        struct CreateDatasetResult {
603            id: DatasetID,
604        }
605
606        let result: CreateDatasetResult =
607            self.rpc("dataset.create".to_owned(), Some(params)).await?;
608        Ok(result.id)
609    }
610
611    /// Deletes a dataset by marking it as deleted.
612    ///
613    /// # Arguments
614    ///
615    /// * `dataset_id` - The ID of the dataset to delete
616    ///
617    /// # Returns
618    ///
619    /// Returns `Ok(())` if the dataset was successfully marked as deleted.
620    pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
621        let params = HashMap::from([("id", dataset_id)]);
622        let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
623        Ok(())
624    }
625
626    /// Updates the label with the specified ID to have the new name or index.
627    /// Label IDs cannot be changed.  Label IDs are globally unique so the
628    /// dataset_id is not required.
629    pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
630        #[derive(Serialize)]
631        struct Params {
632            dataset_id: DatasetID,
633            label_id: u64,
634            label_name: String,
635            label_index: u64,
636        }
637
638        let _: String = self
639            .rpc(
640                "label.update".to_owned(),
641                Some(Params {
642                    dataset_id: label.dataset_id(),
643                    label_id: label.id(),
644                    label_name: label.name().to_owned(),
645                    label_index: label.index(),
646                }),
647            )
648            .await?;
649        Ok(())
650    }
651
652    pub async fn download_dataset(
653        &self,
654        dataset_id: DatasetID,
655        groups: &[String],
656        file_types: &[FileType],
657        output: PathBuf,
658        progress: Option<Sender<Progress>>,
659    ) -> Result<(), Error> {
660        let samples = self
661            .samples(dataset_id, None, &[], groups, file_types, progress.clone())
662            .await?;
663        fs::create_dir_all(&output).await?;
664
665        let total = samples.len();
666        let current = Arc::new(AtomicUsize::new(0));
667        let sem = Arc::new(Semaphore::new(MAX_TASKS));
668
669        let tasks = samples
670            .into_iter()
671            .map(|sample| {
672                let sem = sem.clone();
673                let client = self.clone();
674                let current = current.clone();
675                let progress = progress.clone();
676                let file_types = file_types.to_vec();
677                let output = output.clone();
678
679                tokio::spawn(async move {
680                    let _permit = sem.acquire().await.unwrap();
681
682                    for file_type in file_types {
683                        if let Some(data) = sample.download(&client, file_type.clone()).await? {
684                            let file_ext = match file_type {
685                                FileType::Image => infer::get(&data)
686                                    .expect("Failed to identify image file format for sample")
687                                    .extension()
688                                    .to_string(),
689                                t => t.to_string(),
690                            };
691
692                            let file_name = format!(
693                                "{}.{}",
694                                sample.name().unwrap_or_else(|| "unknown".to_string()),
695                                file_ext
696                            );
697                            let file_path = output.join(&file_name);
698
699                            let mut file = File::create(&file_path).await?;
700                            file.write_all(&data).await?;
701                        } else {
702                            warn!(
703                                "No data for sample: {}",
704                                sample
705                                    .id()
706                                    .map(|id| id.to_string())
707                                    .unwrap_or_else(|| "unknown".to_string())
708                            );
709                        }
710                    }
711
712                    if let Some(progress) = &progress {
713                        let current = current.fetch_add(1, Ordering::SeqCst);
714                        progress
715                            .send(Progress {
716                                current: current + 1,
717                                total,
718                            })
719                            .await
720                            .unwrap();
721                    }
722
723                    Ok::<(), Error>(())
724                })
725            })
726            .collect::<Vec<_>>();
727
728        join_all(tasks)
729            .await
730            .into_iter()
731            .collect::<Result<Vec<_>, _>>()?;
732
733        if let Some(progress) = progress {
734            drop(progress);
735        }
736
737        Ok(())
738    }
739
740    /// List available annotation sets for the specified dataset.
741    pub async fn annotation_sets(
742        &self,
743        dataset_id: DatasetID,
744    ) -> Result<Vec<AnnotationSet>, Error> {
745        let params = HashMap::from([("dataset_id", dataset_id)]);
746        self.rpc("annset.list".to_owned(), Some(params)).await
747    }
748
749    /// Create a new annotation set for the specified dataset.
750    ///
751    /// # Arguments
752    ///
753    /// * `dataset_id` - The ID of the dataset to create the annotation set in
754    /// * `name` - The name of the new annotation set
755    /// * `description` - Optional description for the annotation set
756    ///
757    /// # Returns
758    ///
759    /// Returns the annotation set ID of the newly created annotation set.
760    pub async fn create_annotation_set(
761        &self,
762        dataset_id: DatasetID,
763        name: &str,
764        description: Option<&str>,
765    ) -> Result<AnnotationSetID, Error> {
766        #[derive(Serialize)]
767        struct Params<'a> {
768            dataset_id: DatasetID,
769            name: &'a str,
770            operator: &'a str,
771            #[serde(skip_serializing_if = "Option::is_none")]
772            description: Option<&'a str>,
773        }
774
775        #[derive(Deserialize)]
776        struct CreateAnnotationSetResult {
777            id: AnnotationSetID,
778        }
779
780        let username = self.username().await?;
781        let result: CreateAnnotationSetResult = self
782            .rpc(
783                "annset.add".to_owned(),
784                Some(Params {
785                    dataset_id,
786                    name,
787                    operator: &username,
788                    description,
789                }),
790            )
791            .await?;
792        Ok(result.id)
793    }
794
795    /// Deletes an annotation set by marking it as deleted.
796    ///
797    /// # Arguments
798    ///
799    /// * `annotation_set_id` - The ID of the annotation set to delete
800    ///
801    /// # Returns
802    ///
803    /// Returns `Ok(())` if the annotation set was successfully marked as
804    /// deleted.
805    pub async fn delete_annotation_set(
806        &self,
807        annotation_set_id: AnnotationSetID,
808    ) -> Result<(), Error> {
809        let params = HashMap::from([("id", annotation_set_id)]);
810        let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
811        Ok(())
812    }
813
814    /// Retrieve the annotation set with the specified ID.
815    pub async fn annotation_set(
816        &self,
817        annotation_set_id: AnnotationSetID,
818    ) -> Result<AnnotationSet, Error> {
819        let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
820        self.rpc("annset.get".to_owned(), Some(params)).await
821    }
822
823    /// Get the annotations for the specified annotation set with the
824    /// requested annotation types.  The annotation types are used to filter
825    /// the annotations returned.  The groups parameter is used to filter for
826    /// dataset groups (train, val, test).  Images which do not have any
827    /// annotations are also included in the result as long as they are in the
828    /// requested groups (when specified).
829    ///
830    /// The result is a vector of Annotations objects which contain the
831    /// full dataset along with the annotations for the specified types.
832    ///
833    /// To get the annotations as a DataFrame, use the `annotations_dataframe`
834    /// method instead.
835    pub async fn annotations(
836        &self,
837        annotation_set_id: AnnotationSetID,
838        groups: &[String],
839        annotation_types: &[AnnotationType],
840        progress: Option<Sender<Progress>>,
841    ) -> Result<Vec<Annotation>, Error> {
842        let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
843        let labels = self
844            .labels(dataset_id)
845            .await?
846            .into_iter()
847            .map(|label| (label.name().to_string(), label.index()))
848            .collect::<HashMap<_, _>>();
849        let total = self
850            .samples_count(
851                dataset_id,
852                Some(annotation_set_id),
853                annotation_types,
854                groups,
855                &[],
856            )
857            .await?
858            .total as usize;
859        let mut annotations = vec![];
860        let mut continue_token: Option<String> = None;
861        let mut current = 0;
862
863        if total == 0 {
864            return Ok(annotations);
865        }
866
867        loop {
868            let params = SamplesListParams {
869                dataset_id,
870                annotation_set_id: Some(annotation_set_id),
871                types: annotation_types.iter().map(|t| t.to_string()).collect(),
872                group_names: groups.to_vec(),
873                continue_token,
874            };
875
876            let result: SamplesListResult =
877                self.rpc("samples.list".to_owned(), Some(params)).await?;
878            current += result.samples.len();
879            continue_token = result.continue_token;
880
881            if result.samples.is_empty() {
882                break;
883            }
884
885            for sample in result.samples {
886                // If there are no annotations for the sample, create an empty
887                // annotation for the sample so that it is included in the result.
888                if sample.annotations().is_empty() {
889                    let mut annotation = Annotation::new();
890                    annotation.set_sample_id(sample.id());
891                    annotation.set_name(sample.name());
892                    annotation.set_group(sample.group().cloned());
893                    annotation.set_sequence_name(sample.sequence_name().cloned());
894                    annotations.push(annotation);
895                    continue;
896                }
897
898                sample.annotations().iter().for_each(|annotation| {
899                    let mut annotation = annotation.clone();
900                    annotation.set_sample_id(sample.id());
901                    annotation.set_name(sample.name());
902                    annotation.set_group(sample.group().cloned());
903                    annotation.set_sequence_name(sample.sequence_name().cloned());
904                    annotation.set_label_index(Some(labels[annotation.label().unwrap().as_str()]));
905                    annotations.push(annotation);
906                });
907            }
908
909            if let Some(progress) = &progress {
910                progress.send(Progress { current, total }).await.unwrap();
911            }
912
913            match &continue_token {
914                Some(token) if !token.is_empty() => continue,
915                _ => break,
916            }
917        }
918
919        if let Some(progress) = progress {
920            drop(progress);
921        }
922
923        Ok(annotations)
924    }
925
926    pub async fn samples_count(
927        &self,
928        dataset_id: DatasetID,
929        annotation_set_id: Option<AnnotationSetID>,
930        annotation_types: &[AnnotationType],
931        groups: &[String],
932        types: &[FileType],
933    ) -> Result<SamplesCountResult, Error> {
934        let types = annotation_types
935            .iter()
936            .map(|t| t.to_string())
937            .chain(types.iter().map(|t| t.to_string()))
938            .collect::<Vec<_>>();
939
940        let params = SamplesListParams {
941            dataset_id,
942            annotation_set_id,
943            group_names: groups.to_vec(),
944            types,
945            continue_token: None,
946        };
947
948        self.rpc("samples.count".to_owned(), Some(params)).await
949    }
950
951    pub async fn samples(
952        &self,
953        dataset_id: DatasetID,
954        annotation_set_id: Option<AnnotationSetID>,
955        annotation_types: &[AnnotationType],
956        groups: &[String],
957        types: &[FileType],
958        progress: Option<Sender<Progress>>,
959    ) -> Result<Vec<Sample>, Error> {
960        let types = annotation_types
961            .iter()
962            .map(|t| t.to_string())
963            .chain(types.iter().map(|t| t.to_string()))
964            .collect::<Vec<_>>();
965        let labels = self
966            .labels(dataset_id)
967            .await?
968            .into_iter()
969            .map(|label| (label.name().to_string(), label.index()))
970            .collect::<HashMap<_, _>>();
971        let total = self
972            .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
973            .await?
974            .total as usize;
975
976        let mut samples = vec![];
977        let mut continue_token: Option<String> = None;
978        let mut current = 0;
979
980        if total == 0 {
981            return Ok(samples);
982        }
983
984        loop {
985            let params = SamplesListParams {
986                dataset_id,
987                annotation_set_id,
988                types: types.clone(),
989                group_names: groups.to_vec(),
990                continue_token: continue_token.clone(),
991            };
992
993            let result: SamplesListResult =
994                self.rpc("samples.list".to_owned(), Some(params)).await?;
995            current += result.samples.len();
996            continue_token = result.continue_token;
997
998            if result.samples.is_empty() {
999                break;
1000            }
1001
1002            samples.append(
1003                &mut result
1004                    .samples
1005                    .into_iter()
1006                    .map(|s| {
1007                        let mut anns = s.annotations().to_vec();
1008                        for ann in &mut anns {
1009                            if let Some(label) = ann.label() {
1010                                ann.set_label_index(Some(labels[label.as_str()]));
1011                            }
1012                        }
1013                        s.with_annotations(anns)
1014                    })
1015                    .collect::<Vec<_>>(),
1016            );
1017
1018            if let Some(progress) = &progress {
1019                progress.send(Progress { current, total }).await.unwrap();
1020            }
1021
1022            match &continue_token {
1023                Some(token) if !token.is_empty() => continue,
1024                _ => break,
1025            }
1026        }
1027
1028        if let Some(progress) = progress {
1029            drop(progress);
1030        }
1031
1032        Ok(samples)
1033    }
1034
1035    /// Populates (imports) samples into a dataset using the `samples.populate`
1036    /// API.
1037    ///
1038    /// This method creates new samples in the specified dataset, optionally
1039    /// with annotations and sensor data files. For each sample, the `files`
1040    /// field is checked for local file paths. If a filename is a valid path
1041    /// to an existing file, the file will be automatically uploaded to S3
1042    /// using presigned URLs returned by the server. The filename in the
1043    /// request is replaced with the basename (path removed) before sending
1044    /// to the server.
1045    ///
1046    /// # Important Notes
1047    ///
1048    /// - **`annotation_set_id` is REQUIRED** when importing samples with
1049    ///   annotations. Without it, the server will accept the request but will
1050    ///   not save the annotation data. Use [`Client::annotation_sets`] to query
1051    ///   available annotation sets for a dataset, or create a new one via the
1052    ///   Studio UI.
1053    /// - **Box2d coordinates must be normalized** (0.0-1.0 range) for bounding
1054    ///   boxes. Divide pixel coordinates by image width/height before creating
1055    ///   [`Box2d`](crate::Box2d) annotations.
1056    /// - **Files are uploaded automatically** when the filename is a valid
1057    ///   local path. The method will replace the full path with just the
1058    ///   basename before sending to the server.
1059    /// - **Image dimensions are extracted automatically** for image files using
1060    ///   the `imagesize` crate. The width/height are sent to the server, but
1061    ///   note that the server currently doesn't return these fields when
1062    ///   fetching samples back.
1063    /// - **UUIDs are generated automatically** if not provided. If you need
1064    ///   deterministic UUIDs, set `sample.uuid` explicitly before calling. Note
1065    ///   that the server doesn't currently return UUIDs in sample queries.
1066    ///
1067    /// # Arguments
1068    ///
1069    /// * `dataset_id` - The ID of the dataset to populate
1070    /// * `annotation_set_id` - **Required** if samples contain annotations,
1071    ///   otherwise they will be ignored. Query with
1072    ///   [`Client::annotation_sets`].
1073    /// * `samples` - Vector of samples to import with metadata and file
1074    ///   references. For files, use the full local path - it will be uploaded
1075    ///   automatically. UUIDs and image dimensions will be
1076    ///   auto-generated/extracted if not provided.
1077    ///
1078    /// # Returns
1079    ///
1080    /// Returns the API result with sample UUIDs and upload status.
1081    ///
1082    /// # Example
1083    ///
1084    /// ```no_run
1085    /// use edgefirst_client::{Annotation, Box2d, Client, DatasetID, Sample, SampleFile};
1086    ///
1087    /// # async fn example() -> Result<(), edgefirst_client::Error> {
1088    /// # let client = Client::new()?.with_login("user", "pass").await?;
1089    /// # let dataset_id = DatasetID::from(1);
1090    /// // Query available annotation sets for the dataset
1091    /// let annotation_sets = client.annotation_sets(dataset_id).await?;
1092    /// let annotation_set_id = annotation_sets
1093    ///     .first()
1094    ///     .ok_or_else(|| {
1095    ///         edgefirst_client::Error::InvalidParameters("No annotation sets found".to_string())
1096    ///     })?
1097    ///     .id();
1098    ///
1099    /// // Create sample with annotation (UUID will be auto-generated)
1100    /// let mut sample = Sample::new();
1101    /// sample.width = Some(1920);
1102    /// sample.height = Some(1080);
1103    /// sample.group = Some("train".to_string());
1104    ///
1105    /// // Add file - use full path to local file, it will be uploaded automatically
1106    /// sample.files = vec![SampleFile::with_filename(
1107    ///     "image".to_string(),
1108    ///     "/path/to/image.jpg".to_string(),
1109    /// )];
1110    ///
1111    /// // Add bounding box annotation with NORMALIZED coordinates (0.0-1.0)
1112    /// let mut annotation = Annotation::new();
1113    /// annotation.set_label(Some("person".to_string()));
1114    /// // Normalize pixel coordinates by dividing by image dimensions
1115    /// let bbox = Box2d::new(0.5, 0.5, 0.25, 0.25); // (x, y, w, h) normalized
1116    /// annotation.set_box2d(Some(bbox));
1117    /// sample.annotations = vec![annotation];
1118    ///
1119    /// // Populate with annotation_set_id (REQUIRED for annotations)
1120    /// let result = client
1121    ///     .populate_samples(dataset_id, Some(annotation_set_id), vec![sample], None)
1122    ///     .await?;
1123    /// # Ok(())
1124    /// # }
1125    /// ```
1126    pub async fn populate_samples(
1127        &self,
1128        dataset_id: DatasetID,
1129        annotation_set_id: Option<AnnotationSetID>,
1130        samples: Vec<Sample>,
1131        progress: Option<Sender<Progress>>,
1132    ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1133        use crate::api::SamplesPopulateParams;
1134        use std::path::Path;
1135
1136        let total = samples.len();
1137
1138        // Track which files need to be uploaded: (sample_uuid, file_type, local_path,
1139        // basename)
1140        let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1141
1142        // Process samples to detect local files, extract basenames, and generate UUIDs
1143        let samples: Vec<Sample> = samples
1144            .into_iter()
1145            .map(|mut sample| {
1146                // Generate UUID if not provided
1147                if sample.uuid.is_none() {
1148                    sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1149                }
1150
1151                let sample_uuid = sample.uuid.clone().unwrap();
1152
1153                // Process files: detect local paths and queue for upload
1154                let updated_files: Vec<crate::SampleFile> = sample
1155                    .files
1156                    .iter()
1157                    .map(|file| {
1158                        if let Some(filename) = file.filename() {
1159                            let path = Path::new(filename);
1160
1161                            // Check if this is a valid local file path
1162                            if path.exists() && path.is_file() {
1163                                // Get the basename
1164                                if let Some(basename) = path.file_name().and_then(|s| s.to_str()) {
1165                                    // For image files, try to extract dimensions if not already set
1166                                    if file.file_type() == "image"
1167                                        && (sample.width.is_none() || sample.height.is_none())
1168                                        && let Ok(size) = imagesize::size(path)
1169                                    {
1170                                        sample.width = Some(size.width as u32);
1171                                        sample.height = Some(size.height as u32);
1172                                    }
1173
1174                                    // Store the full path for later upload
1175                                    files_to_upload.push((
1176                                        sample_uuid.clone(),
1177                                        file.file_type().to_string(),
1178                                        path.to_path_buf(),
1179                                        basename.to_string(),
1180                                    ));
1181
1182                                    // Return SampleFile with just the basename
1183                                    return crate::SampleFile::with_filename(
1184                                        file.file_type().to_string(),
1185                                        basename.to_string(),
1186                                    );
1187                                }
1188                            }
1189                        }
1190                        // Return the file unchanged if not a local path
1191                        file.clone()
1192                    })
1193                    .collect();
1194
1195                sample.files = updated_files;
1196                sample
1197            })
1198            .collect();
1199
1200        let has_files_to_upload = !files_to_upload.is_empty();
1201
1202        // Call populate API with presigned_urls=true if we have files to upload
1203        let params = SamplesPopulateParams {
1204            dataset_id,
1205            annotation_set_id,
1206            presigned_urls: if has_files_to_upload {
1207                Some(true)
1208            } else {
1209                Some(false)
1210            },
1211            samples,
1212        };
1213
1214        let results: Vec<crate::SamplesPopulateResult> = self
1215            .rpc("samples.populate".to_owned(), Some(params))
1216            .await?;
1217
1218        // Upload files if we have any
1219        if has_files_to_upload {
1220            // Build a map from (sample_uuid, basename) -> local_path
1221            let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1222            for (uuid, _file_type, path, basename) in files_to_upload {
1223                upload_map.insert((uuid, basename), path);
1224            }
1225
1226            let current = Arc::new(AtomicUsize::new(0));
1227            let sem = Arc::new(Semaphore::new(MAX_TASKS));
1228
1229            // Upload each sample's files in parallel
1230            let upload_tasks = results
1231                .iter()
1232                .map(|result| {
1233                    let sem = sem.clone();
1234                    let http = self.http.clone();
1235                    let current = current.clone();
1236                    let progress = progress.clone();
1237                    let result_uuid = result.uuid.clone();
1238                    let urls = result.urls.clone();
1239                    let upload_map = upload_map.clone();
1240
1241                    tokio::spawn(async move {
1242                        let _permit = sem.acquire().await.unwrap();
1243
1244                        // Upload all files for this sample
1245                        for url_info in &urls {
1246                            if let Some(local_path) =
1247                                upload_map.get(&(result_uuid.clone(), url_info.filename.clone()))
1248                            {
1249                                // Upload the file
1250                                upload_file_to_presigned_url(
1251                                    http.clone(),
1252                                    &url_info.url,
1253                                    local_path.clone(),
1254                                )
1255                                .await?;
1256                            }
1257                        }
1258
1259                        // Update progress after uploading all files for this sample
1260                        if let Some(progress) = &progress {
1261                            let current = current.fetch_add(1, Ordering::SeqCst);
1262                            progress
1263                                .send(Progress {
1264                                    current: current + 1,
1265                                    total,
1266                                })
1267                                .await
1268                                .unwrap();
1269                        }
1270
1271                        Ok::<(), Error>(())
1272                    })
1273                })
1274                .collect::<Vec<_>>();
1275
1276            join_all(upload_tasks)
1277                .await
1278                .into_iter()
1279                .collect::<Result<Vec<_>, _>>()?;
1280        }
1281
1282        if let Some(progress) = progress {
1283            drop(progress);
1284        }
1285
1286        Ok(results)
1287    }
1288
1289    pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1290        for attempt in 1..MAX_RETRIES {
1291            let resp = match self.http.get(url).send().await {
1292                Ok(resp) => resp,
1293                Err(err) => {
1294                    warn!(
1295                        "Socket Error [retry {}/{}]: {:?}",
1296                        attempt, MAX_RETRIES, err
1297                    );
1298                    tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1299                    continue;
1300                }
1301            };
1302
1303            match resp.bytes().await {
1304                Ok(body) => return Ok(body.to_vec()),
1305                Err(err) => {
1306                    warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
1307                    tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1308                    continue;
1309                }
1310            };
1311        }
1312
1313        Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1314    }
1315
1316    /// Get the AnnotationGroup for the specified annotation set with the
1317    /// requested annotation types.  The annotation type is used to filter
1318    /// the annotations returned.  Images which do not have any annotations
1319    /// are included in the result.
1320    ///
1321    /// The result is a DataFrame following the EdgeFirst Dataset Format
1322    /// definition.
1323    ///
1324    /// To get the annotations as a vector of AnnotationGroup objects, use the
1325    /// `annotations` method instead.
1326    #[cfg(feature = "polars")]
1327    pub async fn annotations_dataframe(
1328        &self,
1329        annotation_set_id: AnnotationSetID,
1330        groups: &[String],
1331        types: &[AnnotationType],
1332        progress: Option<Sender<Progress>>,
1333    ) -> Result<DataFrame, Error> {
1334        use crate::dataset::annotations_dataframe;
1335
1336        let annotations = self
1337            .annotations(annotation_set_id, groups, types, progress)
1338            .await?;
1339        Ok(annotations_dataframe(&annotations))
1340    }
1341
1342    /// List available snapshots.  If a name is provided, only snapshots
1343    /// containing that name are returned.
1344    pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1345        let snapshots: Vec<Snapshot> = self
1346            .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1347            .await?;
1348        if let Some(name) = name {
1349            Ok(snapshots
1350                .into_iter()
1351                .filter(|s| s.description().contains(name))
1352                .collect())
1353        } else {
1354            Ok(snapshots)
1355        }
1356    }
1357
1358    /// Get the snapshot with the specified id.
1359    pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1360        let params = HashMap::from([("snapshot_id", snapshot_id)]);
1361        self.rpc("snapshots.get".to_owned(), Some(params)).await
1362    }
1363
1364    /// Create a new snapshot from the file at the specified path.  If the path
1365    /// is a directory then all the files in the directory are uploaded.  The
1366    /// snapshot name will be the specified path, either file or directory.
1367    ///
1368    /// The progress callback can be used to monitor the progress of the upload
1369    /// over a watch channel.
1370    pub async fn create_snapshot(
1371        &self,
1372        path: &str,
1373        progress: Option<Sender<Progress>>,
1374    ) -> Result<Snapshot, Error> {
1375        let path = Path::new(path);
1376
1377        if path.is_dir() {
1378            return self
1379                .create_snapshot_folder(path.to_str().unwrap(), progress)
1380                .await;
1381        }
1382
1383        let name = path.file_name().unwrap().to_str().unwrap();
1384        let total = path.metadata()?.len() as usize;
1385        let current = Arc::new(AtomicUsize::new(0));
1386
1387        if let Some(progress) = &progress {
1388            progress.send(Progress { current: 0, total }).await.unwrap();
1389        }
1390
1391        let params = SnapshotCreateMultipartParams {
1392            snapshot_name: name.to_owned(),
1393            keys: vec![name.to_owned()],
1394            file_sizes: vec![total],
1395        };
1396        let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1397            .rpc(
1398                "snapshots.create_upload_url_multipart".to_owned(),
1399                Some(params),
1400            )
1401            .await?;
1402
1403        let snapshot_id = match multipart.get("snapshot_id") {
1404            Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1405            _ => return Err(Error::InvalidResponse),
1406        };
1407
1408        let snapshot = self.snapshot(snapshot_id).await?;
1409        let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1410        let part_key = format!("{}/{}", part_prefix, name);
1411        let mut part = match multipart.get(&part_key) {
1412            Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1413            _ => return Err(Error::InvalidResponse),
1414        }
1415        .clone();
1416        part.key = Some(part_key);
1417
1418        let params = upload_multipart(
1419            self.http.clone(),
1420            part.clone(),
1421            path.to_path_buf(),
1422            total,
1423            current,
1424            progress.clone(),
1425        )
1426        .await?;
1427
1428        let complete: String = self
1429            .rpc(
1430                "snapshots.complete_multipart_upload".to_owned(),
1431                Some(params),
1432            )
1433            .await?;
1434        debug!("Snapshot Multipart Complete: {:?}", complete);
1435
1436        let params: SnapshotStatusParams = SnapshotStatusParams {
1437            snapshot_id,
1438            status: "available".to_owned(),
1439        };
1440        let _: SnapshotStatusResult = self
1441            .rpc("snapshots.update".to_owned(), Some(params))
1442            .await?;
1443
1444        if let Some(progress) = progress {
1445            drop(progress);
1446        }
1447
1448        self.snapshot(snapshot_id).await
1449    }
1450
1451    async fn create_snapshot_folder(
1452        &self,
1453        path: &str,
1454        progress: Option<Sender<Progress>>,
1455    ) -> Result<Snapshot, Error> {
1456        let path = Path::new(path);
1457        let name = path.file_name().unwrap().to_str().unwrap();
1458
1459        let files = WalkDir::new(path)
1460            .into_iter()
1461            .filter_map(|entry| entry.ok())
1462            .filter(|entry| entry.file_type().is_file())
1463            .map(|entry| entry.path().strip_prefix(path).unwrap().to_owned())
1464            .collect::<Vec<_>>();
1465
1466        let total = files
1467            .iter()
1468            .map(|file| path.join(file).metadata().unwrap().len() as usize)
1469            .sum();
1470        let current = Arc::new(AtomicUsize::new(0));
1471
1472        if let Some(progress) = &progress {
1473            progress.send(Progress { current: 0, total }).await.unwrap();
1474        }
1475
1476        let keys = files
1477            .iter()
1478            .map(|key| key.to_str().unwrap().to_owned())
1479            .collect::<Vec<_>>();
1480        let file_sizes = files
1481            .iter()
1482            .map(|key| path.join(key).metadata().unwrap().len() as usize)
1483            .collect::<Vec<_>>();
1484
1485        let params = SnapshotCreateMultipartParams {
1486            snapshot_name: name.to_owned(),
1487            keys,
1488            file_sizes,
1489        };
1490
1491        let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1492            .rpc(
1493                "snapshots.create_upload_url_multipart".to_owned(),
1494                Some(params),
1495            )
1496            .await?;
1497
1498        let snapshot_id = match multipart.get("snapshot_id") {
1499            Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1500            _ => return Err(Error::InvalidResponse),
1501        };
1502
1503        let snapshot = self.snapshot(snapshot_id).await?;
1504        let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1505
1506        for file in files {
1507            let part_key = format!("{}/{}", part_prefix, file.to_str().unwrap());
1508            let mut part = match multipart.get(&part_key) {
1509                Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1510                _ => return Err(Error::InvalidResponse),
1511            }
1512            .clone();
1513            part.key = Some(part_key);
1514
1515            let params = upload_multipart(
1516                self.http.clone(),
1517                part.clone(),
1518                path.join(file),
1519                total,
1520                current.clone(),
1521                progress.clone(),
1522            )
1523            .await?;
1524
1525            let complete: String = self
1526                .rpc(
1527                    "snapshots.complete_multipart_upload".to_owned(),
1528                    Some(params),
1529                )
1530                .await?;
1531            debug!("Snapshot Part Complete: {:?}", complete);
1532        }
1533
1534        let params = SnapshotStatusParams {
1535            snapshot_id,
1536            status: "available".to_owned(),
1537        };
1538        let _: SnapshotStatusResult = self
1539            .rpc("snapshots.update".to_owned(), Some(params))
1540            .await?;
1541
1542        if let Some(progress) = progress {
1543            drop(progress);
1544        }
1545
1546        self.snapshot(snapshot_id).await
1547    }
1548
1549    /// Downloads a snapshot from the server.  The snapshot could be a single
1550    /// file or a directory of files.  The snapshot is downloaded to the
1551    /// specified path.  A progress callback can be provided to monitor the
1552    /// progress of the download over a watch channel.
1553    pub async fn download_snapshot(
1554        &self,
1555        snapshot_id: SnapshotID,
1556        output: PathBuf,
1557        progress: Option<Sender<Progress>>,
1558    ) -> Result<(), Error> {
1559        fs::create_dir_all(&output).await?;
1560
1561        let params = HashMap::from([("snapshot_id", snapshot_id)]);
1562        let items: HashMap<String, String> = self
1563            .rpc("snapshots.create_download_url".to_owned(), Some(params))
1564            .await?;
1565
1566        let total = Arc::new(AtomicUsize::new(0));
1567        let current = Arc::new(AtomicUsize::new(0));
1568        let sem = Arc::new(Semaphore::new(MAX_TASKS));
1569
1570        let tasks = items
1571            .iter()
1572            .map(|(key, url)| {
1573                let http = self.http.clone();
1574                let key = key.clone();
1575                let url = url.clone();
1576                let output = output.clone();
1577                let progress = progress.clone();
1578                let current = current.clone();
1579                let total = total.clone();
1580                let sem = sem.clone();
1581
1582                tokio::spawn(async move {
1583                    let _permit = sem.acquire().await.unwrap();
1584                    let res = http.get(url).send().await.unwrap();
1585                    let content_length = res.content_length().unwrap() as usize;
1586
1587                    if let Some(progress) = &progress {
1588                        let total = total.fetch_add(content_length, Ordering::SeqCst);
1589                        progress
1590                            .send(Progress {
1591                                current: current.load(Ordering::SeqCst),
1592                                total: total + content_length,
1593                            })
1594                            .await
1595                            .unwrap();
1596                    }
1597
1598                    let mut file = File::create(output.join(key)).await.unwrap();
1599                    let mut stream = res.bytes_stream();
1600
1601                    while let Some(chunk) = stream.next().await {
1602                        let chunk = chunk.unwrap();
1603                        file.write_all(&chunk).await.unwrap();
1604                        let len = chunk.len();
1605
1606                        if let Some(progress) = &progress {
1607                            let total = total.load(Ordering::SeqCst);
1608                            let current = current.fetch_add(len, Ordering::SeqCst);
1609
1610                            progress
1611                                .send(Progress {
1612                                    current: current + len,
1613                                    total,
1614                                })
1615                                .await
1616                                .unwrap();
1617                        }
1618                    }
1619                })
1620            })
1621            .collect::<Vec<_>>();
1622
1623        join_all(tasks)
1624            .await
1625            .into_iter()
1626            .collect::<Result<Vec<_>, _>>()
1627            .unwrap();
1628
1629        Ok(())
1630    }
1631
1632    /// The snapshot restore method is used to restore a snapshot to the server.
1633    /// The restore method can perform a few different operations depending on
1634    /// the snapshot type.
1635    ///
1636    /// The auto-annotation workflow is used to automatically annotate the
1637    /// dataset with 2D masks and boxes using the labels within the
1638    /// autolabel list. If autolabel is empty then the auto-annotation
1639    /// workflow is not used. If the MCAP includes radar or LiDAR data then
1640    /// the auto-annotation workflow will also generate 3D bounding boxes
1641    /// for detected objects.
1642    ///
1643    /// The autodepth flag is used to determine if a depthmap should be
1644    /// automatically generated for the dataset, this will currently only work
1645    /// accurately for Maivin or Raivin cameras.
1646    #[allow(clippy::too_many_arguments)]
1647    pub async fn restore_snapshot(
1648        &self,
1649        project_id: ProjectID,
1650        snapshot_id: SnapshotID,
1651        topics: &[String],
1652        autolabel: &[String],
1653        autodepth: bool,
1654        dataset_name: Option<&str>,
1655        dataset_description: Option<&str>,
1656    ) -> Result<SnapshotRestoreResult, Error> {
1657        let params = SnapshotRestore {
1658            project_id,
1659            snapshot_id,
1660            fps: 1,
1661            autodepth,
1662            agtg_pipeline: !autolabel.is_empty(),
1663            autolabel: autolabel.to_vec(),
1664            topics: topics.to_vec(),
1665            dataset_name: dataset_name.map(|s| s.to_owned()),
1666            dataset_description: dataset_description.map(|s| s.to_owned()),
1667        };
1668        self.rpc("snapshots.restore".to_owned(), Some(params)).await
1669    }
1670
1671    /// Returns a list of experiments available to the user.  The experiments
1672    /// are returned as a vector of Experiment objects.  If name is provided
1673    /// then only experiments containing this string are returned.
1674    ///
1675    /// Experiments provide a method of organizing training and validation
1676    /// sessions together and are akin to an Experiment in MLFlow terminology.  
1677    /// Each experiment can have multiple trainer sessions associated with it,
1678    /// these would be akin to runs in MLFlow terminology.
1679    pub async fn experiments(
1680        &self,
1681        project_id: ProjectID,
1682        name: Option<&str>,
1683    ) -> Result<Vec<Experiment>, Error> {
1684        let params = HashMap::from([("project_id", project_id)]);
1685        let experiments: Vec<Experiment> =
1686            self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1687        if let Some(name) = name {
1688            Ok(experiments
1689                .into_iter()
1690                .filter(|e| e.name().contains(name))
1691                .collect())
1692        } else {
1693            Ok(experiments)
1694        }
1695    }
1696
1697    /// Return the experiment with the specified experiment ID.  If the
1698    /// experiment does not exist, an error is returned.
1699    pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1700        let params = HashMap::from([("trainer_id", experiment_id)]);
1701        self.rpc("trainer.get".to_owned(), Some(params)).await
1702    }
1703
1704    /// Returns a list of trainer sessions available to the user.  The trainer
1705    /// sessions are returned as a vector of TrainingSession objects.  If name
1706    /// is provided then only trainer sessions containing this string are
1707    /// returned.
1708    ///
1709    /// Trainer sessions are akin to runs in MLFlow terminology.  These
1710    /// represent an actual training session which will produce metrics and
1711    /// model artifacts.
1712    pub async fn training_sessions(
1713        &self,
1714        experiment_id: ExperimentID,
1715        name: Option<&str>,
1716    ) -> Result<Vec<TrainingSession>, Error> {
1717        let params = HashMap::from([("trainer_id", experiment_id)]);
1718        let sessions: Vec<TrainingSession> = self
1719            .rpc("trainer.session.list".to_owned(), Some(params))
1720            .await?;
1721        if let Some(name) = name {
1722            Ok(sessions
1723                .into_iter()
1724                .filter(|s| s.name().contains(name))
1725                .collect())
1726        } else {
1727            Ok(sessions)
1728        }
1729    }
1730
1731    /// Return the trainer session with the specified trainer session ID.  If
1732    /// the trainer session does not exist, an error is returned.
1733    pub async fn training_session(
1734        &self,
1735        session_id: TrainingSessionID,
1736    ) -> Result<TrainingSession, Error> {
1737        let params = HashMap::from([("trainer_session_id", session_id)]);
1738        self.rpc("trainer.session.get".to_owned(), Some(params))
1739            .await
1740    }
1741
1742    /// List validation sessions for the given project.
1743    pub async fn validation_sessions(
1744        &self,
1745        project_id: ProjectID,
1746    ) -> Result<Vec<ValidationSession>, Error> {
1747        let params = HashMap::from([("project_id", project_id)]);
1748        self.rpc("validate.session.list".to_owned(), Some(params))
1749            .await
1750    }
1751
1752    /// Retrieve a specific validation session.
1753    pub async fn validation_session(
1754        &self,
1755        session_id: ValidationSessionID,
1756    ) -> Result<ValidationSession, Error> {
1757        let params = HashMap::from([("validate_session_id", session_id)]);
1758        self.rpc("validate.session.get".to_owned(), Some(params))
1759            .await
1760    }
1761
1762    /// List the artifacts for the specified trainer session.  The artifacts
1763    /// are returned as a vector of strings.
1764    pub async fn artifacts(
1765        &self,
1766        training_session_id: TrainingSessionID,
1767    ) -> Result<Vec<Artifact>, Error> {
1768        let params = HashMap::from([("training_session_id", training_session_id)]);
1769        self.rpc("trainer.get_artifacts".to_owned(), Some(params))
1770            .await
1771    }
1772
1773    /// Download the model artifact for the specified trainer session to the
1774    /// specified file path, if path is not provided it will be downloaded to
1775    /// the current directory with the same filename.  A progress callback can
1776    /// be provided to monitor the progress of the download over a watch
1777    /// channel.
1778    pub async fn download_artifact(
1779        &self,
1780        training_session_id: TrainingSessionID,
1781        modelname: &str,
1782        filename: Option<PathBuf>,
1783        progress: Option<Sender<Progress>>,
1784    ) -> Result<(), Error> {
1785        let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
1786        let resp = self
1787            .http
1788            .get(format!(
1789                "{}/download_model?training_session_id={}&file={}",
1790                self.url,
1791                training_session_id.value(),
1792                modelname
1793            ))
1794            .header("Authorization", format!("Bearer {}", self.token().await))
1795            .send()
1796            .await?;
1797        if !resp.status().is_success() {
1798            let err = resp.error_for_status_ref().unwrap_err();
1799            return Err(Error::HttpError(err));
1800        }
1801
1802        fs::create_dir_all(filename.parent().unwrap()).await?;
1803
1804        if let Some(progress) = progress {
1805            let total = resp.content_length().unwrap() as usize;
1806            progress.send(Progress { current: 0, total }).await.unwrap();
1807
1808            let mut file = File::create(filename).await?;
1809            let mut current = 0;
1810            let mut stream = resp.bytes_stream();
1811
1812            while let Some(item) = stream.next().await {
1813                let chunk = item?;
1814                file.write_all(&chunk).await?;
1815                current += chunk.len();
1816                progress.send(Progress { current, total }).await.unwrap();
1817            }
1818        } else {
1819            let body = resp.bytes().await?;
1820            fs::write(filename, body).await?;
1821        }
1822
1823        Ok(())
1824    }
1825
1826    /// Download the model checkpoint associated with the specified trainer
1827    /// session to the specified file path, if path is not provided it will be
1828    /// downloaded to the current directory with the same filename.  A progress
1829    /// callback can be provided to monitor the progress of the download over a
1830    /// watch channel.
1831    ///
1832    /// There is no API for listing checkpoints it is expected that trainers are
1833    /// aware of possible checkpoints and their names within the checkpoint
1834    /// folder on the server.
1835    pub async fn download_checkpoint(
1836        &self,
1837        training_session_id: TrainingSessionID,
1838        checkpoint: &str,
1839        filename: Option<PathBuf>,
1840        progress: Option<Sender<Progress>>,
1841    ) -> Result<(), Error> {
1842        let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
1843        let resp = self
1844            .http
1845            .get(format!(
1846                "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1847                self.url,
1848                training_session_id.value(),
1849                checkpoint
1850            ))
1851            .header("Authorization", format!("Bearer {}", self.token().await))
1852            .send()
1853            .await?;
1854        if !resp.status().is_success() {
1855            let err = resp.error_for_status_ref().unwrap_err();
1856            return Err(Error::HttpError(err));
1857        }
1858
1859        fs::create_dir_all(filename.parent().unwrap()).await?;
1860
1861        if let Some(progress) = progress {
1862            let total = resp.content_length().unwrap() as usize;
1863            progress.send(Progress { current: 0, total }).await.unwrap();
1864
1865            let mut file = File::create(filename).await?;
1866            let mut current = 0;
1867            let mut stream = resp.bytes_stream();
1868
1869            while let Some(item) = stream.next().await {
1870                let chunk = item?;
1871                file.write_all(&chunk).await?;
1872                current += chunk.len();
1873                progress.send(Progress { current, total }).await.unwrap();
1874            }
1875        } else {
1876            let body = resp.bytes().await?;
1877            fs::write(filename, body).await?;
1878        }
1879
1880        Ok(())
1881    }
1882
1883    /// Return a list of tasks for the current user.
1884    pub async fn tasks(
1885        &self,
1886        name: Option<&str>,
1887        workflow: Option<&str>,
1888        status: Option<&str>,
1889        manager: Option<&str>,
1890    ) -> Result<Vec<Task>, Error> {
1891        let mut params = TasksListParams {
1892            continue_token: None,
1893            status: status.map(|s| vec![s.to_owned()]),
1894            manager: manager.map(|m| vec![m.to_owned()]),
1895        };
1896        let mut tasks = Vec::new();
1897
1898        loop {
1899            let result = self
1900                .rpc::<_, TasksListResult>("task.list".to_owned(), Some(&params))
1901                .await?;
1902            tasks.extend(result.tasks);
1903
1904            if result.continue_token.is_none() || result.continue_token == Some("".into()) {
1905                params.continue_token = None;
1906            } else {
1907                params.continue_token = result.continue_token;
1908            }
1909
1910            if params.continue_token.is_none() {
1911                break;
1912            }
1913        }
1914
1915        if let Some(name) = name {
1916            tasks.retain(|t| t.name().contains(name));
1917        }
1918
1919        if let Some(workflow) = workflow {
1920            tasks.retain(|t| t.workflow().contains(workflow));
1921        }
1922
1923        Ok(tasks)
1924    }
1925
1926    /// Retrieve the task information and status.
1927    pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
1928        self.rpc(
1929            "task.get".to_owned(),
1930            Some(HashMap::from([("id", task_id)])),
1931        )
1932        .await
1933    }
1934
1935    /// Updates the tasks status.
1936    pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
1937        let status = TaskStatus {
1938            task_id,
1939            status: status.to_owned(),
1940        };
1941        self.rpc("docker.update.status".to_owned(), Some(status))
1942            .await
1943    }
1944
1945    /// Defines the stages for the task.  The stages are defined as a mapping
1946    /// from stage names to their descriptions.  Once stages are defined their
1947    /// status can be updated using the update_stage method.
1948    pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
1949        let stages: Vec<HashMap<String, String>> = stages
1950            .iter()
1951            .map(|(key, value)| {
1952                let mut stage_map = HashMap::new();
1953                stage_map.insert(key.to_string(), value.to_string());
1954                stage_map
1955            })
1956            .collect();
1957        let params = TaskStages { task_id, stages };
1958        let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
1959        Ok(())
1960    }
1961
1962    /// Updates the progress of the task for the provided stage and status
1963    /// information.
1964    pub async fn update_stage(
1965        &self,
1966        task_id: TaskID,
1967        stage: &str,
1968        status: &str,
1969        message: &str,
1970        percentage: u8,
1971    ) -> Result<(), Error> {
1972        let stage = Stage::new(
1973            Some(task_id),
1974            stage.to_owned(),
1975            Some(status.to_owned()),
1976            Some(message.to_owned()),
1977            percentage,
1978        );
1979        let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
1980        Ok(())
1981    }
1982
1983    /// Raw fetch from the Studio server is used for downloading files.
1984    pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
1985        let req = self
1986            .http
1987            .get(format!("{}/{}", self.url, query))
1988            .header("User-Agent", "EdgeFirst Client")
1989            .header("Authorization", format!("Bearer {}", self.token().await));
1990        let resp = req.send().await?;
1991
1992        if resp.status().is_success() {
1993            let body = resp.bytes().await?;
1994
1995            if log_enabled!(Level::Trace) {
1996                trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
1997            }
1998
1999            Ok(body.to_vec())
2000        } else {
2001            let err = resp.error_for_status_ref().unwrap_err();
2002            Err(Error::HttpError(err))
2003        }
2004    }
2005
2006    /// Sends a multipart post request to the server.  This is used by the
2007    /// upload and download APIs which do not use JSON-RPC but instead transfer
2008    /// files using multipart/form-data.
2009    pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2010        let req = self
2011            .http
2012            .post(format!("{}/api?method={}", self.url, method))
2013            .header("Accept", "application/json")
2014            .header("User-Agent", "EdgeFirst Client")
2015            .header("Authorization", format!("Bearer {}", self.token().await))
2016            .multipart(form);
2017        let resp = req.send().await?;
2018
2019        if resp.status().is_success() {
2020            let body = resp.bytes().await?;
2021
2022            if log_enabled!(Level::Trace) {
2023                trace!(
2024                    "POST Multipart Response: {}",
2025                    String::from_utf8_lossy(&body)
2026                );
2027            }
2028
2029            let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2030                Ok(response) => response,
2031                Err(err) => {
2032                    error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2033                    return Err(err.into());
2034                }
2035            };
2036
2037            if let Some(error) = response.error {
2038                Err(Error::RpcError(error.code, error.message))
2039            } else if let Some(result) = response.result {
2040                Ok(result)
2041            } else {
2042                Err(Error::InvalidResponse)
2043            }
2044        } else {
2045            let err = resp.error_for_status_ref().unwrap_err();
2046            Err(Error::HttpError(err))
2047        }
2048    }
2049
2050    /// Send a JSON-RPC request to the server.  The method is the name of the
2051    /// method to call on the server.  The params are the parameters to pass to
2052    /// the method.  The method and params are serialized into a JSON-RPC
2053    /// request and sent to the server.  The response is deserialized into
2054    /// the specified type and returned to the caller.
2055    ///
2056    /// NOTE: This API would generally not be called directly and instead users
2057    /// should use the higher-level methods provided by the client.
2058    pub async fn rpc<Params, RpcResult>(
2059        &self,
2060        method: String,
2061        params: Option<Params>,
2062    ) -> Result<RpcResult, Error>
2063    where
2064        Params: Serialize,
2065        RpcResult: DeserializeOwned,
2066    {
2067        let auth_expires = self.token_expiration().await?;
2068        if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2069            self.renew_token().await?;
2070        }
2071
2072        self.rpc_without_auth(method, params).await
2073    }
2074
2075    async fn rpc_without_auth<Params, RpcResult>(
2076        &self,
2077        method: String,
2078        params: Option<Params>,
2079    ) -> Result<RpcResult, Error>
2080    where
2081        Params: Serialize,
2082        RpcResult: DeserializeOwned,
2083    {
2084        let request = RpcRequest {
2085            method,
2086            params,
2087            ..Default::default()
2088        };
2089
2090        if log_enabled!(Level::Trace) {
2091            trace!(
2092                "RPC Request: {}",
2093                serde_json::ser::to_string_pretty(&request)?
2094            );
2095        }
2096
2097        for attempt in 0..MAX_RETRIES {
2098            let res = match self
2099                .http
2100                .post(format!("{}/api", self.url))
2101                .header("Accept", "application/json")
2102                .header("User-Agent", "EdgeFirst Client")
2103                .header("Authorization", format!("Bearer {}", self.token().await))
2104                .json(&request)
2105                .send()
2106                .await
2107            {
2108                Ok(res) => res,
2109                Err(err) => {
2110                    warn!("Socket Error: {:?}", err);
2111                    continue;
2112                }
2113            };
2114
2115            if res.status().is_success() {
2116                let body = res.bytes().await?;
2117
2118                if log_enabled!(Level::Trace) {
2119                    trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2120                }
2121
2122                let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2123                    Ok(response) => response,
2124                    Err(err) => {
2125                        error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2126                        return Err(err.into());
2127                    }
2128                };
2129
2130                // FIXME: Studio Server always returns 999 as the id.
2131                // if request.id.to_string() != response.id {
2132                //     return Err(Error::InvalidRpcId(response.id));
2133                // }
2134
2135                if let Some(error) = response.error {
2136                    return Err(Error::RpcError(error.code, error.message));
2137                } else if let Some(result) = response.result {
2138                    return Ok(result);
2139                } else {
2140                    return Err(Error::InvalidResponse);
2141                }
2142            } else {
2143                let err = res.error_for_status_ref().unwrap_err();
2144                warn!("HTTP Error {}: {}", err, res.text().await?);
2145            }
2146
2147            warn!(
2148                "Retrying RPC request (attempt {}/{})...",
2149                attempt + 1,
2150                MAX_RETRIES
2151            );
2152            tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2153        }
2154
2155        Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2156    }
2157}
2158
2159async fn upload_multipart(
2160    http: reqwest::Client,
2161    part: SnapshotPart,
2162    path: PathBuf,
2163    total: usize,
2164    current: Arc<AtomicUsize>,
2165    progress: Option<Sender<Progress>>,
2166) -> Result<SnapshotCompleteMultipartParams, Error> {
2167    let filesize = path.metadata()?.len() as usize;
2168    let n_parts = filesize.div_ceil(PART_SIZE);
2169    let sem = Arc::new(Semaphore::new(MAX_TASKS));
2170
2171    let key = part.key.unwrap();
2172    let upload_id = part.upload_id;
2173
2174    let urls = part.urls.clone();
2175    let etags = Arc::new(tokio::sync::Mutex::new(vec![
2176        EtagPart {
2177            etag: "".to_owned(),
2178            part_number: 0,
2179        };
2180        n_parts
2181    ]));
2182
2183    let tasks = (0..n_parts)
2184        .map(|part| {
2185            let http = http.clone();
2186            let url = urls[part].clone();
2187            let etags = etags.clone();
2188            let path = path.to_owned();
2189            let sem = sem.clone();
2190            let progress = progress.clone();
2191            let current = current.clone();
2192
2193            tokio::spawn(async move {
2194                let _permit = sem.acquire().await?;
2195                let mut etag = None;
2196
2197                for attempt in 0..MAX_RETRIES {
2198                    match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
2199                    {
2200                        Ok(v) => {
2201                            etag = Some(v);
2202                            break;
2203                        }
2204                        Err(err) => {
2205                            warn!("Upload Part Error: {:?}", err);
2206                            tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2207                        }
2208                    }
2209                }
2210
2211                if let Some(etag) = etag {
2212                    let mut etags = etags.lock().await;
2213                    etags[part] = EtagPart {
2214                        etag,
2215                        part_number: part + 1,
2216                    };
2217
2218                    let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2219                    if let Some(progress) = &progress {
2220                        progress
2221                            .send(Progress {
2222                                current: current + PART_SIZE,
2223                                total,
2224                            })
2225                            .await
2226                            .unwrap();
2227                    }
2228
2229                    Ok(())
2230                } else {
2231                    Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2232                }
2233            })
2234        })
2235        .collect::<Vec<_>>();
2236
2237    join_all(tasks)
2238        .await
2239        .into_iter()
2240        .collect::<Result<Vec<_>, _>>()?;
2241
2242    Ok(SnapshotCompleteMultipartParams {
2243        key,
2244        upload_id,
2245        etag_list: etags.lock().await.clone(),
2246    })
2247}
2248
2249async fn upload_part(
2250    http: reqwest::Client,
2251    url: String,
2252    path: PathBuf,
2253    part: usize,
2254    n_parts: usize,
2255) -> Result<String, Error> {
2256    let filesize = path.metadata()?.len() as usize;
2257    let mut file = File::open(path).await.unwrap();
2258    file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2259        .await
2260        .unwrap();
2261    let file = file.take(PART_SIZE as u64);
2262
2263    let body_length = if part + 1 == n_parts {
2264        filesize % PART_SIZE
2265    } else {
2266        PART_SIZE
2267    };
2268
2269    let stream = FramedRead::new(file, BytesCodec::new());
2270    let body = Body::wrap_stream(stream);
2271
2272    let resp = http
2273        .put(url.clone())
2274        .header(CONTENT_LENGTH, body_length)
2275        .body(body)
2276        .send()
2277        .await?
2278        .error_for_status()?;
2279    let etag = resp
2280        .headers()
2281        .get("etag")
2282        .unwrap()
2283        .to_str()
2284        .unwrap()
2285        .to_owned();
2286    // Studio Server requires etag without the quotes.
2287    Ok(etag
2288        .strip_prefix("\"")
2289        .unwrap()
2290        .strip_suffix("\"")
2291        .unwrap()
2292        .to_owned())
2293}
2294
2295/// Upload a complete file to a presigned S3 URL using HTTP PUT.
2296///
2297/// This is used for populate_samples to upload files to S3 after
2298/// receiving presigned URLs from the server.
2299async fn upload_file_to_presigned_url(
2300    http: reqwest::Client,
2301    url: &str,
2302    path: PathBuf,
2303) -> Result<(), Error> {
2304    // Read the entire file into memory
2305    let file_data = fs::read(&path).await?;
2306    let file_size = file_data.len();
2307
2308    // Upload with retry logic
2309    for attempt in 1..=MAX_RETRIES {
2310        match http
2311            .put(url)
2312            .header(CONTENT_LENGTH, file_size)
2313            .body(file_data.clone())
2314            .send()
2315            .await
2316        {
2317            Ok(resp) => {
2318                if resp.status().is_success() {
2319                    debug!(
2320                        "Successfully uploaded file: {:?} ({} bytes)",
2321                        path, file_size
2322                    );
2323                    return Ok(());
2324                } else {
2325                    let status = resp.status();
2326                    let error_text = resp.text().await.unwrap_or_default();
2327                    warn!(
2328                        "Upload failed [attempt {}/{}]: HTTP {} - {}",
2329                        attempt, MAX_RETRIES, status, error_text
2330                    );
2331                }
2332            }
2333            Err(err) => {
2334                warn!(
2335                    "Upload error [attempt {}/{}]: {:?}",
2336                    attempt, MAX_RETRIES, err
2337                );
2338            }
2339        }
2340
2341        if attempt < MAX_RETRIES {
2342            tokio::time::sleep(Duration::from_secs(attempt as u64)).await;
2343        }
2344    }
2345
2346    Err(Error::InvalidParameters(format!(
2347        "Failed to upload file {:?} after {} attempts",
2348        path, MAX_RETRIES
2349    )))
2350}