1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23 collections::HashMap,
24 fs::create_dir_all,
25 io::{SeekFrom, Write as _},
26 path::{Path, PathBuf},
27 sync::{
28 Arc,
29 atomic::{AtomicUsize, Ordering},
30 },
31 time::Duration,
32 vec,
33};
34use tokio::{
35 fs::{self, File},
36 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
37 sync::{RwLock, Semaphore, mpsc::Sender},
38};
39use tokio_util::codec::{BytesCodec, FramedRead};
40use walkdir::WalkDir;
41
42#[cfg(feature = "polars")]
43use polars::prelude::*;
44
45static MAX_TASKS: usize = 32;
46static MAX_RETRIES: u32 = 10;
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49#[derive(Debug, Clone)]
71pub struct Progress {
72 pub current: usize,
74 pub total: usize,
76}
77
78#[derive(Serialize)]
79struct RpcRequest<Params> {
80 id: u64,
81 jsonrpc: String,
82 method: String,
83 params: Option<Params>,
84}
85
86impl<T> Default for RpcRequest<T> {
87 fn default() -> Self {
88 RpcRequest {
89 id: 0,
90 jsonrpc: "2.0".to_string(),
91 method: "".to_string(),
92 params: None,
93 }
94 }
95}
96
97#[derive(Deserialize)]
98struct RpcError {
99 code: i32,
100 message: String,
101}
102
103#[derive(Deserialize)]
104struct RpcResponse<RpcResult> {
105 id: String,
106 jsonrpc: String,
107 error: Option<RpcError>,
108 result: Option<RpcResult>,
109}
110
111#[derive(Deserialize)]
112struct EmptyResult {}
113
114#[derive(Debug, Serialize)]
115struct SnapshotCreateParams {
116 snapshot_name: String,
117 keys: Vec<String>,
118}
119
120#[derive(Debug, Deserialize)]
121struct SnapshotCreateResult {
122 snapshot_id: SnapshotID,
123 urls: Vec<String>,
124}
125
126#[derive(Debug, Serialize)]
127struct SnapshotCreateMultipartParams {
128 snapshot_name: String,
129 keys: Vec<String>,
130 file_sizes: Vec<usize>,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum SnapshotCreateMultipartResultField {
136 Id(u64),
137 Part(SnapshotPart),
138}
139
140#[derive(Debug, Serialize)]
141struct SnapshotCompleteMultipartParams {
142 key: String,
143 upload_id: String,
144 etag_list: Vec<EtagPart>,
145}
146
147#[derive(Debug, Clone, Serialize)]
148struct EtagPart {
149 #[serde(rename = "ETag")]
150 etag: String,
151 #[serde(rename = "PartNumber")]
152 part_number: usize,
153}
154
155#[derive(Debug, Clone, Deserialize)]
156struct SnapshotPart {
157 key: Option<String>,
158 upload_id: String,
159 urls: Vec<String>,
160}
161
162#[derive(Debug, Serialize)]
163struct SnapshotStatusParams {
164 snapshot_id: SnapshotID,
165 status: String,
166}
167
168#[derive(Deserialize, Debug)]
169struct SnapshotStatusResult {
170 pub id: SnapshotID,
171 pub uid: String,
172 pub description: String,
173 pub date: String,
174 pub status: String,
175}
176
177#[derive(Serialize)]
178struct ImageListParams {
179 images_filter: ImagesFilter,
180 image_files_filter: HashMap<String, String>,
181 only_ids: bool,
182}
183
184#[derive(Serialize)]
185struct ImagesFilter {
186 dataset_id: DatasetID,
187}
188
189#[derive(Clone, Debug)]
238pub struct Client {
239 http: reqwest::Client,
240 url: String,
241 token: Arc<RwLock<String>>,
242 token_path: Option<PathBuf>,
243}
244
245impl Client {
246 pub fn new() -> Result<Self, Error> {
255 Ok(Client {
256 http: reqwest::Client::builder()
257 .read_timeout(Duration::from_secs(60))
258 .build()?,
259 url: "https://edgefirst.studio".to_string(),
260 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
261 token_path: None,
262 })
263 }
264
265 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
269 Ok(Client {
270 url: format!("https://{}.edgefirst.studio", server),
271 ..self.clone()
272 })
273 }
274
275 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
278 let params = HashMap::from([("username", username), ("password", password)]);
279 let login: LoginResult = self
280 .rpc_without_auth("auth.login".to_owned(), Some(params))
281 .await?;
282 Ok(Client {
283 token: Arc::new(tokio::sync::RwLock::new(login.token)),
284 ..self.clone()
285 })
286 }
287
288 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
291 let token_path = match token_path {
292 Some(path) => path.to_path_buf(),
293 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
294 .unwrap()
295 .config_dir()
296 .join("token"),
297 };
298
299 debug!("Using token path: {:?}", token_path);
300
301 let token = match token_path.exists() {
302 true => std::fs::read_to_string(&token_path)?,
303 false => "".to_string(),
304 };
305
306 if !token.is_empty() {
307 let client = self.with_token(&token)?;
308 Ok(Client {
309 token_path: Some(token_path),
310 ..client
311 })
312 } else {
313 Ok(Client {
314 token_path: Some(token_path),
315 ..self.clone()
316 })
317 }
318 }
319
320 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
322 if token.is_empty() {
323 return Ok(self.clone());
324 }
325
326 let token_parts: Vec<&str> = token.split('.').collect();
327 if token_parts.len() != 3 {
328 return Err(Error::InvalidToken);
329 }
330
331 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
332 .decode(token_parts[1])
333 .unwrap();
334 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
335 let server = match payload.get("database") {
336 Some(value) => Ok(value.as_str().unwrap().to_string()),
337 None => Err(Error::InvalidToken),
338 }?;
339
340 Ok(Client {
341 url: format!("https://{}.edgefirst.studio", server),
342 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
343 ..self.clone()
344 })
345 }
346
347 pub async fn save_token(&self) -> Result<(), Error> {
348 let path = self.token_path.clone().unwrap_or_else(|| {
349 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
350 .unwrap()
351 .config_dir()
352 .join("token")
353 });
354
355 create_dir_all(path.parent().unwrap())?;
356 let mut file = std::fs::File::create(&path)?;
357 file.write_all(self.token.read().await.as_bytes())?;
358
359 debug!("Saved token to {:?}", path);
360
361 Ok(())
362 }
363
364 pub async fn version(&self) -> Result<String, Error> {
367 let version: HashMap<String, String> = self
368 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
369 .await?;
370 let version = version.get("version").ok_or(Error::InvalidResponse)?;
371 Ok(version.to_owned())
372 }
373
374 pub async fn logout(&self) -> Result<(), Error> {
378 {
379 let mut token = self.token.write().await;
380 *token = "".to_string();
381 }
382
383 if let Some(path) = &self.token_path
384 && path.exists()
385 {
386 fs::remove_file(path).await?;
387 }
388
389 Ok(())
390 }
391
392 pub async fn token(&self) -> String {
396 self.token.read().await.clone()
397 }
398
399 pub async fn verify_token(&self) -> Result<(), Error> {
404 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
405 .await?;
406 Ok::<(), Error>(())
407 }
408
409 pub async fn renew_token(&self) -> Result<(), Error> {
414 let params = HashMap::from([("username".to_string(), self.username().await?)]);
415 let result: LoginResult = self
416 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
417 .await?;
418
419 {
420 let mut token = self.token.write().await;
421 *token = result.token;
422 }
423
424 if self.token_path.is_some() {
425 self.save_token().await?;
426 }
427
428 Ok(())
429 }
430
431 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
432 let token = self.token.read().await;
433 if token.is_empty() {
434 return Err(Error::EmptyToken);
435 }
436
437 let token_parts: Vec<&str> = token.split('.').collect();
438 if token_parts.len() != 3 {
439 return Err(Error::InvalidToken);
440 }
441
442 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
443 .decode(token_parts[1])
444 .unwrap();
445 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
446 match payload.get(field) {
447 Some(value) => Ok(value.to_owned()),
448 None => Err(Error::InvalidToken),
449 }
450 }
451
452 pub fn url(&self) -> &str {
454 &self.url
455 }
456
457 pub async fn username(&self) -> Result<String, Error> {
459 match self.token_field("username").await? {
460 serde_json::Value::String(username) => Ok(username),
461 _ => Err(Error::InvalidToken),
462 }
463 }
464
465 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
467 let ts = match self.token_field("exp").await? {
468 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
469 _ => return Err(Error::InvalidToken),
470 };
471
472 match DateTime::<Utc>::from_timestamp_secs(ts) {
473 Some(dt) => Ok(dt),
474 None => Err(Error::InvalidToken),
475 }
476 }
477
478 pub async fn organization(&self) -> Result<Organization, Error> {
480 self.rpc::<(), Organization>("org.get".to_owned(), None)
481 .await
482 }
483
484 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
492 let projects = self
493 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
494 .await?;
495 if let Some(name) = name {
496 Ok(projects
497 .into_iter()
498 .filter(|p| p.name().contains(name))
499 .collect())
500 } else {
501 Ok(projects)
502 }
503 }
504
505 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
508 let params = HashMap::from([("project_id", project_id)]);
509 self.rpc("project.get".to_owned(), Some(params)).await
510 }
511
512 pub async fn datasets(
516 &self,
517 project_id: ProjectID,
518 name: Option<&str>,
519 ) -> Result<Vec<Dataset>, Error> {
520 let params = HashMap::from([("project_id", project_id)]);
521 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
522 if let Some(name) = name {
523 Ok(datasets
524 .into_iter()
525 .filter(|d| d.name().contains(name))
526 .collect())
527 } else {
528 Ok(datasets)
529 }
530 }
531
532 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
535 let params = HashMap::from([("dataset_id", dataset_id)]);
536 self.rpc("dataset.get".to_owned(), Some(params)).await
537 }
538
539 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
541 let params = HashMap::from([("dataset_id", dataset_id)]);
542 self.rpc("label.list".to_owned(), Some(params)).await
543 }
544
545 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
547 let new_label = NewLabel {
548 dataset_id,
549 labels: vec![NewLabelObject {
550 name: name.to_owned(),
551 }],
552 };
553 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
554 Ok(())
555 }
556
557 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
560 let params = HashMap::from([("label_id", label_id)]);
561 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
562 Ok(())
563 }
564
565 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
569 #[derive(Serialize)]
570 struct Params {
571 dataset_id: DatasetID,
572 label_id: u64,
573 label_name: String,
574 label_index: u64,
575 }
576
577 let _: String = self
578 .rpc(
579 "label.update".to_owned(),
580 Some(Params {
581 dataset_id: label.dataset_id(),
582 label_id: label.id(),
583 label_name: label.name().to_owned(),
584 label_index: label.index(),
585 }),
586 )
587 .await?;
588 Ok(())
589 }
590
591 pub async fn download_dataset(
592 &self,
593 dataset_id: DatasetID,
594 groups: &[String],
595 file_types: &[FileType],
596 output: PathBuf,
597 progress: Option<Sender<Progress>>,
598 ) -> Result<(), Error> {
599 let samples = self
600 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
601 .await?;
602 fs::create_dir_all(&output).await?;
603
604 let total = samples.len();
605 let current = Arc::new(AtomicUsize::new(0));
606 let sem = Arc::new(Semaphore::new(MAX_TASKS));
607
608 let tasks = samples
609 .into_iter()
610 .map(|sample| {
611 let sem = sem.clone();
612 let client = self.clone();
613 let current = current.clone();
614 let progress = progress.clone();
615 let file_types = file_types.to_vec();
616 let output = output.clone();
617
618 tokio::spawn(async move {
619 let _permit = sem.acquire().await.unwrap();
620
621 for file_type in file_types {
622 if let Some(data) = sample.download(&client, file_type.clone()).await? {
623 let file_ext = match file_type {
624 FileType::Image => infer::get(&data)
625 .expect("Failed to identify image file format for sample")
626 .extension()
627 .to_string(),
628 t => t.to_string(),
629 };
630
631 let file_name = format!("{}.{}", sample.name(), file_ext);
632 let file_path = output.join(&file_name);
633
634 let mut file = File::create(&file_path).await?;
635 file.write_all(&data).await?;
636 } else {
637 warn!("No data for sample: {}", sample.id());
638 }
639 }
640
641 if let Some(progress) = &progress {
642 let current = current.fetch_add(1, Ordering::SeqCst);
643 progress
644 .send(Progress {
645 current: current + 1,
646 total,
647 })
648 .await
649 .unwrap();
650 }
651
652 Ok::<(), Error>(())
653 })
654 })
655 .collect::<Vec<_>>();
656
657 join_all(tasks)
658 .await
659 .into_iter()
660 .collect::<Result<Vec<_>, _>>()?;
661
662 if let Some(progress) = progress {
663 drop(progress);
664 }
665
666 Ok(())
667 }
668
669 pub async fn annotation_sets(
671 &self,
672 dataset_id: DatasetID,
673 ) -> Result<Vec<AnnotationSet>, Error> {
674 let params = HashMap::from([("dataset_id", dataset_id)]);
675 self.rpc("annset.list".to_owned(), Some(params)).await
676 }
677
678 pub async fn annotation_set(
680 &self,
681 annotation_set_id: AnnotationSetID,
682 ) -> Result<AnnotationSet, Error> {
683 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
684 self.rpc("annset.get".to_owned(), Some(params)).await
685 }
686
687 pub async fn annotations(
700 &self,
701 annotation_set_id: AnnotationSetID,
702 groups: &[String],
703 annotation_types: &[AnnotationType],
704 progress: Option<Sender<Progress>>,
705 ) -> Result<Vec<Annotation>, Error> {
706 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
707 let labels = self
708 .labels(dataset_id)
709 .await?
710 .into_iter()
711 .map(|label| (label.name().to_string(), label.index()))
712 .collect::<HashMap<_, _>>();
713 let total = self
714 .samples_count(
715 dataset_id,
716 Some(annotation_set_id),
717 annotation_types,
718 groups,
719 &[],
720 )
721 .await?
722 .total as usize;
723 let mut annotations = vec![];
724 let mut continue_token: Option<String> = None;
725 let mut current = 0;
726
727 if total == 0 {
728 return Ok(annotations);
729 }
730
731 loop {
732 let params = SamplesListParams {
733 dataset_id,
734 annotation_set_id: Some(annotation_set_id),
735 types: annotation_types.iter().map(|t| t.to_string()).collect(),
736 group_names: groups.to_vec(),
737 continue_token,
738 };
739
740 let result: SamplesListResult =
741 self.rpc("samples.list".to_owned(), Some(params)).await?;
742 current += result.samples.len();
743 continue_token = result.continue_token;
744
745 if result.samples.is_empty() {
746 break;
747 }
748
749 for sample in result.samples {
750 if sample.annotations().is_empty() {
753 let mut annotation = Annotation::new();
754 annotation.set_sample_id(Some(sample.id()));
755 annotation.set_name(Some(sample.name().to_string()));
756 annotation.set_group(sample.group().cloned());
757 annotation.set_sequence_name(sample.sequence_name().cloned());
758 annotations.push(annotation);
759 continue;
760 }
761
762 sample.annotations().iter().for_each(|annotation| {
763 let mut annotation = annotation.clone();
764 annotation.set_sample_id(Some(sample.id()));
765 annotation.set_name(Some(sample.name().to_string()));
766 annotation.set_group(sample.group().cloned());
767 annotation.set_sequence_name(sample.sequence_name().cloned());
768 annotation.set_label_index(Some(labels[annotation.label().unwrap().as_str()]));
769 annotations.push(annotation);
770 });
771 }
772
773 if let Some(progress) = &progress {
774 progress.send(Progress { current, total }).await.unwrap();
775 }
776
777 match &continue_token {
778 Some(token) if !token.is_empty() => continue,
779 _ => break,
780 }
781 }
782
783 if let Some(progress) = progress {
784 drop(progress);
785 }
786
787 Ok(annotations)
788 }
789
790 pub async fn samples_count(
791 &self,
792 dataset_id: DatasetID,
793 annotation_set_id: Option<AnnotationSetID>,
794 annotation_types: &[AnnotationType],
795 groups: &[String],
796 types: &[FileType],
797 ) -> Result<SamplesCountResult, Error> {
798 let types = annotation_types
799 .iter()
800 .map(|t| t.to_string())
801 .chain(types.iter().map(|t| t.to_string()))
802 .collect::<Vec<_>>();
803
804 let params = SamplesListParams {
805 dataset_id,
806 annotation_set_id,
807 group_names: groups.to_vec(),
808 types,
809 continue_token: None,
810 };
811
812 self.rpc("samples.count".to_owned(), Some(params)).await
813 }
814
815 pub async fn samples(
816 &self,
817 dataset_id: DatasetID,
818 annotation_set_id: Option<AnnotationSetID>,
819 annotation_types: &[AnnotationType],
820 groups: &[String],
821 types: &[FileType],
822 progress: Option<Sender<Progress>>,
823 ) -> Result<Vec<Sample>, Error> {
824 let types = annotation_types
825 .iter()
826 .map(|t| t.to_string())
827 .chain(types.iter().map(|t| t.to_string()))
828 .collect::<Vec<_>>();
829 let labels = self
830 .labels(dataset_id)
831 .await?
832 .into_iter()
833 .map(|label| (label.name().to_string(), label.index()))
834 .collect::<HashMap<_, _>>();
835 let total = self
836 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
837 .await?
838 .total as usize;
839
840 let mut samples = vec![];
841 let mut continue_token: Option<String> = None;
842 let mut current = 0;
843
844 if total == 0 {
845 return Ok(samples);
846 }
847
848 loop {
849 let params = SamplesListParams {
850 dataset_id,
851 annotation_set_id,
852 types: types.clone(),
853 group_names: groups.to_vec(),
854 continue_token: continue_token.clone(),
855 };
856
857 let result: SamplesListResult =
858 self.rpc("samples.list".to_owned(), Some(params)).await?;
859 current += result.samples.len();
860 continue_token = result.continue_token;
861
862 if result.samples.is_empty() {
863 break;
864 }
865
866 samples.append(
867 &mut result
868 .samples
869 .into_iter()
870 .map(|s| {
871 let mut anns = s.annotations().to_vec();
872 for ann in &mut anns {
873 if let Some(label) = ann.label() {
874 ann.set_label_index(Some(labels[label.as_str()]));
875 }
876 }
877 s.with_annotations(anns)
878 })
879 .collect::<Vec<_>>(),
880 );
881
882 if let Some(progress) = &progress {
883 progress.send(Progress { current, total }).await.unwrap();
884 }
885
886 match &continue_token {
887 Some(token) if !token.is_empty() => continue,
888 _ => break,
889 }
890 }
891
892 if let Some(progress) = progress {
893 drop(progress);
894 }
895
896 Ok(samples)
897 }
898
899 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
900 for attempt in 1..MAX_RETRIES {
901 let resp = match self.http.get(url).send().await {
902 Ok(resp) => resp,
903 Err(err) => {
904 warn!(
905 "Socket Error [retry {}/{}]: {:?}",
906 attempt, MAX_RETRIES, err
907 );
908 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
909 continue;
910 }
911 };
912
913 match resp.bytes().await {
914 Ok(body) => return Ok(body.to_vec()),
915 Err(err) => {
916 warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
917 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
918 continue;
919 }
920 };
921 }
922
923 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
924 }
925
926 #[cfg(feature = "polars")]
937 pub async fn annotations_dataframe(
938 &self,
939 annotation_set_id: AnnotationSetID,
940 groups: &[String],
941 types: &[AnnotationType],
942 progress: Option<Sender<Progress>>,
943 ) -> Result<DataFrame, Error> {
944 use crate::dataset::annotations_dataframe;
945
946 let annotations = self
947 .annotations(annotation_set_id, groups, types, progress)
948 .await?;
949 Ok(annotations_dataframe(&annotations))
950 }
951
952 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
955 let snapshots: Vec<Snapshot> = self
956 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
957 .await?;
958 if let Some(name) = name {
959 Ok(snapshots
960 .into_iter()
961 .filter(|s| s.description().contains(name))
962 .collect())
963 } else {
964 Ok(snapshots)
965 }
966 }
967
968 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
970 let params = HashMap::from([("snapshot_id", snapshot_id)]);
971 self.rpc("snapshots.get".to_owned(), Some(params)).await
972 }
973
974 pub async fn create_snapshot(
981 &self,
982 path: &str,
983 progress: Option<Sender<Progress>>,
984 ) -> Result<Snapshot, Error> {
985 let path = Path::new(path);
986
987 if path.is_dir() {
988 return self
989 .create_snapshot_folder(path.to_str().unwrap(), progress)
990 .await;
991 }
992
993 let name = path.file_name().unwrap().to_str().unwrap();
994 let total = path.metadata()?.len() as usize;
995 let current = Arc::new(AtomicUsize::new(0));
996
997 if let Some(progress) = &progress {
998 progress.send(Progress { current: 0, total }).await.unwrap();
999 }
1000
1001 let params = SnapshotCreateMultipartParams {
1002 snapshot_name: name.to_owned(),
1003 keys: vec![name.to_owned()],
1004 file_sizes: vec![total],
1005 };
1006 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1007 .rpc(
1008 "snapshots.create_upload_url_multipart".to_owned(),
1009 Some(params),
1010 )
1011 .await?;
1012
1013 let snapshot_id = match multipart.get("snapshot_id") {
1014 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1015 _ => return Err(Error::InvalidResponse),
1016 };
1017
1018 let snapshot = self.snapshot(snapshot_id).await?;
1019 let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1020 let part_key = format!("{}/{}", part_prefix, name);
1021 let mut part = match multipart.get(&part_key) {
1022 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1023 _ => return Err(Error::InvalidResponse),
1024 }
1025 .clone();
1026 part.key = Some(part_key);
1027
1028 let params = upload_multipart(
1029 self.http.clone(),
1030 part.clone(),
1031 path.to_path_buf(),
1032 total,
1033 current,
1034 progress.clone(),
1035 )
1036 .await?;
1037
1038 let complete: String = self
1039 .rpc(
1040 "snapshots.complete_multipart_upload".to_owned(),
1041 Some(params),
1042 )
1043 .await?;
1044 debug!("Snapshot Multipart Complete: {:?}", complete);
1045
1046 let params: SnapshotStatusParams = SnapshotStatusParams {
1047 snapshot_id,
1048 status: "available".to_owned(),
1049 };
1050 let _: SnapshotStatusResult = self
1051 .rpc("snapshots.update".to_owned(), Some(params))
1052 .await?;
1053
1054 if let Some(progress) = progress {
1055 drop(progress);
1056 }
1057
1058 self.snapshot(snapshot_id).await
1059 }
1060
1061 async fn create_snapshot_folder(
1062 &self,
1063 path: &str,
1064 progress: Option<Sender<Progress>>,
1065 ) -> Result<Snapshot, Error> {
1066 let path = Path::new(path);
1067 let name = path.file_name().unwrap().to_str().unwrap();
1068
1069 let files = WalkDir::new(path)
1070 .into_iter()
1071 .filter_map(|entry| entry.ok())
1072 .filter(|entry| entry.file_type().is_file())
1073 .map(|entry| entry.path().strip_prefix(path).unwrap().to_owned())
1074 .collect::<Vec<_>>();
1075
1076 let total = files
1077 .iter()
1078 .map(|file| path.join(file).metadata().unwrap().len() as usize)
1079 .sum();
1080 let current = Arc::new(AtomicUsize::new(0));
1081
1082 if let Some(progress) = &progress {
1083 progress.send(Progress { current: 0, total }).await.unwrap();
1084 }
1085
1086 let keys = files
1087 .iter()
1088 .map(|key| key.to_str().unwrap().to_owned())
1089 .collect::<Vec<_>>();
1090 let file_sizes = files
1091 .iter()
1092 .map(|key| path.join(key).metadata().unwrap().len() as usize)
1093 .collect::<Vec<_>>();
1094
1095 let params = SnapshotCreateMultipartParams {
1096 snapshot_name: name.to_owned(),
1097 keys,
1098 file_sizes,
1099 };
1100
1101 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1102 .rpc(
1103 "snapshots.create_upload_url_multipart".to_owned(),
1104 Some(params),
1105 )
1106 .await?;
1107
1108 let snapshot_id = match multipart.get("snapshot_id") {
1109 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1110 _ => return Err(Error::InvalidResponse),
1111 };
1112
1113 let snapshot = self.snapshot(snapshot_id).await?;
1114 let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1115
1116 for file in files {
1117 let part_key = format!("{}/{}", part_prefix, file.to_str().unwrap());
1118 let mut part = match multipart.get(&part_key) {
1119 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1120 _ => return Err(Error::InvalidResponse),
1121 }
1122 .clone();
1123 part.key = Some(part_key);
1124
1125 let params = upload_multipart(
1126 self.http.clone(),
1127 part.clone(),
1128 path.join(file),
1129 total,
1130 current.clone(),
1131 progress.clone(),
1132 )
1133 .await?;
1134
1135 let complete: String = self
1136 .rpc(
1137 "snapshots.complete_multipart_upload".to_owned(),
1138 Some(params),
1139 )
1140 .await?;
1141 debug!("Snapshot Part Complete: {:?}", complete);
1142 }
1143
1144 let params = SnapshotStatusParams {
1145 snapshot_id,
1146 status: "available".to_owned(),
1147 };
1148 let _: SnapshotStatusResult = self
1149 .rpc("snapshots.update".to_owned(), Some(params))
1150 .await?;
1151
1152 if let Some(progress) = progress {
1153 drop(progress);
1154 }
1155
1156 self.snapshot(snapshot_id).await
1157 }
1158
1159 pub async fn download_snapshot(
1164 &self,
1165 snapshot_id: SnapshotID,
1166 output: PathBuf,
1167 progress: Option<Sender<Progress>>,
1168 ) -> Result<(), Error> {
1169 fs::create_dir_all(&output).await?;
1170
1171 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1172 let items: HashMap<String, String> = self
1173 .rpc("snapshots.create_download_url".to_owned(), Some(params))
1174 .await?;
1175
1176 let total = Arc::new(AtomicUsize::new(0));
1177 let current = Arc::new(AtomicUsize::new(0));
1178 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1179
1180 let tasks = items
1181 .iter()
1182 .map(|(key, url)| {
1183 let http = self.http.clone();
1184 let key = key.clone();
1185 let url = url.clone();
1186 let output = output.clone();
1187 let progress = progress.clone();
1188 let current = current.clone();
1189 let total = total.clone();
1190 let sem = sem.clone();
1191
1192 tokio::spawn(async move {
1193 let _permit = sem.acquire().await.unwrap();
1194 let res = http.get(url).send().await.unwrap();
1195 let content_length = res.content_length().unwrap() as usize;
1196
1197 if let Some(progress) = &progress {
1198 let total = total.fetch_add(content_length, Ordering::SeqCst);
1199 progress
1200 .send(Progress {
1201 current: current.load(Ordering::SeqCst),
1202 total: total + content_length,
1203 })
1204 .await
1205 .unwrap();
1206 }
1207
1208 let mut file = File::create(output.join(key)).await.unwrap();
1209 let mut stream = res.bytes_stream();
1210
1211 while let Some(chunk) = stream.next().await {
1212 let chunk = chunk.unwrap();
1213 file.write_all(&chunk).await.unwrap();
1214 let len = chunk.len();
1215
1216 if let Some(progress) = &progress {
1217 let total = total.load(Ordering::SeqCst);
1218 let current = current.fetch_add(len, Ordering::SeqCst);
1219
1220 progress
1221 .send(Progress {
1222 current: current + len,
1223 total,
1224 })
1225 .await
1226 .unwrap();
1227 }
1228 }
1229 })
1230 })
1231 .collect::<Vec<_>>();
1232
1233 join_all(tasks)
1234 .await
1235 .into_iter()
1236 .collect::<Result<Vec<_>, _>>()
1237 .unwrap();
1238
1239 Ok(())
1240 }
1241
1242 pub async fn restore_snapshot(
1257 &self,
1258 project_id: ProjectID,
1259 snapshot_id: SnapshotID,
1260 topics: &[String],
1261 autolabel: &[String],
1262 autodepth: bool,
1263 dataset_name: Option<&str>,
1264 dataset_description: Option<&str>,
1265 ) -> Result<SnapshotRestoreResult, Error> {
1266 let params = SnapshotRestore {
1267 project_id,
1268 snapshot_id,
1269 fps: 1,
1270 autodepth,
1271 agtg_pipeline: !autolabel.is_empty(),
1272 autolabel: autolabel.to_vec(),
1273 topics: topics.to_vec(),
1274 dataset_name: dataset_name.map(|s| s.to_owned()),
1275 dataset_description: dataset_description.map(|s| s.to_owned()),
1276 };
1277 self.rpc("snapshots.restore".to_owned(), Some(params)).await
1278 }
1279
1280 pub async fn experiments(
1289 &self,
1290 project_id: ProjectID,
1291 name: Option<&str>,
1292 ) -> Result<Vec<Experiment>, Error> {
1293 let params = HashMap::from([("project_id", project_id)]);
1294 let experiments: Vec<Experiment> =
1295 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1296 if let Some(name) = name {
1297 Ok(experiments
1298 .into_iter()
1299 .filter(|e| e.name().contains(name))
1300 .collect())
1301 } else {
1302 Ok(experiments)
1303 }
1304 }
1305
1306 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1309 let params = HashMap::from([("trainer_id", experiment_id)]);
1310 self.rpc("trainer.get".to_owned(), Some(params)).await
1311 }
1312
1313 pub async fn training_sessions(
1322 &self,
1323 experiment_id: ExperimentID,
1324 name: Option<&str>,
1325 ) -> Result<Vec<TrainingSession>, Error> {
1326 let params = HashMap::from([("trainer_id", experiment_id)]);
1327 let sessions: Vec<TrainingSession> = self
1328 .rpc("trainer.session.list".to_owned(), Some(params))
1329 .await?;
1330 if let Some(name) = name {
1331 Ok(sessions
1332 .into_iter()
1333 .filter(|s| s.name().contains(name))
1334 .collect())
1335 } else {
1336 Ok(sessions)
1337 }
1338 }
1339
1340 pub async fn training_session(
1343 &self,
1344 session_id: TrainingSessionID,
1345 ) -> Result<TrainingSession, Error> {
1346 let params = HashMap::from([("trainer_session_id", session_id)]);
1347 self.rpc("trainer.session.get".to_owned(), Some(params))
1348 .await
1349 }
1350
1351 pub async fn validation_sessions(
1353 &self,
1354 project_id: ProjectID,
1355 ) -> Result<Vec<ValidationSession>, Error> {
1356 let params = HashMap::from([("project_id", project_id)]);
1357 self.rpc("validate.session.list".to_owned(), Some(params))
1358 .await
1359 }
1360
1361 pub async fn validation_session(
1363 &self,
1364 session_id: ValidationSessionID,
1365 ) -> Result<ValidationSession, Error> {
1366 let params = HashMap::from([("validate_session_id", session_id)]);
1367 self.rpc("validate.session.get".to_owned(), Some(params))
1368 .await
1369 }
1370
1371 pub async fn artifacts(
1374 &self,
1375 training_session_id: TrainingSessionID,
1376 ) -> Result<Vec<Artifact>, Error> {
1377 let params = HashMap::from([("training_session_id", training_session_id)]);
1378 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
1379 .await
1380 }
1381
1382 pub async fn download_artifact(
1388 &self,
1389 training_session_id: TrainingSessionID,
1390 modelname: &str,
1391 filename: Option<PathBuf>,
1392 progress: Option<Sender<Progress>>,
1393 ) -> Result<(), Error> {
1394 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
1395 let resp = self
1396 .http
1397 .get(format!(
1398 "{}/download_model?training_session_id={}&file={}",
1399 self.url,
1400 training_session_id.value(),
1401 modelname
1402 ))
1403 .header("Authorization", format!("Bearer {}", self.token().await))
1404 .send()
1405 .await?;
1406 if !resp.status().is_success() {
1407 let err = resp.error_for_status_ref().unwrap_err();
1408 return Err(Error::HttpError(err));
1409 }
1410
1411 fs::create_dir_all(filename.parent().unwrap()).await?;
1412
1413 if let Some(progress) = progress {
1414 let total = resp.content_length().unwrap() as usize;
1415 progress.send(Progress { current: 0, total }).await.unwrap();
1416
1417 let mut file = File::create(filename).await?;
1418 let mut current = 0;
1419 let mut stream = resp.bytes_stream();
1420
1421 while let Some(item) = stream.next().await {
1422 let chunk = item?;
1423 file.write_all(&chunk).await?;
1424 current += chunk.len();
1425 progress.send(Progress { current, total }).await.unwrap();
1426 }
1427 } else {
1428 let body = resp.bytes().await?;
1429 fs::write(filename, body).await?;
1430 }
1431
1432 Ok(())
1433 }
1434
1435 pub async fn download_checkpoint(
1445 &self,
1446 training_session_id: TrainingSessionID,
1447 checkpoint: &str,
1448 filename: Option<PathBuf>,
1449 progress: Option<Sender<Progress>>,
1450 ) -> Result<(), Error> {
1451 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
1452 let resp = self
1453 .http
1454 .get(format!(
1455 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1456 self.url,
1457 training_session_id.value(),
1458 checkpoint
1459 ))
1460 .header("Authorization", format!("Bearer {}", self.token().await))
1461 .send()
1462 .await?;
1463 if !resp.status().is_success() {
1464 let err = resp.error_for_status_ref().unwrap_err();
1465 return Err(Error::HttpError(err));
1466 }
1467
1468 fs::create_dir_all(filename.parent().unwrap()).await?;
1469
1470 if let Some(progress) = progress {
1471 let total = resp.content_length().unwrap() as usize;
1472 progress.send(Progress { current: 0, total }).await.unwrap();
1473
1474 let mut file = File::create(filename).await?;
1475 let mut current = 0;
1476 let mut stream = resp.bytes_stream();
1477
1478 while let Some(item) = stream.next().await {
1479 let chunk = item?;
1480 file.write_all(&chunk).await?;
1481 current += chunk.len();
1482 progress.send(Progress { current, total }).await.unwrap();
1483 }
1484 } else {
1485 let body = resp.bytes().await?;
1486 fs::write(filename, body).await?;
1487 }
1488
1489 Ok(())
1490 }
1491
1492 pub async fn tasks(
1494 &self,
1495 name: Option<&str>,
1496 workflow: Option<&str>,
1497 status: Option<&str>,
1498 manager: Option<&str>,
1499 ) -> Result<Vec<Task>, Error> {
1500 let mut params = TasksListParams {
1501 continue_token: None,
1502 status: status.map(|s| vec![s.to_owned()]),
1503 manager: manager.map(|m| vec![m.to_owned()]),
1504 };
1505 let mut tasks = Vec::new();
1506
1507 loop {
1508 let result = self
1509 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
1510 .await?;
1511 tasks.extend(result.tasks);
1512
1513 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
1514 params.continue_token = None;
1515 } else {
1516 params.continue_token = result.continue_token;
1517 }
1518
1519 if params.continue_token.is_none() {
1520 break;
1521 }
1522 }
1523
1524 if let Some(name) = name {
1525 tasks.retain(|t| t.name().contains(name));
1526 }
1527
1528 if let Some(workflow) = workflow {
1529 tasks.retain(|t| t.workflow().contains(workflow));
1530 }
1531
1532 Ok(tasks)
1533 }
1534
1535 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
1537 self.rpc(
1538 "task.get".to_owned(),
1539 Some(HashMap::from([("id", task_id)])),
1540 )
1541 .await
1542 }
1543
1544 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
1546 let status = TaskStatus {
1547 task_id,
1548 status: status.to_owned(),
1549 };
1550 self.rpc("docker.update.status".to_owned(), Some(status))
1551 .await
1552 }
1553
1554 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
1558 let stages: Vec<HashMap<String, String>> = stages
1559 .iter()
1560 .map(|(key, value)| {
1561 let mut stage_map = HashMap::new();
1562 stage_map.insert(key.to_string(), value.to_string());
1563 stage_map
1564 })
1565 .collect();
1566 let params = TaskStages { task_id, stages };
1567 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
1568 Ok(())
1569 }
1570
1571 pub async fn update_stage(
1574 &self,
1575 task_id: TaskID,
1576 stage: &str,
1577 status: &str,
1578 message: &str,
1579 percentage: u8,
1580 ) -> Result<(), Error> {
1581 let stage = Stage::new(
1582 Some(task_id),
1583 stage.to_owned(),
1584 Some(status.to_owned()),
1585 Some(message.to_owned()),
1586 percentage,
1587 );
1588 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
1589 Ok(())
1590 }
1591
1592 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
1594 let req = self
1595 .http
1596 .get(format!("{}/{}", self.url, query))
1597 .header("User-Agent", "EdgeFirst Client")
1598 .header("Authorization", format!("Bearer {}", self.token().await));
1599 let resp = req.send().await?;
1600
1601 if resp.status().is_success() {
1602 let body = resp.bytes().await?;
1603
1604 if log_enabled!(Level::Trace) {
1605 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
1606 }
1607
1608 Ok(body.to_vec())
1609 } else {
1610 let err = resp.error_for_status_ref().unwrap_err();
1611 Err(Error::HttpError(err))
1612 }
1613 }
1614
1615 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
1619 let req = self
1620 .http
1621 .post(format!("{}/api?method={}", self.url, method))
1622 .header("Accept", "application/json")
1623 .header("User-Agent", "EdgeFirst Client")
1624 .header("Authorization", format!("Bearer {}", self.token().await))
1625 .multipart(form);
1626 let resp = req.send().await?;
1627
1628 if resp.status().is_success() {
1629 let body = resp.bytes().await?;
1630
1631 if log_enabled!(Level::Trace) {
1632 trace!(
1633 "POST Multipart Response: {}",
1634 String::from_utf8_lossy(&body)
1635 );
1636 }
1637
1638 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
1639 Ok(response) => response,
1640 Err(err) => {
1641 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
1642 return Err(err.into());
1643 }
1644 };
1645
1646 if let Some(error) = response.error {
1647 Err(Error::RpcError(error.code, error.message))
1648 } else if let Some(result) = response.result {
1649 Ok(result)
1650 } else {
1651 Err(Error::InvalidResponse)
1652 }
1653 } else {
1654 let err = resp.error_for_status_ref().unwrap_err();
1655 Err(Error::HttpError(err))
1656 }
1657 }
1658
1659 pub async fn rpc<Params, RpcResult>(
1668 &self,
1669 method: String,
1670 params: Option<Params>,
1671 ) -> Result<RpcResult, Error>
1672 where
1673 Params: Serialize,
1674 RpcResult: DeserializeOwned,
1675 {
1676 let auth_expires = self.token_expiration().await?;
1677 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
1678 self.renew_token().await?;
1679 }
1680
1681 self.rpc_without_auth(method, params).await
1682 }
1683
1684 async fn rpc_without_auth<Params, RpcResult>(
1685 &self,
1686 method: String,
1687 params: Option<Params>,
1688 ) -> Result<RpcResult, Error>
1689 where
1690 Params: Serialize,
1691 RpcResult: DeserializeOwned,
1692 {
1693 let request = RpcRequest {
1694 method,
1695 params,
1696 ..Default::default()
1697 };
1698
1699 if log_enabled!(Level::Trace) {
1700 trace!(
1701 "RPC Request: {}",
1702 serde_json::ser::to_string_pretty(&request)?
1703 );
1704 }
1705
1706 for attempt in 0..MAX_RETRIES {
1707 let res = match self
1708 .http
1709 .post(format!("{}/api", self.url))
1710 .header("Accept", "application/json")
1711 .header("User-Agent", "EdgeFirst Client")
1712 .header("Authorization", format!("Bearer {}", self.token().await))
1713 .json(&request)
1714 .send()
1715 .await
1716 {
1717 Ok(res) => res,
1718 Err(err) => {
1719 warn!("Socket Error: {:?}", err);
1720 continue;
1721 }
1722 };
1723
1724 if res.status().is_success() {
1725 let body = res.bytes().await?;
1726
1727 if log_enabled!(Level::Trace) {
1728 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
1729 }
1730
1731 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
1732 Ok(response) => response,
1733 Err(err) => {
1734 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
1735 return Err(err.into());
1736 }
1737 };
1738
1739 if let Some(error) = response.error {
1745 return Err(Error::RpcError(error.code, error.message));
1746 } else if let Some(result) = response.result {
1747 return Ok(result);
1748 } else {
1749 return Err(Error::InvalidResponse);
1750 }
1751 } else {
1752 let err = res.error_for_status_ref().unwrap_err();
1753 warn!("HTTP Error {}: {}", err, res.text().await?);
1754 }
1755
1756 warn!(
1757 "Retrying RPC request (attempt {}/{})...",
1758 attempt + 1,
1759 MAX_RETRIES
1760 );
1761 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1762 }
1763
1764 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1765 }
1766}
1767
1768async fn upload_multipart(
1769 http: reqwest::Client,
1770 part: SnapshotPart,
1771 path: PathBuf,
1772 total: usize,
1773 current: Arc<AtomicUsize>,
1774 progress: Option<Sender<Progress>>,
1775) -> Result<SnapshotCompleteMultipartParams, Error> {
1776 let filesize = path.metadata()?.len() as usize;
1777 let n_parts = filesize.div_ceil(PART_SIZE);
1778 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1779
1780 let key = part.key.unwrap();
1781 let upload_id = part.upload_id;
1782
1783 let urls = part.urls.clone();
1784 let etags = Arc::new(tokio::sync::Mutex::new(vec![
1785 EtagPart {
1786 etag: "".to_owned(),
1787 part_number: 0,
1788 };
1789 n_parts
1790 ]));
1791
1792 let tasks = (0..n_parts)
1793 .map(|part| {
1794 let http = http.clone();
1795 let url = urls[part].clone();
1796 let etags = etags.clone();
1797 let path = path.to_owned();
1798 let sem = sem.clone();
1799 let progress = progress.clone();
1800 let current = current.clone();
1801
1802 tokio::spawn(async move {
1803 let _permit = sem.acquire().await?;
1804 let mut etag = None;
1805
1806 for attempt in 0..MAX_RETRIES {
1807 match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
1808 {
1809 Ok(v) => {
1810 etag = Some(v);
1811 break;
1812 }
1813 Err(err) => {
1814 warn!("Upload Part Error: {:?}", err);
1815 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1816 }
1817 }
1818 }
1819
1820 if let Some(etag) = etag {
1821 let mut etags = etags.lock().await;
1822 etags[part] = EtagPart {
1823 etag,
1824 part_number: part + 1,
1825 };
1826
1827 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
1828 if let Some(progress) = &progress {
1829 progress
1830 .send(Progress {
1831 current: current + PART_SIZE,
1832 total,
1833 })
1834 .await
1835 .unwrap();
1836 }
1837
1838 Ok(())
1839 } else {
1840 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1841 }
1842 })
1843 })
1844 .collect::<Vec<_>>();
1845
1846 join_all(tasks)
1847 .await
1848 .into_iter()
1849 .collect::<Result<Vec<_>, _>>()?;
1850
1851 Ok(SnapshotCompleteMultipartParams {
1852 key,
1853 upload_id,
1854 etag_list: etags.lock().await.clone(),
1855 })
1856}
1857
1858async fn upload_part(
1859 http: reqwest::Client,
1860 url: String,
1861 path: PathBuf,
1862 part: usize,
1863 n_parts: usize,
1864) -> Result<String, Error> {
1865 let filesize = path.metadata()?.len() as usize;
1866 let mut file = File::open(path).await.unwrap();
1867 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
1868 .await
1869 .unwrap();
1870 let file = file.take(PART_SIZE as u64);
1871
1872 let body_length = if part + 1 == n_parts {
1873 filesize % PART_SIZE
1874 } else {
1875 PART_SIZE
1876 };
1877
1878 let stream = FramedRead::new(file, BytesCodec::new());
1879 let body = Body::wrap_stream(stream);
1880
1881 let resp = http
1882 .put(url.clone())
1883 .header(CONTENT_LENGTH, body_length)
1884 .body(body)
1885 .send()
1886 .await?
1887 .error_for_status()?;
1888 let etag = resp
1889 .headers()
1890 .get("etag")
1891 .unwrap()
1892 .to_str()
1893 .unwrap()
1894 .to_owned();
1895 Ok(etag
1897 .strip_prefix("\"")
1898 .unwrap()
1899 .strip_suffix("\"")
1900 .unwrap()
1901 .to_owned())
1902}