edgefirst_client/
client.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright © 2025 Au-Zone Technologies. All Rights Reserved.
3
4use crate::{
5    Annotation, Error, Sample, Task,
6    api::{
7        AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8        Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9        SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10        TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11        ValidationSession, ValidationSessionID,
12    },
13    dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23    collections::HashMap,
24    fs::create_dir_all,
25    io::{SeekFrom, Write as _},
26    path::{Path, PathBuf},
27    sync::{
28        Arc,
29        atomic::{AtomicUsize, Ordering},
30    },
31    time::Duration,
32    vec,
33};
34use tokio::{
35    fs::{self, File},
36    io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
37    sync::{RwLock, Semaphore, mpsc::Sender},
38};
39use tokio_util::codec::{BytesCodec, FramedRead};
40use walkdir::WalkDir;
41
42#[cfg(feature = "polars")]
43use polars::prelude::*;
44
45static MAX_TASKS: usize = 32;
46static MAX_RETRIES: u32 = 10;
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49/// Progress information for long-running operations.
50///
51/// This struct tracks the current progress of operations like file uploads,
52/// downloads, or dataset processing. It provides the current count and total
53/// count to enable progress reporting in applications.
54///
55/// # Examples
56///
57/// ```rust
58/// use edgefirst_client::Progress;
59///
60/// let progress = Progress {
61///     current: 25,
62///     total: 100,
63/// };
64/// let percentage = (progress.current as f64 / progress.total as f64) * 100.0;
65/// println!(
66///     "Progress: {:.1}% ({}/{})",
67///     percentage, progress.current, progress.total
68/// );
69/// ```
70#[derive(Debug, Clone)]
71pub struct Progress {
72    /// Current number of completed items.
73    pub current: usize,
74    /// Total number of items to process.
75    pub total: usize,
76}
77
78#[derive(Serialize)]
79struct RpcRequest<Params> {
80    id: u64,
81    jsonrpc: String,
82    method: String,
83    params: Option<Params>,
84}
85
86impl<T> Default for RpcRequest<T> {
87    fn default() -> Self {
88        RpcRequest {
89            id: 0,
90            jsonrpc: "2.0".to_string(),
91            method: "".to_string(),
92            params: None,
93        }
94    }
95}
96
97#[derive(Deserialize)]
98struct RpcError {
99    code: i32,
100    message: String,
101}
102
103#[derive(Deserialize)]
104struct RpcResponse<RpcResult> {
105    id: String,
106    jsonrpc: String,
107    error: Option<RpcError>,
108    result: Option<RpcResult>,
109}
110
111#[derive(Deserialize)]
112struct EmptyResult {}
113
114#[derive(Debug, Serialize)]
115struct SnapshotCreateParams {
116    snapshot_name: String,
117    keys: Vec<String>,
118}
119
120#[derive(Debug, Deserialize)]
121struct SnapshotCreateResult {
122    snapshot_id: SnapshotID,
123    urls: Vec<String>,
124}
125
126#[derive(Debug, Serialize)]
127struct SnapshotCreateMultipartParams {
128    snapshot_name: String,
129    keys: Vec<String>,
130    file_sizes: Vec<usize>,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum SnapshotCreateMultipartResultField {
136    Id(u64),
137    Part(SnapshotPart),
138}
139
140#[derive(Debug, Serialize)]
141struct SnapshotCompleteMultipartParams {
142    key: String,
143    upload_id: String,
144    etag_list: Vec<EtagPart>,
145}
146
147#[derive(Debug, Clone, Serialize)]
148struct EtagPart {
149    #[serde(rename = "ETag")]
150    etag: String,
151    #[serde(rename = "PartNumber")]
152    part_number: usize,
153}
154
155#[derive(Debug, Clone, Deserialize)]
156struct SnapshotPart {
157    key: Option<String>,
158    upload_id: String,
159    urls: Vec<String>,
160}
161
162#[derive(Debug, Serialize)]
163struct SnapshotStatusParams {
164    snapshot_id: SnapshotID,
165    status: String,
166}
167
168#[derive(Deserialize, Debug)]
169struct SnapshotStatusResult {
170    pub id: SnapshotID,
171    pub uid: String,
172    pub description: String,
173    pub date: String,
174    pub status: String,
175}
176
177#[derive(Serialize)]
178struct ImageListParams {
179    images_filter: ImagesFilter,
180    image_files_filter: HashMap<String, String>,
181    only_ids: bool,
182}
183
184#[derive(Serialize)]
185struct ImagesFilter {
186    dataset_id: DatasetID,
187}
188
189/// Main client for interacting with EdgeFirst Studio Server.
190///
191/// The EdgeFirst Client handles the connection to the EdgeFirst Studio Server
192/// and manages authentication, RPC calls, and data operations. It provides
193/// methods for managing projects, datasets, experiments, training sessions,
194/// and various utility functions for data processing.
195///
196/// The client supports multiple authentication methods and can work with both
197/// SaaS and self-hosted EdgeFirst Studio instances.
198///
199/// # Features
200///
201/// - **Authentication**: Token-based authentication with automatic persistence
202/// - **Dataset Management**: Upload, download, and manipulate datasets
203/// - **Project Operations**: Create and manage projects and experiments
204/// - **Training & Validation**: Submit and monitor ML training jobs
205/// - **Data Integration**: Convert between EdgeFirst datasets and popular
206///   formats
207/// - **Progress Tracking**: Real-time progress updates for long-running
208///   operations
209///
210/// # Examples
211///
212/// ```no_run
213/// use edgefirst_client::{Client, DatasetID};
214/// use std::str::FromStr;
215///
216/// # async fn example() -> Result<(), edgefirst_client::Error> {
217/// // Create a new client and authenticate
218/// let mut client = Client::new()?;
219/// let client = client
220///     .with_login("your-email@example.com", "password")
221///     .await?;
222///
223/// // Or use an existing token
224/// let base_client = Client::new()?;
225/// let client = base_client.with_token("your-token-here")?;
226///
227/// // Get organization and projects
228/// let org = client.organization().await?;
229/// let projects = client.projects(None).await?;
230///
231/// // Work with datasets
232/// let dataset_id = DatasetID::from_str("ds-abc123")?;
233/// let dataset = client.dataset(dataset_id).await?;
234/// # Ok(())
235/// # }
236/// ```
237#[derive(Clone, Debug)]
238pub struct Client {
239    http: reqwest::Client,
240    url: String,
241    token: Arc<RwLock<String>>,
242    token_path: Option<PathBuf>,
243}
244
245impl Client {
246    /// Create a new unauthenticated client with the default saas server.  To
247    /// connect to a different server use the `with_server` method or with the
248    /// `with_token` method to create a client with a token which includes the
249    /// server instance name (test, stage, saas).
250    ///
251    /// This client is created without a token and will need to login before
252    /// using any methods that require authentication.  Use the `with_token`
253    /// method to create a client with a token.
254    pub fn new() -> Result<Self, Error> {
255        Ok(Client {
256            http: reqwest::Client::builder()
257                .read_timeout(Duration::from_secs(60))
258                .build()?,
259            url: "https://edgefirst.studio".to_string(),
260            token: Arc::new(tokio::sync::RwLock::new("".to_string())),
261            token_path: None,
262        })
263    }
264
265    /// Returns a new client connected to the specified server instance.  If a
266    /// token is already set in the client then it will be dropped as the token
267    /// is specific to the server instance.
268    pub fn with_server(&self, server: &str) -> Result<Self, Error> {
269        Ok(Client {
270            url: format!("https://{}.edgefirst.studio", server),
271            ..self.clone()
272        })
273    }
274
275    /// Returns a new client authenticated with the provided username and
276    /// password.
277    pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
278        let params = HashMap::from([("username", username), ("password", password)]);
279        let login: LoginResult = self
280            .rpc_without_auth("auth.login".to_owned(), Some(params))
281            .await?;
282        Ok(Client {
283            token: Arc::new(tokio::sync::RwLock::new(login.token)),
284            ..self.clone()
285        })
286    }
287
288    /// Returns a new client which will load and save the token to the specified
289    /// path.
290    pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
291        let token_path = match token_path {
292            Some(path) => path.to_path_buf(),
293            None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
294                .unwrap()
295                .config_dir()
296                .join("token"),
297        };
298
299        debug!("Using token path: {:?}", token_path);
300
301        let token = match token_path.exists() {
302            true => std::fs::read_to_string(&token_path)?,
303            false => "".to_string(),
304        };
305
306        if !token.is_empty() {
307            let client = self.with_token(&token)?;
308            Ok(Client {
309                token_path: Some(token_path),
310                ..client
311            })
312        } else {
313            Ok(Client {
314                token_path: Some(token_path),
315                ..self.clone()
316            })
317        }
318    }
319
320    /// Returns a new client authenticated with the provided token.
321    pub fn with_token(&self, token: &str) -> Result<Self, Error> {
322        if token.is_empty() {
323            return Ok(self.clone());
324        }
325
326        let token_parts: Vec<&str> = token.split('.').collect();
327        if token_parts.len() != 3 {
328            return Err(Error::InvalidToken);
329        }
330
331        let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
332            .decode(token_parts[1])
333            .unwrap();
334        let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
335        let server = match payload.get("database") {
336            Some(value) => Ok(value.as_str().unwrap().to_string()),
337            None => Err(Error::InvalidToken),
338        }?;
339
340        Ok(Client {
341            url: format!("https://{}.edgefirst.studio", server),
342            token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
343            ..self.clone()
344        })
345    }
346
347    pub async fn save_token(&self) -> Result<(), Error> {
348        let path = self.token_path.clone().unwrap_or_else(|| {
349            ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
350                .unwrap()
351                .config_dir()
352                .join("token")
353        });
354
355        create_dir_all(path.parent().unwrap())?;
356        let mut file = std::fs::File::create(&path)?;
357        file.write_all(self.token.read().await.as_bytes())?;
358
359        debug!("Saved token to {:?}", path);
360
361        Ok(())
362    }
363
364    /// Return the version of the EdgeFirst Studio server for the current
365    /// client connection.
366    pub async fn version(&self) -> Result<String, Error> {
367        let version: HashMap<String, String> = self
368            .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
369            .await?;
370        let version = version.get("version").ok_or(Error::InvalidResponse)?;
371        Ok(version.to_owned())
372    }
373
374    /// Clear the token used to authenticate the client with the server.  If an
375    /// optional path was provided when creating the client, the token file
376    /// will also be cleared.
377    pub async fn logout(&self) -> Result<(), Error> {
378        {
379            let mut token = self.token.write().await;
380            *token = "".to_string();
381        }
382
383        if let Some(path) = &self.token_path
384            && path.exists()
385        {
386            fs::remove_file(path).await?;
387        }
388
389        Ok(())
390    }
391
392    /// Return the token used to authenticate the client with the server.  When
393    /// logging into the server using a username and password, the token is
394    /// returned by the server and stored in the client for future interactions.
395    pub async fn token(&self) -> String {
396        self.token.read().await.clone()
397    }
398
399    /// Verify the token used to authenticate the client with the server.  This
400    /// method is used to ensure that the token is still valid and has not
401    /// expired.  If the token is invalid, the server will return an error and
402    /// the client will need to login again.
403    pub async fn verify_token(&self) -> Result<(), Error> {
404        self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
405            .await?;
406        Ok::<(), Error>(())
407    }
408
409    /// Renew the token used to authenticate the client with the server.  This
410    /// method is used to refresh the token before it expires.  If the token
411    /// has already expired, the server will return an error and the client
412    /// will need to login again.
413    pub async fn renew_token(&self) -> Result<(), Error> {
414        let params = HashMap::from([("username".to_string(), self.username().await?)]);
415        let result: LoginResult = self
416            .rpc_without_auth("auth.refresh".to_owned(), Some(params))
417            .await?;
418
419        {
420            let mut token = self.token.write().await;
421            *token = result.token;
422        }
423
424        if self.token_path.is_some() {
425            self.save_token().await?;
426        }
427
428        Ok(())
429    }
430
431    async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
432        let token = self.token.read().await;
433        if token.is_empty() {
434            return Err(Error::EmptyToken);
435        }
436
437        let token_parts: Vec<&str> = token.split('.').collect();
438        if token_parts.len() != 3 {
439            return Err(Error::InvalidToken);
440        }
441
442        let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
443            .decode(token_parts[1])
444            .unwrap();
445        let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
446        match payload.get(field) {
447            Some(value) => Ok(value.to_owned()),
448            None => Err(Error::InvalidToken),
449        }
450    }
451
452    /// Returns the URL of the EdgeFirst Studio server for the current client.
453    pub fn url(&self) -> &str {
454        &self.url
455    }
456
457    /// Returns the username associated with the current token.
458    pub async fn username(&self) -> Result<String, Error> {
459        match self.token_field("username").await? {
460            serde_json::Value::String(username) => Ok(username),
461            _ => Err(Error::InvalidToken),
462        }
463    }
464
465    /// Returns the expiration time for the current token.
466    pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
467        let ts = match self.token_field("exp").await? {
468            serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
469            _ => return Err(Error::InvalidToken),
470        };
471
472        match DateTime::<Utc>::from_timestamp_secs(ts) {
473            Some(dt) => Ok(dt),
474            None => Err(Error::InvalidToken),
475        }
476    }
477
478    /// Returns the organization information for the current user.
479    pub async fn organization(&self) -> Result<Organization, Error> {
480        self.rpc::<(), Organization>("org.get".to_owned(), None)
481            .await
482    }
483
484    /// Returns a list of projects available to the user.  The projects are
485    /// returned as a vector of Project objects.  If a name filter is
486    /// provided, only projects matching the filter are returned.
487    ///
488    /// Projects are the top-level organizational unit in EdgeFirst Studio.
489    /// Projects contain datasets, trainers, and trainer sessions.  Projects
490    /// are used to group related datasets and trainers together.
491    pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
492        let projects = self
493            .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
494            .await?;
495        if let Some(name) = name {
496            Ok(projects
497                .into_iter()
498                .filter(|p| p.name().contains(name))
499                .collect())
500        } else {
501            Ok(projects)
502        }
503    }
504
505    /// Return the project with the specified project ID.  If the project does
506    /// not exist, an error is returned.
507    pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
508        let params = HashMap::from([("project_id", project_id)]);
509        self.rpc("project.get".to_owned(), Some(params)).await
510    }
511
512    /// Returns a list of datasets available to the user.  The datasets are
513    /// returned as a vector of Dataset objects.  If a name filter is
514    /// provided, only datasets matching the filter are returned.
515    pub async fn datasets(
516        &self,
517        project_id: ProjectID,
518        name: Option<&str>,
519    ) -> Result<Vec<Dataset>, Error> {
520        let params = HashMap::from([("project_id", project_id)]);
521        let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
522        if let Some(name) = name {
523            Ok(datasets
524                .into_iter()
525                .filter(|d| d.name().contains(name))
526                .collect())
527        } else {
528            Ok(datasets)
529        }
530    }
531
532    /// Return the dataset with the specified dataset ID.  If the dataset does
533    /// not exist, an error is returned.
534    pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
535        let params = HashMap::from([("dataset_id", dataset_id)]);
536        self.rpc("dataset.get".to_owned(), Some(params)).await
537    }
538
539    /// Lists the labels for the specified dataset.
540    pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
541        let params = HashMap::from([("dataset_id", dataset_id)]);
542        self.rpc("label.list".to_owned(), Some(params)).await
543    }
544
545    /// Add a new label to the dataset with the specified name.
546    pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
547        let new_label = NewLabel {
548            dataset_id,
549            labels: vec![NewLabelObject {
550                name: name.to_owned(),
551            }],
552        };
553        let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
554        Ok(())
555    }
556
557    /// Removes the label with the specified ID from the dataset.  Label IDs are
558    /// globally unique so the dataset_id is not required.
559    pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
560        let params = HashMap::from([("label_id", label_id)]);
561        let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
562        Ok(())
563    }
564
565    /// Updates the label with the specified ID to have the new name or index.
566    /// Label IDs cannot be changed.  Label IDs are globally unique so the
567    /// dataset_id is not required.
568    pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
569        #[derive(Serialize)]
570        struct Params {
571            dataset_id: DatasetID,
572            label_id: u64,
573            label_name: String,
574            label_index: u64,
575        }
576
577        let _: String = self
578            .rpc(
579                "label.update".to_owned(),
580                Some(Params {
581                    dataset_id: label.dataset_id(),
582                    label_id: label.id(),
583                    label_name: label.name().to_owned(),
584                    label_index: label.index(),
585                }),
586            )
587            .await?;
588        Ok(())
589    }
590
591    pub async fn download_dataset(
592        &self,
593        dataset_id: DatasetID,
594        groups: &[String],
595        file_types: &[FileType],
596        output: PathBuf,
597        progress: Option<Sender<Progress>>,
598    ) -> Result<(), Error> {
599        let samples = self
600            .samples(dataset_id, None, &[], groups, file_types, progress.clone())
601            .await?;
602        fs::create_dir_all(&output).await?;
603
604        let total = samples.len();
605        let current = Arc::new(AtomicUsize::new(0));
606        let sem = Arc::new(Semaphore::new(MAX_TASKS));
607
608        let tasks = samples
609            .into_iter()
610            .map(|sample| {
611                let sem = sem.clone();
612                let client = self.clone();
613                let current = current.clone();
614                let progress = progress.clone();
615                let file_types = file_types.to_vec();
616                let output = output.clone();
617
618                tokio::spawn(async move {
619                    let _permit = sem.acquire().await.unwrap();
620
621                    for file_type in file_types {
622                        if let Some(data) = sample.download(&client, file_type.clone()).await? {
623                            let file_ext = match file_type {
624                                FileType::Image => infer::get(&data)
625                                    .expect("Failed to identify image file format for sample")
626                                    .extension()
627                                    .to_string(),
628                                t => t.to_string(),
629                            };
630
631                            let file_name = format!("{}.{}", sample.name(), file_ext);
632                            let file_path = output.join(&file_name);
633
634                            let mut file = File::create(&file_path).await?;
635                            file.write_all(&data).await?;
636                        } else {
637                            warn!("No data for sample: {}", sample.id());
638                        }
639                    }
640
641                    if let Some(progress) = &progress {
642                        let current = current.fetch_add(1, Ordering::SeqCst);
643                        progress
644                            .send(Progress {
645                                current: current + 1,
646                                total,
647                            })
648                            .await
649                            .unwrap();
650                    }
651
652                    Ok::<(), Error>(())
653                })
654            })
655            .collect::<Vec<_>>();
656
657        join_all(tasks)
658            .await
659            .into_iter()
660            .collect::<Result<Vec<_>, _>>()?;
661
662        if let Some(progress) = progress {
663            drop(progress);
664        }
665
666        Ok(())
667    }
668
669    /// List available annotation sets for the specified dataset.
670    pub async fn annotation_sets(
671        &self,
672        dataset_id: DatasetID,
673    ) -> Result<Vec<AnnotationSet>, Error> {
674        let params = HashMap::from([("dataset_id", dataset_id)]);
675        self.rpc("annset.list".to_owned(), Some(params)).await
676    }
677
678    /// Retrieve the annotation set with the specified ID.
679    pub async fn annotation_set(
680        &self,
681        annotation_set_id: AnnotationSetID,
682    ) -> Result<AnnotationSet, Error> {
683        let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
684        self.rpc("annset.get".to_owned(), Some(params)).await
685    }
686
687    /// Get the annotations for the specified annotation set with the
688    /// requested annotation types.  The annotation types are used to filter
689    /// the annotations returned.  The groups parameter is used to filter for
690    /// dataset groups (train, val, test).  Images which do not have any
691    /// annotations are also included in the result as long as they are in the
692    /// requested groups (when specified).
693    ///
694    /// The result is a vector of Annotations objects which contain the
695    /// full dataset along with the annotations for the specified types.
696    ///
697    /// To get the annotations as a DataFrame, use the `annotations_dataframe`
698    /// method instead.
699    pub async fn annotations(
700        &self,
701        annotation_set_id: AnnotationSetID,
702        groups: &[String],
703        annotation_types: &[AnnotationType],
704        progress: Option<Sender<Progress>>,
705    ) -> Result<Vec<Annotation>, Error> {
706        let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
707        let labels = self
708            .labels(dataset_id)
709            .await?
710            .into_iter()
711            .map(|label| (label.name().to_string(), label.index()))
712            .collect::<HashMap<_, _>>();
713        let total = self
714            .samples_count(
715                dataset_id,
716                Some(annotation_set_id),
717                annotation_types,
718                groups,
719                &[],
720            )
721            .await?
722            .total as usize;
723        let mut annotations = vec![];
724        let mut continue_token: Option<String> = None;
725        let mut current = 0;
726
727        if total == 0 {
728            return Ok(annotations);
729        }
730
731        loop {
732            let params = SamplesListParams {
733                dataset_id,
734                annotation_set_id: Some(annotation_set_id),
735                types: annotation_types.iter().map(|t| t.to_string()).collect(),
736                group_names: groups.to_vec(),
737                continue_token,
738            };
739
740            let result: SamplesListResult =
741                self.rpc("samples.list".to_owned(), Some(params)).await?;
742            current += result.samples.len();
743            continue_token = result.continue_token;
744
745            if result.samples.is_empty() {
746                break;
747            }
748
749            for sample in result.samples {
750                // If there are no annotations for the sample, create an empty
751                // annotation for the sample so that it is included in the result.
752                if sample.annotations().is_empty() {
753                    let mut annotation = Annotation::new();
754                    annotation.set_sample_id(Some(sample.id()));
755                    annotation.set_name(Some(sample.name().to_string()));
756                    annotation.set_group(sample.group().cloned());
757                    annotation.set_sequence_name(sample.sequence_name().cloned());
758                    annotations.push(annotation);
759                    continue;
760                }
761
762                sample.annotations().iter().for_each(|annotation| {
763                    let mut annotation = annotation.clone();
764                    annotation.set_sample_id(Some(sample.id()));
765                    annotation.set_name(Some(sample.name().to_string()));
766                    annotation.set_group(sample.group().cloned());
767                    annotation.set_sequence_name(sample.sequence_name().cloned());
768                    annotation.set_label_index(Some(labels[annotation.label().unwrap().as_str()]));
769                    annotations.push(annotation);
770                });
771            }
772
773            if let Some(progress) = &progress {
774                progress.send(Progress { current, total }).await.unwrap();
775            }
776
777            match &continue_token {
778                Some(token) if !token.is_empty() => continue,
779                _ => break,
780            }
781        }
782
783        if let Some(progress) = progress {
784            drop(progress);
785        }
786
787        Ok(annotations)
788    }
789
790    pub async fn samples_count(
791        &self,
792        dataset_id: DatasetID,
793        annotation_set_id: Option<AnnotationSetID>,
794        annotation_types: &[AnnotationType],
795        groups: &[String],
796        types: &[FileType],
797    ) -> Result<SamplesCountResult, Error> {
798        let types = annotation_types
799            .iter()
800            .map(|t| t.to_string())
801            .chain(types.iter().map(|t| t.to_string()))
802            .collect::<Vec<_>>();
803
804        let params = SamplesListParams {
805            dataset_id,
806            annotation_set_id,
807            group_names: groups.to_vec(),
808            types,
809            continue_token: None,
810        };
811
812        self.rpc("samples.count".to_owned(), Some(params)).await
813    }
814
815    pub async fn samples(
816        &self,
817        dataset_id: DatasetID,
818        annotation_set_id: Option<AnnotationSetID>,
819        annotation_types: &[AnnotationType],
820        groups: &[String],
821        types: &[FileType],
822        progress: Option<Sender<Progress>>,
823    ) -> Result<Vec<Sample>, Error> {
824        let types = annotation_types
825            .iter()
826            .map(|t| t.to_string())
827            .chain(types.iter().map(|t| t.to_string()))
828            .collect::<Vec<_>>();
829        let labels = self
830            .labels(dataset_id)
831            .await?
832            .into_iter()
833            .map(|label| (label.name().to_string(), label.index()))
834            .collect::<HashMap<_, _>>();
835        let total = self
836            .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
837            .await?
838            .total as usize;
839
840        let mut samples = vec![];
841        let mut continue_token: Option<String> = None;
842        let mut current = 0;
843
844        if total == 0 {
845            return Ok(samples);
846        }
847
848        loop {
849            let params = SamplesListParams {
850                dataset_id,
851                annotation_set_id,
852                types: types.clone(),
853                group_names: groups.to_vec(),
854                continue_token: continue_token.clone(),
855            };
856
857            let result: SamplesListResult =
858                self.rpc("samples.list".to_owned(), Some(params)).await?;
859            current += result.samples.len();
860            continue_token = result.continue_token;
861
862            if result.samples.is_empty() {
863                break;
864            }
865
866            samples.append(
867                &mut result
868                    .samples
869                    .into_iter()
870                    .map(|s| {
871                        let mut anns = s.annotations().to_vec();
872                        for ann in &mut anns {
873                            if let Some(label) = ann.label() {
874                                ann.set_label_index(Some(labels[label.as_str()]));
875                            }
876                        }
877                        s.with_annotations(anns)
878                    })
879                    .collect::<Vec<_>>(),
880            );
881
882            if let Some(progress) = &progress {
883                progress.send(Progress { current, total }).await.unwrap();
884            }
885
886            match &continue_token {
887                Some(token) if !token.is_empty() => continue,
888                _ => break,
889            }
890        }
891
892        if let Some(progress) = progress {
893            drop(progress);
894        }
895
896        Ok(samples)
897    }
898
899    pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
900        for attempt in 1..MAX_RETRIES {
901            let resp = match self.http.get(url).send().await {
902                Ok(resp) => resp,
903                Err(err) => {
904                    warn!(
905                        "Socket Error [retry {}/{}]: {:?}",
906                        attempt, MAX_RETRIES, err
907                    );
908                    tokio::time::sleep(Duration::from_secs(1) * attempt).await;
909                    continue;
910                }
911            };
912
913            match resp.bytes().await {
914                Ok(body) => return Ok(body.to_vec()),
915                Err(err) => {
916                    warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
917                    tokio::time::sleep(Duration::from_secs(1) * attempt).await;
918                    continue;
919                }
920            };
921        }
922
923        Err(Error::MaxRetriesExceeded(MAX_RETRIES))
924    }
925
926    /// Get the AnnotationGroup for the specified annotation set with the
927    /// requested annotation types.  The annotation type is used to filter
928    /// the annotations returned.  Images which do not have any annotations
929    /// are included in the result.
930    ///
931    /// The result is a DataFrame following the EdgeFirst Dataset Format
932    /// definition.
933    ///
934    /// To get the annotations as a vector of AnnotationGroup objects, use the
935    /// `annotations` method instead.
936    #[cfg(feature = "polars")]
937    pub async fn annotations_dataframe(
938        &self,
939        annotation_set_id: AnnotationSetID,
940        groups: &[String],
941        types: &[AnnotationType],
942        progress: Option<Sender<Progress>>,
943    ) -> Result<DataFrame, Error> {
944        use crate::dataset::annotations_dataframe;
945
946        let annotations = self
947            .annotations(annotation_set_id, groups, types, progress)
948            .await?;
949        Ok(annotations_dataframe(&annotations))
950    }
951
952    /// List available snapshots.  If a name is provided, only snapshots
953    /// containing that name are returned.
954    pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
955        let snapshots: Vec<Snapshot> = self
956            .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
957            .await?;
958        if let Some(name) = name {
959            Ok(snapshots
960                .into_iter()
961                .filter(|s| s.description().contains(name))
962                .collect())
963        } else {
964            Ok(snapshots)
965        }
966    }
967
968    /// Get the snapshot with the specified id.
969    pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
970        let params = HashMap::from([("snapshot_id", snapshot_id)]);
971        self.rpc("snapshots.get".to_owned(), Some(params)).await
972    }
973
974    /// Create a new snapshot from the file at the specified path.  If the path
975    /// is a directory then all the files in the directory are uploaded.  The
976    /// snapshot name will be the specified path, either file or directory.
977    ///
978    /// The progress callback can be used to monitor the progress of the upload
979    /// over a watch channel.
980    pub async fn create_snapshot(
981        &self,
982        path: &str,
983        progress: Option<Sender<Progress>>,
984    ) -> Result<Snapshot, Error> {
985        let path = Path::new(path);
986
987        if path.is_dir() {
988            return self
989                .create_snapshot_folder(path.to_str().unwrap(), progress)
990                .await;
991        }
992
993        let name = path.file_name().unwrap().to_str().unwrap();
994        let total = path.metadata()?.len() as usize;
995        let current = Arc::new(AtomicUsize::new(0));
996
997        if let Some(progress) = &progress {
998            progress.send(Progress { current: 0, total }).await.unwrap();
999        }
1000
1001        let params = SnapshotCreateMultipartParams {
1002            snapshot_name: name.to_owned(),
1003            keys: vec![name.to_owned()],
1004            file_sizes: vec![total],
1005        };
1006        let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1007            .rpc(
1008                "snapshots.create_upload_url_multipart".to_owned(),
1009                Some(params),
1010            )
1011            .await?;
1012
1013        let snapshot_id = match multipart.get("snapshot_id") {
1014            Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1015            _ => return Err(Error::InvalidResponse),
1016        };
1017
1018        let snapshot = self.snapshot(snapshot_id).await?;
1019        let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1020        let part_key = format!("{}/{}", part_prefix, name);
1021        let mut part = match multipart.get(&part_key) {
1022            Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1023            _ => return Err(Error::InvalidResponse),
1024        }
1025        .clone();
1026        part.key = Some(part_key);
1027
1028        let params = upload_multipart(
1029            self.http.clone(),
1030            part.clone(),
1031            path.to_path_buf(),
1032            total,
1033            current,
1034            progress.clone(),
1035        )
1036        .await?;
1037
1038        let complete: String = self
1039            .rpc(
1040                "snapshots.complete_multipart_upload".to_owned(),
1041                Some(params),
1042            )
1043            .await?;
1044        debug!("Snapshot Multipart Complete: {:?}", complete);
1045
1046        let params: SnapshotStatusParams = SnapshotStatusParams {
1047            snapshot_id,
1048            status: "available".to_owned(),
1049        };
1050        let _: SnapshotStatusResult = self
1051            .rpc("snapshots.update".to_owned(), Some(params))
1052            .await?;
1053
1054        if let Some(progress) = progress {
1055            drop(progress);
1056        }
1057
1058        self.snapshot(snapshot_id).await
1059    }
1060
1061    async fn create_snapshot_folder(
1062        &self,
1063        path: &str,
1064        progress: Option<Sender<Progress>>,
1065    ) -> Result<Snapshot, Error> {
1066        let path = Path::new(path);
1067        let name = path.file_name().unwrap().to_str().unwrap();
1068
1069        let files = WalkDir::new(path)
1070            .into_iter()
1071            .filter_map(|entry| entry.ok())
1072            .filter(|entry| entry.file_type().is_file())
1073            .map(|entry| entry.path().strip_prefix(path).unwrap().to_owned())
1074            .collect::<Vec<_>>();
1075
1076        let total = files
1077            .iter()
1078            .map(|file| path.join(file).metadata().unwrap().len() as usize)
1079            .sum();
1080        let current = Arc::new(AtomicUsize::new(0));
1081
1082        if let Some(progress) = &progress {
1083            progress.send(Progress { current: 0, total }).await.unwrap();
1084        }
1085
1086        let keys = files
1087            .iter()
1088            .map(|key| key.to_str().unwrap().to_owned())
1089            .collect::<Vec<_>>();
1090        let file_sizes = files
1091            .iter()
1092            .map(|key| path.join(key).metadata().unwrap().len() as usize)
1093            .collect::<Vec<_>>();
1094
1095        let params = SnapshotCreateMultipartParams {
1096            snapshot_name: name.to_owned(),
1097            keys,
1098            file_sizes,
1099        };
1100
1101        let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1102            .rpc(
1103                "snapshots.create_upload_url_multipart".to_owned(),
1104                Some(params),
1105            )
1106            .await?;
1107
1108        let snapshot_id = match multipart.get("snapshot_id") {
1109            Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1110            _ => return Err(Error::InvalidResponse),
1111        };
1112
1113        let snapshot = self.snapshot(snapshot_id).await?;
1114        let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1115
1116        for file in files {
1117            let part_key = format!("{}/{}", part_prefix, file.to_str().unwrap());
1118            let mut part = match multipart.get(&part_key) {
1119                Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1120                _ => return Err(Error::InvalidResponse),
1121            }
1122            .clone();
1123            part.key = Some(part_key);
1124
1125            let params = upload_multipart(
1126                self.http.clone(),
1127                part.clone(),
1128                path.join(file),
1129                total,
1130                current.clone(),
1131                progress.clone(),
1132            )
1133            .await?;
1134
1135            let complete: String = self
1136                .rpc(
1137                    "snapshots.complete_multipart_upload".to_owned(),
1138                    Some(params),
1139                )
1140                .await?;
1141            debug!("Snapshot Part Complete: {:?}", complete);
1142        }
1143
1144        let params = SnapshotStatusParams {
1145            snapshot_id,
1146            status: "available".to_owned(),
1147        };
1148        let _: SnapshotStatusResult = self
1149            .rpc("snapshots.update".to_owned(), Some(params))
1150            .await?;
1151
1152        if let Some(progress) = progress {
1153            drop(progress);
1154        }
1155
1156        self.snapshot(snapshot_id).await
1157    }
1158
1159    /// Downloads a snapshot from the server.  The snapshot could be a single
1160    /// file or a directory of files.  The snapshot is downloaded to the
1161    /// specified path.  A progress callback can be provided to monitor the
1162    /// progress of the download over a watch channel.
1163    pub async fn download_snapshot(
1164        &self,
1165        snapshot_id: SnapshotID,
1166        output: PathBuf,
1167        progress: Option<Sender<Progress>>,
1168    ) -> Result<(), Error> {
1169        fs::create_dir_all(&output).await?;
1170
1171        let params = HashMap::from([("snapshot_id", snapshot_id)]);
1172        let items: HashMap<String, String> = self
1173            .rpc("snapshots.create_download_url".to_owned(), Some(params))
1174            .await?;
1175
1176        let total = Arc::new(AtomicUsize::new(0));
1177        let current = Arc::new(AtomicUsize::new(0));
1178        let sem = Arc::new(Semaphore::new(MAX_TASKS));
1179
1180        let tasks = items
1181            .iter()
1182            .map(|(key, url)| {
1183                let http = self.http.clone();
1184                let key = key.clone();
1185                let url = url.clone();
1186                let output = output.clone();
1187                let progress = progress.clone();
1188                let current = current.clone();
1189                let total = total.clone();
1190                let sem = sem.clone();
1191
1192                tokio::spawn(async move {
1193                    let _permit = sem.acquire().await.unwrap();
1194                    let res = http.get(url).send().await.unwrap();
1195                    let content_length = res.content_length().unwrap() as usize;
1196
1197                    if let Some(progress) = &progress {
1198                        let total = total.fetch_add(content_length, Ordering::SeqCst);
1199                        progress
1200                            .send(Progress {
1201                                current: current.load(Ordering::SeqCst),
1202                                total: total + content_length,
1203                            })
1204                            .await
1205                            .unwrap();
1206                    }
1207
1208                    let mut file = File::create(output.join(key)).await.unwrap();
1209                    let mut stream = res.bytes_stream();
1210
1211                    while let Some(chunk) = stream.next().await {
1212                        let chunk = chunk.unwrap();
1213                        file.write_all(&chunk).await.unwrap();
1214                        let len = chunk.len();
1215
1216                        if let Some(progress) = &progress {
1217                            let total = total.load(Ordering::SeqCst);
1218                            let current = current.fetch_add(len, Ordering::SeqCst);
1219
1220                            progress
1221                                .send(Progress {
1222                                    current: current + len,
1223                                    total,
1224                                })
1225                                .await
1226                                .unwrap();
1227                        }
1228                    }
1229                })
1230            })
1231            .collect::<Vec<_>>();
1232
1233        join_all(tasks)
1234            .await
1235            .into_iter()
1236            .collect::<Result<Vec<_>, _>>()
1237            .unwrap();
1238
1239        Ok(())
1240    }
1241
1242    /// The snapshot restore method is used to restore a snapshot to the server.
1243    /// The restore method can perform a few different operations depending on
1244    /// the snapshot type.
1245    ///
1246    /// The auto-annotation workflow is used to automatically annotate the
1247    /// dataset with 2D masks and boxes using the labels within the
1248    /// autolabel list. If autolabel is empty then the auto-annotation
1249    /// workflow is not used. If the MCAP includes radar or LiDAR data then
1250    /// the auto-annotation workflow will also generate 3D bounding boxes
1251    /// for detected objects.
1252    ///
1253    /// The autodepth flag is used to determine if a depthmap should be
1254    /// automatically generated for the dataset, this will currently only work
1255    /// accurately for Maivin or Raivin cameras.
1256    pub async fn restore_snapshot(
1257        &self,
1258        project_id: ProjectID,
1259        snapshot_id: SnapshotID,
1260        topics: &[String],
1261        autolabel: &[String],
1262        autodepth: bool,
1263        dataset_name: Option<&str>,
1264        dataset_description: Option<&str>,
1265    ) -> Result<SnapshotRestoreResult, Error> {
1266        let params = SnapshotRestore {
1267            project_id,
1268            snapshot_id,
1269            fps: 1,
1270            autodepth,
1271            agtg_pipeline: !autolabel.is_empty(),
1272            autolabel: autolabel.to_vec(),
1273            topics: topics.to_vec(),
1274            dataset_name: dataset_name.map(|s| s.to_owned()),
1275            dataset_description: dataset_description.map(|s| s.to_owned()),
1276        };
1277        self.rpc("snapshots.restore".to_owned(), Some(params)).await
1278    }
1279
1280    /// Returns a list of experiments available to the user.  The experiments
1281    /// are returned as a vector of Experiment objects.  If name is provided
1282    /// then only experiments containing this string are returned.
1283    ///
1284    /// Experiments provide a method of organizing training and validation
1285    /// sessions together and are akin to an Experiment in MLFlow terminology.  
1286    /// Each experiment can have multiple trainer sessions associated with it,
1287    /// these would be akin to runs in MLFlow terminology.
1288    pub async fn experiments(
1289        &self,
1290        project_id: ProjectID,
1291        name: Option<&str>,
1292    ) -> Result<Vec<Experiment>, Error> {
1293        let params = HashMap::from([("project_id", project_id)]);
1294        let experiments: Vec<Experiment> =
1295            self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1296        if let Some(name) = name {
1297            Ok(experiments
1298                .into_iter()
1299                .filter(|e| e.name().contains(name))
1300                .collect())
1301        } else {
1302            Ok(experiments)
1303        }
1304    }
1305
1306    /// Return the experiment with the specified experiment ID.  If the
1307    /// experiment does not exist, an error is returned.
1308    pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1309        let params = HashMap::from([("trainer_id", experiment_id)]);
1310        self.rpc("trainer.get".to_owned(), Some(params)).await
1311    }
1312
1313    /// Returns a list of trainer sessions available to the user.  The trainer
1314    /// sessions are returned as a vector of TrainingSession objects.  If name
1315    /// is provided then only trainer sessions containing this string are
1316    /// returned.
1317    ///
1318    /// Trainer sessions are akin to runs in MLFlow terminology.  These
1319    /// represent an actual training session which will produce metrics and
1320    /// model artifacts.
1321    pub async fn training_sessions(
1322        &self,
1323        experiment_id: ExperimentID,
1324        name: Option<&str>,
1325    ) -> Result<Vec<TrainingSession>, Error> {
1326        let params = HashMap::from([("trainer_id", experiment_id)]);
1327        let sessions: Vec<TrainingSession> = self
1328            .rpc("trainer.session.list".to_owned(), Some(params))
1329            .await?;
1330        if let Some(name) = name {
1331            Ok(sessions
1332                .into_iter()
1333                .filter(|s| s.name().contains(name))
1334                .collect())
1335        } else {
1336            Ok(sessions)
1337        }
1338    }
1339
1340    /// Return the trainer session with the specified trainer session ID.  If
1341    /// the trainer session does not exist, an error is returned.
1342    pub async fn training_session(
1343        &self,
1344        session_id: TrainingSessionID,
1345    ) -> Result<TrainingSession, Error> {
1346        let params = HashMap::from([("trainer_session_id", session_id)]);
1347        self.rpc("trainer.session.get".to_owned(), Some(params))
1348            .await
1349    }
1350
1351    /// List validation sessions for the given project.
1352    pub async fn validation_sessions(
1353        &self,
1354        project_id: ProjectID,
1355    ) -> Result<Vec<ValidationSession>, Error> {
1356        let params = HashMap::from([("project_id", project_id)]);
1357        self.rpc("validate.session.list".to_owned(), Some(params))
1358            .await
1359    }
1360
1361    /// Retrieve a specific validation session.
1362    pub async fn validation_session(
1363        &self,
1364        session_id: ValidationSessionID,
1365    ) -> Result<ValidationSession, Error> {
1366        let params = HashMap::from([("validate_session_id", session_id)]);
1367        self.rpc("validate.session.get".to_owned(), Some(params))
1368            .await
1369    }
1370
1371    /// List the artifacts for the specified trainer session.  The artifacts
1372    /// are returned as a vector of strings.
1373    pub async fn artifacts(
1374        &self,
1375        training_session_id: TrainingSessionID,
1376    ) -> Result<Vec<Artifact>, Error> {
1377        let params = HashMap::from([("training_session_id", training_session_id)]);
1378        self.rpc("trainer.get_artifacts".to_owned(), Some(params))
1379            .await
1380    }
1381
1382    /// Download the model artifact for the specified trainer session to the
1383    /// specified file path, if path is not provided it will be downloaded to
1384    /// the current directory with the same filename.  A progress callback can
1385    /// be provided to monitor the progress of the download over a watch
1386    /// channel.
1387    pub async fn download_artifact(
1388        &self,
1389        training_session_id: TrainingSessionID,
1390        modelname: &str,
1391        filename: Option<PathBuf>,
1392        progress: Option<Sender<Progress>>,
1393    ) -> Result<(), Error> {
1394        let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
1395        let resp = self
1396            .http
1397            .get(format!(
1398                "{}/download_model?training_session_id={}&file={}",
1399                self.url,
1400                training_session_id.value(),
1401                modelname
1402            ))
1403            .header("Authorization", format!("Bearer {}", self.token().await))
1404            .send()
1405            .await?;
1406        if !resp.status().is_success() {
1407            let err = resp.error_for_status_ref().unwrap_err();
1408            return Err(Error::HttpError(err));
1409        }
1410
1411        fs::create_dir_all(filename.parent().unwrap()).await?;
1412
1413        if let Some(progress) = progress {
1414            let total = resp.content_length().unwrap() as usize;
1415            progress.send(Progress { current: 0, total }).await.unwrap();
1416
1417            let mut file = File::create(filename).await?;
1418            let mut current = 0;
1419            let mut stream = resp.bytes_stream();
1420
1421            while let Some(item) = stream.next().await {
1422                let chunk = item?;
1423                file.write_all(&chunk).await?;
1424                current += chunk.len();
1425                progress.send(Progress { current, total }).await.unwrap();
1426            }
1427        } else {
1428            let body = resp.bytes().await?;
1429            fs::write(filename, body).await?;
1430        }
1431
1432        Ok(())
1433    }
1434
1435    /// Download the model checkpoint associated with the specified trainer
1436    /// session to the specified file path, if path is not provided it will be
1437    /// downloaded to the current directory with the same filename.  A progress
1438    /// callback can be provided to monitor the progress of the download over a
1439    /// watch channel.
1440    ///
1441    /// There is no API for listing checkpoints it is expected that trainers are
1442    /// aware of possible checkpoints and their names within the checkpoint
1443    /// folder on the server.
1444    pub async fn download_checkpoint(
1445        &self,
1446        training_session_id: TrainingSessionID,
1447        checkpoint: &str,
1448        filename: Option<PathBuf>,
1449        progress: Option<Sender<Progress>>,
1450    ) -> Result<(), Error> {
1451        let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
1452        let resp = self
1453            .http
1454            .get(format!(
1455                "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1456                self.url,
1457                training_session_id.value(),
1458                checkpoint
1459            ))
1460            .header("Authorization", format!("Bearer {}", self.token().await))
1461            .send()
1462            .await?;
1463        if !resp.status().is_success() {
1464            let err = resp.error_for_status_ref().unwrap_err();
1465            return Err(Error::HttpError(err));
1466        }
1467
1468        fs::create_dir_all(filename.parent().unwrap()).await?;
1469
1470        if let Some(progress) = progress {
1471            let total = resp.content_length().unwrap() as usize;
1472            progress.send(Progress { current: 0, total }).await.unwrap();
1473
1474            let mut file = File::create(filename).await?;
1475            let mut current = 0;
1476            let mut stream = resp.bytes_stream();
1477
1478            while let Some(item) = stream.next().await {
1479                let chunk = item?;
1480                file.write_all(&chunk).await?;
1481                current += chunk.len();
1482                progress.send(Progress { current, total }).await.unwrap();
1483            }
1484        } else {
1485            let body = resp.bytes().await?;
1486            fs::write(filename, body).await?;
1487        }
1488
1489        Ok(())
1490    }
1491
1492    /// Return a list of tasks for the current user.
1493    pub async fn tasks(
1494        &self,
1495        name: Option<&str>,
1496        workflow: Option<&str>,
1497        status: Option<&str>,
1498        manager: Option<&str>,
1499    ) -> Result<Vec<Task>, Error> {
1500        let mut params = TasksListParams {
1501            continue_token: None,
1502            status: status.map(|s| vec![s.to_owned()]),
1503            manager: manager.map(|m| vec![m.to_owned()]),
1504        };
1505        let mut tasks = Vec::new();
1506
1507        loop {
1508            let result = self
1509                .rpc::<_, TasksListResult>("task.list".to_owned(), Some(&params))
1510                .await?;
1511            tasks.extend(result.tasks);
1512
1513            if result.continue_token.is_none() || result.continue_token == Some("".into()) {
1514                params.continue_token = None;
1515            } else {
1516                params.continue_token = result.continue_token;
1517            }
1518
1519            if params.continue_token.is_none() {
1520                break;
1521            }
1522        }
1523
1524        if let Some(name) = name {
1525            tasks.retain(|t| t.name().contains(name));
1526        }
1527
1528        if let Some(workflow) = workflow {
1529            tasks.retain(|t| t.workflow().contains(workflow));
1530        }
1531
1532        Ok(tasks)
1533    }
1534
1535    /// Retrieve the task information and status.
1536    pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
1537        self.rpc(
1538            "task.get".to_owned(),
1539            Some(HashMap::from([("id", task_id)])),
1540        )
1541        .await
1542    }
1543
1544    /// Updates the tasks status.
1545    pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
1546        let status = TaskStatus {
1547            task_id,
1548            status: status.to_owned(),
1549        };
1550        self.rpc("docker.update.status".to_owned(), Some(status))
1551            .await
1552    }
1553
1554    /// Defines the stages for the task.  The stages are defined as a mapping
1555    /// from stage names to their descriptions.  Once stages are defined their
1556    /// status can be updated using the update_stage method.
1557    pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
1558        let stages: Vec<HashMap<String, String>> = stages
1559            .iter()
1560            .map(|(key, value)| {
1561                let mut stage_map = HashMap::new();
1562                stage_map.insert(key.to_string(), value.to_string());
1563                stage_map
1564            })
1565            .collect();
1566        let params = TaskStages { task_id, stages };
1567        let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
1568        Ok(())
1569    }
1570
1571    /// Updates the progress of the task for the provided stage and status
1572    /// information.
1573    pub async fn update_stage(
1574        &self,
1575        task_id: TaskID,
1576        stage: &str,
1577        status: &str,
1578        message: &str,
1579        percentage: u8,
1580    ) -> Result<(), Error> {
1581        let stage = Stage::new(
1582            Some(task_id),
1583            stage.to_owned(),
1584            Some(status.to_owned()),
1585            Some(message.to_owned()),
1586            percentage,
1587        );
1588        let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
1589        Ok(())
1590    }
1591
1592    /// Raw fetch from the Studio server is used for downloading files.
1593    pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
1594        let req = self
1595            .http
1596            .get(format!("{}/{}", self.url, query))
1597            .header("User-Agent", "EdgeFirst Client")
1598            .header("Authorization", format!("Bearer {}", self.token().await));
1599        let resp = req.send().await?;
1600
1601        if resp.status().is_success() {
1602            let body = resp.bytes().await?;
1603
1604            if log_enabled!(Level::Trace) {
1605                trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
1606            }
1607
1608            Ok(body.to_vec())
1609        } else {
1610            let err = resp.error_for_status_ref().unwrap_err();
1611            Err(Error::HttpError(err))
1612        }
1613    }
1614
1615    /// Sends a multipart post request to the server.  This is used by the
1616    /// upload and download APIs which do not use JSON-RPC but instead transfer
1617    /// files using multipart/form-data.
1618    pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
1619        let req = self
1620            .http
1621            .post(format!("{}/api?method={}", self.url, method))
1622            .header("Accept", "application/json")
1623            .header("User-Agent", "EdgeFirst Client")
1624            .header("Authorization", format!("Bearer {}", self.token().await))
1625            .multipart(form);
1626        let resp = req.send().await?;
1627
1628        if resp.status().is_success() {
1629            let body = resp.bytes().await?;
1630
1631            if log_enabled!(Level::Trace) {
1632                trace!(
1633                    "POST Multipart Response: {}",
1634                    String::from_utf8_lossy(&body)
1635                );
1636            }
1637
1638            let response: RpcResponse<String> = match serde_json::from_slice(&body) {
1639                Ok(response) => response,
1640                Err(err) => {
1641                    error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
1642                    return Err(err.into());
1643                }
1644            };
1645
1646            if let Some(error) = response.error {
1647                Err(Error::RpcError(error.code, error.message))
1648            } else if let Some(result) = response.result {
1649                Ok(result)
1650            } else {
1651                Err(Error::InvalidResponse)
1652            }
1653        } else {
1654            let err = resp.error_for_status_ref().unwrap_err();
1655            Err(Error::HttpError(err))
1656        }
1657    }
1658
1659    /// Send a JSON-RPC request to the server.  The method is the name of the
1660    /// method to call on the server.  The params are the parameters to pass to
1661    /// the method.  The method and params are serialized into a JSON-RPC
1662    /// request and sent to the server.  The response is deserialized into
1663    /// the specified type and returned to the caller.
1664    ///
1665    /// NOTE: This API would generally not be called directly and instead users
1666    /// should use the higher-level methods provided by the client.
1667    pub async fn rpc<Params, RpcResult>(
1668        &self,
1669        method: String,
1670        params: Option<Params>,
1671    ) -> Result<RpcResult, Error>
1672    where
1673        Params: Serialize,
1674        RpcResult: DeserializeOwned,
1675    {
1676        let auth_expires = self.token_expiration().await?;
1677        if auth_expires <= Utc::now() + Duration::from_secs(3600) {
1678            self.renew_token().await?;
1679        }
1680
1681        self.rpc_without_auth(method, params).await
1682    }
1683
1684    async fn rpc_without_auth<Params, RpcResult>(
1685        &self,
1686        method: String,
1687        params: Option<Params>,
1688    ) -> Result<RpcResult, Error>
1689    where
1690        Params: Serialize,
1691        RpcResult: DeserializeOwned,
1692    {
1693        let request = RpcRequest {
1694            method,
1695            params,
1696            ..Default::default()
1697        };
1698
1699        if log_enabled!(Level::Trace) {
1700            trace!(
1701                "RPC Request: {}",
1702                serde_json::ser::to_string_pretty(&request)?
1703            );
1704        }
1705
1706        for attempt in 0..MAX_RETRIES {
1707            let res = match self
1708                .http
1709                .post(format!("{}/api", self.url))
1710                .header("Accept", "application/json")
1711                .header("User-Agent", "EdgeFirst Client")
1712                .header("Authorization", format!("Bearer {}", self.token().await))
1713                .json(&request)
1714                .send()
1715                .await
1716            {
1717                Ok(res) => res,
1718                Err(err) => {
1719                    warn!("Socket Error: {:?}", err);
1720                    continue;
1721                }
1722            };
1723
1724            if res.status().is_success() {
1725                let body = res.bytes().await?;
1726
1727                if log_enabled!(Level::Trace) {
1728                    trace!("RPC Response: {}", String::from_utf8_lossy(&body));
1729                }
1730
1731                let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
1732                    Ok(response) => response,
1733                    Err(err) => {
1734                        error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
1735                        return Err(err.into());
1736                    }
1737                };
1738
1739                // FIXME: Studio Server always returns 999 as the id.
1740                // if request.id.to_string() != response.id {
1741                //     return Err(Error::InvalidRpcId(response.id));
1742                // }
1743
1744                if let Some(error) = response.error {
1745                    return Err(Error::RpcError(error.code, error.message));
1746                } else if let Some(result) = response.result {
1747                    return Ok(result);
1748                } else {
1749                    return Err(Error::InvalidResponse);
1750                }
1751            } else {
1752                let err = res.error_for_status_ref().unwrap_err();
1753                warn!("HTTP Error {}: {}", err, res.text().await?);
1754            }
1755
1756            warn!(
1757                "Retrying RPC request (attempt {}/{})...",
1758                attempt + 1,
1759                MAX_RETRIES
1760            );
1761            tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1762        }
1763
1764        Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1765    }
1766}
1767
1768async fn upload_multipart(
1769    http: reqwest::Client,
1770    part: SnapshotPart,
1771    path: PathBuf,
1772    total: usize,
1773    current: Arc<AtomicUsize>,
1774    progress: Option<Sender<Progress>>,
1775) -> Result<SnapshotCompleteMultipartParams, Error> {
1776    let filesize = path.metadata()?.len() as usize;
1777    let n_parts = filesize.div_ceil(PART_SIZE);
1778    let sem = Arc::new(Semaphore::new(MAX_TASKS));
1779
1780    let key = part.key.unwrap();
1781    let upload_id = part.upload_id;
1782
1783    let urls = part.urls.clone();
1784    let etags = Arc::new(tokio::sync::Mutex::new(vec![
1785        EtagPart {
1786            etag: "".to_owned(),
1787            part_number: 0,
1788        };
1789        n_parts
1790    ]));
1791
1792    let tasks = (0..n_parts)
1793        .map(|part| {
1794            let http = http.clone();
1795            let url = urls[part].clone();
1796            let etags = etags.clone();
1797            let path = path.to_owned();
1798            let sem = sem.clone();
1799            let progress = progress.clone();
1800            let current = current.clone();
1801
1802            tokio::spawn(async move {
1803                let _permit = sem.acquire().await?;
1804                let mut etag = None;
1805
1806                for attempt in 0..MAX_RETRIES {
1807                    match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
1808                    {
1809                        Ok(v) => {
1810                            etag = Some(v);
1811                            break;
1812                        }
1813                        Err(err) => {
1814                            warn!("Upload Part Error: {:?}", err);
1815                            tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1816                        }
1817                    }
1818                }
1819
1820                if let Some(etag) = etag {
1821                    let mut etags = etags.lock().await;
1822                    etags[part] = EtagPart {
1823                        etag,
1824                        part_number: part + 1,
1825                    };
1826
1827                    let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
1828                    if let Some(progress) = &progress {
1829                        progress
1830                            .send(Progress {
1831                                current: current + PART_SIZE,
1832                                total,
1833                            })
1834                            .await
1835                            .unwrap();
1836                    }
1837
1838                    Ok(())
1839                } else {
1840                    Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1841                }
1842            })
1843        })
1844        .collect::<Vec<_>>();
1845
1846    join_all(tasks)
1847        .await
1848        .into_iter()
1849        .collect::<Result<Vec<_>, _>>()?;
1850
1851    Ok(SnapshotCompleteMultipartParams {
1852        key,
1853        upload_id,
1854        etag_list: etags.lock().await.clone(),
1855    })
1856}
1857
1858async fn upload_part(
1859    http: reqwest::Client,
1860    url: String,
1861    path: PathBuf,
1862    part: usize,
1863    n_parts: usize,
1864) -> Result<String, Error> {
1865    let filesize = path.metadata()?.len() as usize;
1866    let mut file = File::open(path).await.unwrap();
1867    file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
1868        .await
1869        .unwrap();
1870    let file = file.take(PART_SIZE as u64);
1871
1872    let body_length = if part + 1 == n_parts {
1873        filesize % PART_SIZE
1874    } else {
1875        PART_SIZE
1876    };
1877
1878    let stream = FramedRead::new(file, BytesCodec::new());
1879    let body = Body::wrap_stream(stream);
1880
1881    let resp = http
1882        .put(url.clone())
1883        .header(CONTENT_LENGTH, body_length)
1884        .body(body)
1885        .send()
1886        .await?
1887        .error_for_status()?;
1888    let etag = resp
1889        .headers()
1890        .get("etag")
1891        .unwrap()
1892        .to_str()
1893        .unwrap()
1894        .to_owned();
1895    // Studio Server requires etag without the quotes.
1896    Ok(etag
1897        .strip_prefix("\"")
1898        .unwrap()
1899        .strip_suffix("\"")
1900        .unwrap()
1901        .to_owned())
1902}