1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23 collections::HashMap,
24 fs::create_dir_all,
25 io::{SeekFrom, Write as _},
26 path::{Path, PathBuf},
27 sync::{
28 Arc,
29 atomic::{AtomicUsize, Ordering},
30 },
31 time::Duration,
32 vec,
33};
34use tokio::{
35 fs::{self, File},
36 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
37 sync::{RwLock, Semaphore, mpsc::Sender},
38};
39use tokio_util::codec::{BytesCodec, FramedRead};
40use walkdir::WalkDir;
41
42#[cfg(feature = "polars")]
43use polars::prelude::*;
44
45static MAX_TASKS: usize = 32;
46static MAX_RETRIES: u32 = 10;
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49#[derive(Debug, Clone)]
71pub struct Progress {
72 pub current: usize,
74 pub total: usize,
76}
77
78#[derive(Serialize)]
79struct RpcRequest<Params> {
80 id: u64,
81 jsonrpc: String,
82 method: String,
83 params: Option<Params>,
84}
85
86impl<T> Default for RpcRequest<T> {
87 fn default() -> Self {
88 RpcRequest {
89 id: 0,
90 jsonrpc: "2.0".to_string(),
91 method: "".to_string(),
92 params: None,
93 }
94 }
95}
96
97#[derive(Deserialize)]
98struct RpcError {
99 code: i32,
100 message: String,
101}
102
103#[derive(Deserialize)]
104struct RpcResponse<RpcResult> {
105 id: String,
106 jsonrpc: String,
107 error: Option<RpcError>,
108 result: Option<RpcResult>,
109}
110
111#[derive(Deserialize)]
112struct EmptyResult {}
113
114#[derive(Debug, Serialize)]
115struct SnapshotCreateParams {
116 snapshot_name: String,
117 keys: Vec<String>,
118}
119
120#[derive(Debug, Deserialize)]
121struct SnapshotCreateResult {
122 snapshot_id: SnapshotID,
123 urls: Vec<String>,
124}
125
126#[derive(Debug, Serialize)]
127struct SnapshotCreateMultipartParams {
128 snapshot_name: String,
129 keys: Vec<String>,
130 file_sizes: Vec<usize>,
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(untagged)]
135enum SnapshotCreateMultipartResultField {
136 Id(u64),
137 Part(SnapshotPart),
138}
139
140#[derive(Debug, Serialize)]
141struct SnapshotCompleteMultipartParams {
142 key: String,
143 upload_id: String,
144 etag_list: Vec<EtagPart>,
145}
146
147#[derive(Debug, Clone, Serialize)]
148struct EtagPart {
149 #[serde(rename = "ETag")]
150 etag: String,
151 #[serde(rename = "PartNumber")]
152 part_number: usize,
153}
154
155#[derive(Debug, Clone, Deserialize)]
156struct SnapshotPart {
157 key: Option<String>,
158 upload_id: String,
159 urls: Vec<String>,
160}
161
162#[derive(Debug, Serialize)]
163struct SnapshotStatusParams {
164 snapshot_id: SnapshotID,
165 status: String,
166}
167
168#[derive(Deserialize, Debug)]
169struct SnapshotStatusResult {
170 pub id: SnapshotID,
171 pub uid: String,
172 pub description: String,
173 pub date: String,
174 pub status: String,
175}
176
177#[derive(Serialize)]
178struct ImageListParams {
179 images_filter: ImagesFilter,
180 image_files_filter: HashMap<String, String>,
181 only_ids: bool,
182}
183
184#[derive(Serialize)]
185struct ImagesFilter {
186 dataset_id: DatasetID,
187}
188
189#[derive(Clone, Debug)]
238pub struct Client {
239 http: reqwest::Client,
240 url: String,
241 token: Arc<RwLock<String>>,
242 token_path: Option<PathBuf>,
243}
244
245impl Client {
246 pub fn new() -> Result<Self, Error> {
255 Ok(Client {
256 http: reqwest::Client::builder()
257 .read_timeout(Duration::from_secs(60))
258 .build()?,
259 url: "https://edgefirst.studio".to_string(),
260 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
261 token_path: None,
262 })
263 }
264
265 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
269 Ok(Client {
270 url: format!("https://{}.edgefirst.studio", server),
271 ..self.clone()
272 })
273 }
274
275 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
278 let params = HashMap::from([("username", username), ("password", password)]);
279 let login: LoginResult = self
280 .rpc_without_auth("auth.login".to_owned(), Some(params))
281 .await?;
282 Ok(Client {
283 token: Arc::new(tokio::sync::RwLock::new(login.token)),
284 ..self.clone()
285 })
286 }
287
288 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
291 let token_path = match token_path {
292 Some(path) => path.to_path_buf(),
293 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
294 .unwrap()
295 .config_dir()
296 .join("token"),
297 };
298
299 debug!("Using token path: {:?}", token_path);
300
301 let token = match token_path.exists() {
302 true => std::fs::read_to_string(&token_path)?,
303 false => "".to_string(),
304 };
305
306 if !token.is_empty() {
307 let client = self.with_token(&token)?;
308 Ok(Client {
309 token_path: Some(token_path),
310 ..client
311 })
312 } else {
313 Ok(Client {
314 token_path: Some(token_path),
315 ..self.clone()
316 })
317 }
318 }
319
320 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
322 if token.is_empty() {
323 return Ok(self.clone());
324 }
325
326 let token_parts: Vec<&str> = token.split('.').collect();
327 if token_parts.len() != 3 {
328 return Err(Error::InvalidToken);
329 }
330
331 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
332 .decode(token_parts[1])
333 .unwrap();
334 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
335 let server = match payload.get("database") {
336 Some(value) => Ok(value.as_str().unwrap().to_string()),
337 None => Err(Error::InvalidToken),
338 }?;
339
340 Ok(Client {
341 url: format!("https://{}.edgefirst.studio", server),
342 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
343 ..self.clone()
344 })
345 }
346
347 pub async fn save_token(&self) -> Result<(), Error> {
348 let path = self.token_path.clone().unwrap_or_else(|| {
349 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
350 .unwrap()
351 .config_dir()
352 .join("token")
353 });
354
355 create_dir_all(path.parent().unwrap())?;
356 let mut file = std::fs::File::create(&path)?;
357 file.write_all(self.token.read().await.as_bytes())?;
358
359 debug!("Saved token to {:?}", path);
360
361 Ok(())
362 }
363
364 pub async fn version(&self) -> Result<String, Error> {
367 let version: HashMap<String, String> = self
368 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
369 .await?;
370 let version = version.get("version").ok_or(Error::InvalidResponse)?;
371 Ok(version.to_owned())
372 }
373
374 pub async fn logout(&self) -> Result<(), Error> {
378 {
379 let mut token = self.token.write().await;
380 *token = "".to_string();
381 }
382
383 if let Some(path) = &self.token_path
384 && path.exists()
385 {
386 fs::remove_file(path).await?;
387 }
388
389 Ok(())
390 }
391
392 pub async fn token(&self) -> String {
396 self.token.read().await.clone()
397 }
398
399 pub async fn verify_token(&self) -> Result<(), Error> {
404 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
405 .await?;
406 Ok::<(), Error>(())
407 }
408
409 pub async fn renew_token(&self) -> Result<(), Error> {
414 let params = HashMap::from([("username".to_string(), self.username().await?)]);
415 let result: LoginResult = self
416 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
417 .await?;
418
419 {
420 let mut token = self.token.write().await;
421 *token = result.token;
422 }
423
424 if self.token_path.is_some() {
425 self.save_token().await?;
426 }
427
428 Ok(())
429 }
430
431 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
432 let token = self.token.read().await;
433 if token.is_empty() {
434 return Err(Error::EmptyToken);
435 }
436
437 let token_parts: Vec<&str> = token.split('.').collect();
438 if token_parts.len() != 3 {
439 return Err(Error::InvalidToken);
440 }
441
442 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
443 .decode(token_parts[1])
444 .unwrap();
445 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
446 match payload.get(field) {
447 Some(value) => Ok(value.to_owned()),
448 None => Err(Error::InvalidToken),
449 }
450 }
451
452 pub fn url(&self) -> &str {
454 &self.url
455 }
456
457 pub async fn username(&self) -> Result<String, Error> {
459 match self.token_field("username").await? {
460 serde_json::Value::String(username) => Ok(username),
461 _ => Err(Error::InvalidToken),
462 }
463 }
464
465 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
467 let ts = match self.token_field("exp").await? {
468 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
469 _ => return Err(Error::InvalidToken),
470 };
471
472 match DateTime::<Utc>::from_timestamp_secs(ts) {
473 Some(dt) => Ok(dt),
474 None => Err(Error::InvalidToken),
475 }
476 }
477
478 pub async fn organization(&self) -> Result<Organization, Error> {
480 self.rpc::<(), Organization>("org.get".to_owned(), None)
481 .await
482 }
483
484 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
492 let projects = self
493 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
494 .await?;
495 if let Some(name) = name {
496 Ok(projects
497 .into_iter()
498 .filter(|p| p.name().contains(name))
499 .collect())
500 } else {
501 Ok(projects)
502 }
503 }
504
505 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
508 let params = HashMap::from([("project_id", project_id)]);
509 self.rpc("project.get".to_owned(), Some(params)).await
510 }
511
512 pub async fn datasets(
516 &self,
517 project_id: ProjectID,
518 name: Option<&str>,
519 ) -> Result<Vec<Dataset>, Error> {
520 let params = HashMap::from([("project_id", project_id)]);
521 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
522 if let Some(name) = name {
523 Ok(datasets
524 .into_iter()
525 .filter(|d| d.name().contains(name))
526 .collect())
527 } else {
528 Ok(datasets)
529 }
530 }
531
532 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
535 let params = HashMap::from([("dataset_id", dataset_id)]);
536 self.rpc("dataset.get".to_owned(), Some(params)).await
537 }
538
539 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
541 let params = HashMap::from([("dataset_id", dataset_id)]);
542 self.rpc("label.list".to_owned(), Some(params)).await
543 }
544
545 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
547 let new_label = NewLabel {
548 dataset_id,
549 labels: vec![NewLabelObject {
550 name: name.to_owned(),
551 }],
552 };
553 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
554 Ok(())
555 }
556
557 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
560 let params = HashMap::from([("label_id", label_id)]);
561 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
562 Ok(())
563 }
564
565 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
569 #[derive(Serialize)]
570 struct Params {
571 dataset_id: DatasetID,
572 label_id: u64,
573 label_name: String,
574 label_index: u64,
575 }
576
577 let _: String = self
578 .rpc(
579 "label.update".to_owned(),
580 Some(Params {
581 dataset_id: label.dataset_id(),
582 label_id: label.id(),
583 label_name: label.name().to_owned(),
584 label_index: label.index(),
585 }),
586 )
587 .await?;
588 Ok(())
589 }
590
591 pub async fn download_dataset(
592 &self,
593 dataset_id: DatasetID,
594 groups: &[String],
595 file_types: &[FileType],
596 output: PathBuf,
597 progress: Option<Sender<Progress>>,
598 ) -> Result<(), Error> {
599 let samples = self
600 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
601 .await?;
602 fs::create_dir_all(&output).await?;
603
604 let total = samples.len();
605 let current = Arc::new(AtomicUsize::new(0));
606 let sem = Arc::new(Semaphore::new(MAX_TASKS));
607
608 let tasks = samples
609 .into_iter()
610 .map(|sample| {
611 let sem = sem.clone();
612 let client = self.clone();
613 let current = current.clone();
614 let progress = progress.clone();
615 let file_types = file_types.to_vec();
616 let output = output.clone();
617
618 tokio::spawn(async move {
619 let _permit = sem.acquire().await.unwrap();
620
621 for file_type in file_types {
622 if let Some(data) = sample.download(&client, file_type.clone()).await? {
623 let file_ext = match file_type {
624 FileType::Image => infer::get(&data)
625 .expect("Failed to identify image file format for sample")
626 .extension()
627 .to_string(),
628 t => t.to_string(),
629 };
630
631 let file_name = format!(
632 "{}.{}",
633 sample.name().unwrap_or_else(|| "unknown".to_string()),
634 file_ext
635 );
636 let file_path = output.join(&file_name);
637
638 let mut file = File::create(&file_path).await?;
639 file.write_all(&data).await?;
640 } else {
641 warn!(
642 "No data for sample: {}",
643 sample
644 .id()
645 .map(|id| id.to_string())
646 .unwrap_or_else(|| "unknown".to_string())
647 );
648 }
649 }
650
651 if let Some(progress) = &progress {
652 let current = current.fetch_add(1, Ordering::SeqCst);
653 progress
654 .send(Progress {
655 current: current + 1,
656 total,
657 })
658 .await
659 .unwrap();
660 }
661
662 Ok::<(), Error>(())
663 })
664 })
665 .collect::<Vec<_>>();
666
667 join_all(tasks)
668 .await
669 .into_iter()
670 .collect::<Result<Vec<_>, _>>()?;
671
672 if let Some(progress) = progress {
673 drop(progress);
674 }
675
676 Ok(())
677 }
678
679 pub async fn annotation_sets(
681 &self,
682 dataset_id: DatasetID,
683 ) -> Result<Vec<AnnotationSet>, Error> {
684 let params = HashMap::from([("dataset_id", dataset_id)]);
685 self.rpc("annset.list".to_owned(), Some(params)).await
686 }
687
688 pub async fn annotation_set(
690 &self,
691 annotation_set_id: AnnotationSetID,
692 ) -> Result<AnnotationSet, Error> {
693 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
694 self.rpc("annset.get".to_owned(), Some(params)).await
695 }
696
697 pub async fn annotations(
710 &self,
711 annotation_set_id: AnnotationSetID,
712 groups: &[String],
713 annotation_types: &[AnnotationType],
714 progress: Option<Sender<Progress>>,
715 ) -> Result<Vec<Annotation>, Error> {
716 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
717 let labels = self
718 .labels(dataset_id)
719 .await?
720 .into_iter()
721 .map(|label| (label.name().to_string(), label.index()))
722 .collect::<HashMap<_, _>>();
723 let total = self
724 .samples_count(
725 dataset_id,
726 Some(annotation_set_id),
727 annotation_types,
728 groups,
729 &[],
730 )
731 .await?
732 .total as usize;
733 let mut annotations = vec![];
734 let mut continue_token: Option<String> = None;
735 let mut current = 0;
736
737 if total == 0 {
738 return Ok(annotations);
739 }
740
741 loop {
742 let params = SamplesListParams {
743 dataset_id,
744 annotation_set_id: Some(annotation_set_id),
745 types: annotation_types.iter().map(|t| t.to_string()).collect(),
746 group_names: groups.to_vec(),
747 continue_token,
748 };
749
750 let result: SamplesListResult =
751 self.rpc("samples.list".to_owned(), Some(params)).await?;
752 current += result.samples.len();
753 continue_token = result.continue_token;
754
755 if result.samples.is_empty() {
756 break;
757 }
758
759 for sample in result.samples {
760 if sample.annotations().is_empty() {
763 let mut annotation = Annotation::new();
764 annotation.set_sample_id(sample.id());
765 annotation.set_name(sample.name());
766 annotation.set_group(sample.group().cloned());
767 annotation.set_sequence_name(sample.sequence_name().cloned());
768 annotations.push(annotation);
769 continue;
770 }
771
772 sample.annotations().iter().for_each(|annotation| {
773 let mut annotation = annotation.clone();
774 annotation.set_sample_id(sample.id());
775 annotation.set_name(sample.name());
776 annotation.set_group(sample.group().cloned());
777 annotation.set_sequence_name(sample.sequence_name().cloned());
778 annotation.set_label_index(Some(labels[annotation.label().unwrap().as_str()]));
779 annotations.push(annotation);
780 });
781 }
782
783 if let Some(progress) = &progress {
784 progress.send(Progress { current, total }).await.unwrap();
785 }
786
787 match &continue_token {
788 Some(token) if !token.is_empty() => continue,
789 _ => break,
790 }
791 }
792
793 if let Some(progress) = progress {
794 drop(progress);
795 }
796
797 Ok(annotations)
798 }
799
800 pub async fn samples_count(
801 &self,
802 dataset_id: DatasetID,
803 annotation_set_id: Option<AnnotationSetID>,
804 annotation_types: &[AnnotationType],
805 groups: &[String],
806 types: &[FileType],
807 ) -> Result<SamplesCountResult, Error> {
808 let types = annotation_types
809 .iter()
810 .map(|t| t.to_string())
811 .chain(types.iter().map(|t| t.to_string()))
812 .collect::<Vec<_>>();
813
814 let params = SamplesListParams {
815 dataset_id,
816 annotation_set_id,
817 group_names: groups.to_vec(),
818 types,
819 continue_token: None,
820 };
821
822 self.rpc("samples.count".to_owned(), Some(params)).await
823 }
824
825 pub async fn samples(
826 &self,
827 dataset_id: DatasetID,
828 annotation_set_id: Option<AnnotationSetID>,
829 annotation_types: &[AnnotationType],
830 groups: &[String],
831 types: &[FileType],
832 progress: Option<Sender<Progress>>,
833 ) -> Result<Vec<Sample>, Error> {
834 let types = annotation_types
835 .iter()
836 .map(|t| t.to_string())
837 .chain(types.iter().map(|t| t.to_string()))
838 .collect::<Vec<_>>();
839 let labels = self
840 .labels(dataset_id)
841 .await?
842 .into_iter()
843 .map(|label| (label.name().to_string(), label.index()))
844 .collect::<HashMap<_, _>>();
845 let total = self
846 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
847 .await?
848 .total as usize;
849
850 let mut samples = vec![];
851 let mut continue_token: Option<String> = None;
852 let mut current = 0;
853
854 if total == 0 {
855 return Ok(samples);
856 }
857
858 loop {
859 let params = SamplesListParams {
860 dataset_id,
861 annotation_set_id,
862 types: types.clone(),
863 group_names: groups.to_vec(),
864 continue_token: continue_token.clone(),
865 };
866
867 let result: SamplesListResult =
868 self.rpc("samples.list".to_owned(), Some(params)).await?;
869 current += result.samples.len();
870 continue_token = result.continue_token;
871
872 if result.samples.is_empty() {
873 break;
874 }
875
876 samples.append(
877 &mut result
878 .samples
879 .into_iter()
880 .map(|s| {
881 let mut anns = s.annotations().to_vec();
882 for ann in &mut anns {
883 if let Some(label) = ann.label() {
884 ann.set_label_index(Some(labels[label.as_str()]));
885 }
886 }
887 s.with_annotations(anns)
888 })
889 .collect::<Vec<_>>(),
890 );
891
892 if let Some(progress) = &progress {
893 progress.send(Progress { current, total }).await.unwrap();
894 }
895
896 match &continue_token {
897 Some(token) if !token.is_empty() => continue,
898 _ => break,
899 }
900 }
901
902 if let Some(progress) = progress {
903 drop(progress);
904 }
905
906 Ok(samples)
907 }
908
909 pub async fn populate_samples(
1001 &self,
1002 dataset_id: DatasetID,
1003 annotation_set_id: Option<AnnotationSetID>,
1004 samples: Vec<Sample>,
1005 progress: Option<Sender<Progress>>,
1006 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1007 use crate::api::SamplesPopulateParams;
1008 use std::path::Path;
1009
1010 let total = samples.len();
1011
1012 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1015
1016 let samples: Vec<Sample> = samples
1018 .into_iter()
1019 .map(|mut sample| {
1020 if sample.uuid.is_none() {
1022 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1023 }
1024
1025 let sample_uuid = sample.uuid.clone().unwrap();
1026
1027 let updated_files: Vec<crate::SampleFile> = sample
1029 .files
1030 .iter()
1031 .map(|file| {
1032 if let Some(filename) = file.filename() {
1033 let path = Path::new(filename);
1034
1035 if path.exists() && path.is_file() {
1037 if let Some(basename) = path.file_name().and_then(|s| s.to_str()) {
1039 if file.file_type() == "image"
1041 && (sample.width.is_none() || sample.height.is_none())
1042 && let Ok(size) = imagesize::size(path)
1043 {
1044 sample.width = Some(size.width as u32);
1045 sample.height = Some(size.height as u32);
1046 }
1047
1048 files_to_upload.push((
1050 sample_uuid.clone(),
1051 file.file_type().to_string(),
1052 path.to_path_buf(),
1053 basename.to_string(),
1054 ));
1055
1056 return crate::SampleFile::with_filename(
1058 file.file_type().to_string(),
1059 basename.to_string(),
1060 );
1061 }
1062 }
1063 }
1064 file.clone()
1066 })
1067 .collect();
1068
1069 sample.files = updated_files;
1070 sample
1071 })
1072 .collect();
1073
1074 let has_files_to_upload = !files_to_upload.is_empty();
1075
1076 let params = SamplesPopulateParams {
1078 dataset_id,
1079 annotation_set_id,
1080 presigned_urls: if has_files_to_upload {
1081 Some(true)
1082 } else {
1083 Some(false)
1084 },
1085 samples,
1086 };
1087
1088 let results: Vec<crate::SamplesPopulateResult> = self
1089 .rpc("samples.populate".to_owned(), Some(params))
1090 .await?;
1091
1092 if has_files_to_upload {
1094 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1096 for (uuid, _file_type, path, basename) in files_to_upload {
1097 upload_map.insert((uuid, basename), path);
1098 }
1099
1100 let current = Arc::new(AtomicUsize::new(0));
1101 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1102
1103 let upload_tasks = results
1105 .iter()
1106 .map(|result| {
1107 let sem = sem.clone();
1108 let http = self.http.clone();
1109 let current = current.clone();
1110 let progress = progress.clone();
1111 let result_uuid = result.uuid.clone();
1112 let urls = result.urls.clone();
1113 let upload_map = upload_map.clone();
1114
1115 tokio::spawn(async move {
1116 let _permit = sem.acquire().await.unwrap();
1117
1118 for url_info in &urls {
1120 if let Some(local_path) =
1121 upload_map.get(&(result_uuid.clone(), url_info.filename.clone()))
1122 {
1123 upload_file_to_presigned_url(
1125 http.clone(),
1126 &url_info.url,
1127 local_path.clone(),
1128 )
1129 .await?;
1130 }
1131 }
1132
1133 if let Some(progress) = &progress {
1135 let current = current.fetch_add(1, Ordering::SeqCst);
1136 progress
1137 .send(Progress {
1138 current: current + 1,
1139 total,
1140 })
1141 .await
1142 .unwrap();
1143 }
1144
1145 Ok::<(), Error>(())
1146 })
1147 })
1148 .collect::<Vec<_>>();
1149
1150 join_all(upload_tasks)
1151 .await
1152 .into_iter()
1153 .collect::<Result<Vec<_>, _>>()?;
1154 }
1155
1156 if let Some(progress) = progress {
1157 drop(progress);
1158 }
1159
1160 Ok(results)
1161 }
1162
1163 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1164 for attempt in 1..MAX_RETRIES {
1165 let resp = match self.http.get(url).send().await {
1166 Ok(resp) => resp,
1167 Err(err) => {
1168 warn!(
1169 "Socket Error [retry {}/{}]: {:?}",
1170 attempt, MAX_RETRIES, err
1171 );
1172 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1173 continue;
1174 }
1175 };
1176
1177 match resp.bytes().await {
1178 Ok(body) => return Ok(body.to_vec()),
1179 Err(err) => {
1180 warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
1181 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1182 continue;
1183 }
1184 };
1185 }
1186
1187 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1188 }
1189
1190 #[cfg(feature = "polars")]
1201 pub async fn annotations_dataframe(
1202 &self,
1203 annotation_set_id: AnnotationSetID,
1204 groups: &[String],
1205 types: &[AnnotationType],
1206 progress: Option<Sender<Progress>>,
1207 ) -> Result<DataFrame, Error> {
1208 use crate::dataset::annotations_dataframe;
1209
1210 let annotations = self
1211 .annotations(annotation_set_id, groups, types, progress)
1212 .await?;
1213 Ok(annotations_dataframe(&annotations))
1214 }
1215
1216 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1219 let snapshots: Vec<Snapshot> = self
1220 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1221 .await?;
1222 if let Some(name) = name {
1223 Ok(snapshots
1224 .into_iter()
1225 .filter(|s| s.description().contains(name))
1226 .collect())
1227 } else {
1228 Ok(snapshots)
1229 }
1230 }
1231
1232 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1234 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1235 self.rpc("snapshots.get".to_owned(), Some(params)).await
1236 }
1237
1238 pub async fn create_snapshot(
1245 &self,
1246 path: &str,
1247 progress: Option<Sender<Progress>>,
1248 ) -> Result<Snapshot, Error> {
1249 let path = Path::new(path);
1250
1251 if path.is_dir() {
1252 return self
1253 .create_snapshot_folder(path.to_str().unwrap(), progress)
1254 .await;
1255 }
1256
1257 let name = path.file_name().unwrap().to_str().unwrap();
1258 let total = path.metadata()?.len() as usize;
1259 let current = Arc::new(AtomicUsize::new(0));
1260
1261 if let Some(progress) = &progress {
1262 progress.send(Progress { current: 0, total }).await.unwrap();
1263 }
1264
1265 let params = SnapshotCreateMultipartParams {
1266 snapshot_name: name.to_owned(),
1267 keys: vec![name.to_owned()],
1268 file_sizes: vec![total],
1269 };
1270 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1271 .rpc(
1272 "snapshots.create_upload_url_multipart".to_owned(),
1273 Some(params),
1274 )
1275 .await?;
1276
1277 let snapshot_id = match multipart.get("snapshot_id") {
1278 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1279 _ => return Err(Error::InvalidResponse),
1280 };
1281
1282 let snapshot = self.snapshot(snapshot_id).await?;
1283 let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1284 let part_key = format!("{}/{}", part_prefix, name);
1285 let mut part = match multipart.get(&part_key) {
1286 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1287 _ => return Err(Error::InvalidResponse),
1288 }
1289 .clone();
1290 part.key = Some(part_key);
1291
1292 let params = upload_multipart(
1293 self.http.clone(),
1294 part.clone(),
1295 path.to_path_buf(),
1296 total,
1297 current,
1298 progress.clone(),
1299 )
1300 .await?;
1301
1302 let complete: String = self
1303 .rpc(
1304 "snapshots.complete_multipart_upload".to_owned(),
1305 Some(params),
1306 )
1307 .await?;
1308 debug!("Snapshot Multipart Complete: {:?}", complete);
1309
1310 let params: SnapshotStatusParams = SnapshotStatusParams {
1311 snapshot_id,
1312 status: "available".to_owned(),
1313 };
1314 let _: SnapshotStatusResult = self
1315 .rpc("snapshots.update".to_owned(), Some(params))
1316 .await?;
1317
1318 if let Some(progress) = progress {
1319 drop(progress);
1320 }
1321
1322 self.snapshot(snapshot_id).await
1323 }
1324
1325 async fn create_snapshot_folder(
1326 &self,
1327 path: &str,
1328 progress: Option<Sender<Progress>>,
1329 ) -> Result<Snapshot, Error> {
1330 let path = Path::new(path);
1331 let name = path.file_name().unwrap().to_str().unwrap();
1332
1333 let files = WalkDir::new(path)
1334 .into_iter()
1335 .filter_map(|entry| entry.ok())
1336 .filter(|entry| entry.file_type().is_file())
1337 .map(|entry| entry.path().strip_prefix(path).unwrap().to_owned())
1338 .collect::<Vec<_>>();
1339
1340 let total = files
1341 .iter()
1342 .map(|file| path.join(file).metadata().unwrap().len() as usize)
1343 .sum();
1344 let current = Arc::new(AtomicUsize::new(0));
1345
1346 if let Some(progress) = &progress {
1347 progress.send(Progress { current: 0, total }).await.unwrap();
1348 }
1349
1350 let keys = files
1351 .iter()
1352 .map(|key| key.to_str().unwrap().to_owned())
1353 .collect::<Vec<_>>();
1354 let file_sizes = files
1355 .iter()
1356 .map(|key| path.join(key).metadata().unwrap().len() as usize)
1357 .collect::<Vec<_>>();
1358
1359 let params = SnapshotCreateMultipartParams {
1360 snapshot_name: name.to_owned(),
1361 keys,
1362 file_sizes,
1363 };
1364
1365 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1366 .rpc(
1367 "snapshots.create_upload_url_multipart".to_owned(),
1368 Some(params),
1369 )
1370 .await?;
1371
1372 let snapshot_id = match multipart.get("snapshot_id") {
1373 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1374 _ => return Err(Error::InvalidResponse),
1375 };
1376
1377 let snapshot = self.snapshot(snapshot_id).await?;
1378 let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1379
1380 for file in files {
1381 let part_key = format!("{}/{}", part_prefix, file.to_str().unwrap());
1382 let mut part = match multipart.get(&part_key) {
1383 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1384 _ => return Err(Error::InvalidResponse),
1385 }
1386 .clone();
1387 part.key = Some(part_key);
1388
1389 let params = upload_multipart(
1390 self.http.clone(),
1391 part.clone(),
1392 path.join(file),
1393 total,
1394 current.clone(),
1395 progress.clone(),
1396 )
1397 .await?;
1398
1399 let complete: String = self
1400 .rpc(
1401 "snapshots.complete_multipart_upload".to_owned(),
1402 Some(params),
1403 )
1404 .await?;
1405 debug!("Snapshot Part Complete: {:?}", complete);
1406 }
1407
1408 let params = SnapshotStatusParams {
1409 snapshot_id,
1410 status: "available".to_owned(),
1411 };
1412 let _: SnapshotStatusResult = self
1413 .rpc("snapshots.update".to_owned(), Some(params))
1414 .await?;
1415
1416 if let Some(progress) = progress {
1417 drop(progress);
1418 }
1419
1420 self.snapshot(snapshot_id).await
1421 }
1422
1423 pub async fn download_snapshot(
1428 &self,
1429 snapshot_id: SnapshotID,
1430 output: PathBuf,
1431 progress: Option<Sender<Progress>>,
1432 ) -> Result<(), Error> {
1433 fs::create_dir_all(&output).await?;
1434
1435 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1436 let items: HashMap<String, String> = self
1437 .rpc("snapshots.create_download_url".to_owned(), Some(params))
1438 .await?;
1439
1440 let total = Arc::new(AtomicUsize::new(0));
1441 let current = Arc::new(AtomicUsize::new(0));
1442 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1443
1444 let tasks = items
1445 .iter()
1446 .map(|(key, url)| {
1447 let http = self.http.clone();
1448 let key = key.clone();
1449 let url = url.clone();
1450 let output = output.clone();
1451 let progress = progress.clone();
1452 let current = current.clone();
1453 let total = total.clone();
1454 let sem = sem.clone();
1455
1456 tokio::spawn(async move {
1457 let _permit = sem.acquire().await.unwrap();
1458 let res = http.get(url).send().await.unwrap();
1459 let content_length = res.content_length().unwrap() as usize;
1460
1461 if let Some(progress) = &progress {
1462 let total = total.fetch_add(content_length, Ordering::SeqCst);
1463 progress
1464 .send(Progress {
1465 current: current.load(Ordering::SeqCst),
1466 total: total + content_length,
1467 })
1468 .await
1469 .unwrap();
1470 }
1471
1472 let mut file = File::create(output.join(key)).await.unwrap();
1473 let mut stream = res.bytes_stream();
1474
1475 while let Some(chunk) = stream.next().await {
1476 let chunk = chunk.unwrap();
1477 file.write_all(&chunk).await.unwrap();
1478 let len = chunk.len();
1479
1480 if let Some(progress) = &progress {
1481 let total = total.load(Ordering::SeqCst);
1482 let current = current.fetch_add(len, Ordering::SeqCst);
1483
1484 progress
1485 .send(Progress {
1486 current: current + len,
1487 total,
1488 })
1489 .await
1490 .unwrap();
1491 }
1492 }
1493 })
1494 })
1495 .collect::<Vec<_>>();
1496
1497 join_all(tasks)
1498 .await
1499 .into_iter()
1500 .collect::<Result<Vec<_>, _>>()
1501 .unwrap();
1502
1503 Ok(())
1504 }
1505
1506 pub async fn restore_snapshot(
1521 &self,
1522 project_id: ProjectID,
1523 snapshot_id: SnapshotID,
1524 topics: &[String],
1525 autolabel: &[String],
1526 autodepth: bool,
1527 dataset_name: Option<&str>,
1528 dataset_description: Option<&str>,
1529 ) -> Result<SnapshotRestoreResult, Error> {
1530 let params = SnapshotRestore {
1531 project_id,
1532 snapshot_id,
1533 fps: 1,
1534 autodepth,
1535 agtg_pipeline: !autolabel.is_empty(),
1536 autolabel: autolabel.to_vec(),
1537 topics: topics.to_vec(),
1538 dataset_name: dataset_name.map(|s| s.to_owned()),
1539 dataset_description: dataset_description.map(|s| s.to_owned()),
1540 };
1541 self.rpc("snapshots.restore".to_owned(), Some(params)).await
1542 }
1543
1544 pub async fn experiments(
1553 &self,
1554 project_id: ProjectID,
1555 name: Option<&str>,
1556 ) -> Result<Vec<Experiment>, Error> {
1557 let params = HashMap::from([("project_id", project_id)]);
1558 let experiments: Vec<Experiment> =
1559 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1560 if let Some(name) = name {
1561 Ok(experiments
1562 .into_iter()
1563 .filter(|e| e.name().contains(name))
1564 .collect())
1565 } else {
1566 Ok(experiments)
1567 }
1568 }
1569
1570 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1573 let params = HashMap::from([("trainer_id", experiment_id)]);
1574 self.rpc("trainer.get".to_owned(), Some(params)).await
1575 }
1576
1577 pub async fn training_sessions(
1586 &self,
1587 experiment_id: ExperimentID,
1588 name: Option<&str>,
1589 ) -> Result<Vec<TrainingSession>, Error> {
1590 let params = HashMap::from([("trainer_id", experiment_id)]);
1591 let sessions: Vec<TrainingSession> = self
1592 .rpc("trainer.session.list".to_owned(), Some(params))
1593 .await?;
1594 if let Some(name) = name {
1595 Ok(sessions
1596 .into_iter()
1597 .filter(|s| s.name().contains(name))
1598 .collect())
1599 } else {
1600 Ok(sessions)
1601 }
1602 }
1603
1604 pub async fn training_session(
1607 &self,
1608 session_id: TrainingSessionID,
1609 ) -> Result<TrainingSession, Error> {
1610 let params = HashMap::from([("trainer_session_id", session_id)]);
1611 self.rpc("trainer.session.get".to_owned(), Some(params))
1612 .await
1613 }
1614
1615 pub async fn validation_sessions(
1617 &self,
1618 project_id: ProjectID,
1619 ) -> Result<Vec<ValidationSession>, Error> {
1620 let params = HashMap::from([("project_id", project_id)]);
1621 self.rpc("validate.session.list".to_owned(), Some(params))
1622 .await
1623 }
1624
1625 pub async fn validation_session(
1627 &self,
1628 session_id: ValidationSessionID,
1629 ) -> Result<ValidationSession, Error> {
1630 let params = HashMap::from([("validate_session_id", session_id)]);
1631 self.rpc("validate.session.get".to_owned(), Some(params))
1632 .await
1633 }
1634
1635 pub async fn artifacts(
1638 &self,
1639 training_session_id: TrainingSessionID,
1640 ) -> Result<Vec<Artifact>, Error> {
1641 let params = HashMap::from([("training_session_id", training_session_id)]);
1642 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
1643 .await
1644 }
1645
1646 pub async fn download_artifact(
1652 &self,
1653 training_session_id: TrainingSessionID,
1654 modelname: &str,
1655 filename: Option<PathBuf>,
1656 progress: Option<Sender<Progress>>,
1657 ) -> Result<(), Error> {
1658 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
1659 let resp = self
1660 .http
1661 .get(format!(
1662 "{}/download_model?training_session_id={}&file={}",
1663 self.url,
1664 training_session_id.value(),
1665 modelname
1666 ))
1667 .header("Authorization", format!("Bearer {}", self.token().await))
1668 .send()
1669 .await?;
1670 if !resp.status().is_success() {
1671 let err = resp.error_for_status_ref().unwrap_err();
1672 return Err(Error::HttpError(err));
1673 }
1674
1675 fs::create_dir_all(filename.parent().unwrap()).await?;
1676
1677 if let Some(progress) = progress {
1678 let total = resp.content_length().unwrap() as usize;
1679 progress.send(Progress { current: 0, total }).await.unwrap();
1680
1681 let mut file = File::create(filename).await?;
1682 let mut current = 0;
1683 let mut stream = resp.bytes_stream();
1684
1685 while let Some(item) = stream.next().await {
1686 let chunk = item?;
1687 file.write_all(&chunk).await?;
1688 current += chunk.len();
1689 progress.send(Progress { current, total }).await.unwrap();
1690 }
1691 } else {
1692 let body = resp.bytes().await?;
1693 fs::write(filename, body).await?;
1694 }
1695
1696 Ok(())
1697 }
1698
1699 pub async fn download_checkpoint(
1709 &self,
1710 training_session_id: TrainingSessionID,
1711 checkpoint: &str,
1712 filename: Option<PathBuf>,
1713 progress: Option<Sender<Progress>>,
1714 ) -> Result<(), Error> {
1715 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
1716 let resp = self
1717 .http
1718 .get(format!(
1719 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1720 self.url,
1721 training_session_id.value(),
1722 checkpoint
1723 ))
1724 .header("Authorization", format!("Bearer {}", self.token().await))
1725 .send()
1726 .await?;
1727 if !resp.status().is_success() {
1728 let err = resp.error_for_status_ref().unwrap_err();
1729 return Err(Error::HttpError(err));
1730 }
1731
1732 fs::create_dir_all(filename.parent().unwrap()).await?;
1733
1734 if let Some(progress) = progress {
1735 let total = resp.content_length().unwrap() as usize;
1736 progress.send(Progress { current: 0, total }).await.unwrap();
1737
1738 let mut file = File::create(filename).await?;
1739 let mut current = 0;
1740 let mut stream = resp.bytes_stream();
1741
1742 while let Some(item) = stream.next().await {
1743 let chunk = item?;
1744 file.write_all(&chunk).await?;
1745 current += chunk.len();
1746 progress.send(Progress { current, total }).await.unwrap();
1747 }
1748 } else {
1749 let body = resp.bytes().await?;
1750 fs::write(filename, body).await?;
1751 }
1752
1753 Ok(())
1754 }
1755
1756 pub async fn tasks(
1758 &self,
1759 name: Option<&str>,
1760 workflow: Option<&str>,
1761 status: Option<&str>,
1762 manager: Option<&str>,
1763 ) -> Result<Vec<Task>, Error> {
1764 let mut params = TasksListParams {
1765 continue_token: None,
1766 status: status.map(|s| vec![s.to_owned()]),
1767 manager: manager.map(|m| vec![m.to_owned()]),
1768 };
1769 let mut tasks = Vec::new();
1770
1771 loop {
1772 let result = self
1773 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
1774 .await?;
1775 tasks.extend(result.tasks);
1776
1777 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
1778 params.continue_token = None;
1779 } else {
1780 params.continue_token = result.continue_token;
1781 }
1782
1783 if params.continue_token.is_none() {
1784 break;
1785 }
1786 }
1787
1788 if let Some(name) = name {
1789 tasks.retain(|t| t.name().contains(name));
1790 }
1791
1792 if let Some(workflow) = workflow {
1793 tasks.retain(|t| t.workflow().contains(workflow));
1794 }
1795
1796 Ok(tasks)
1797 }
1798
1799 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
1801 self.rpc(
1802 "task.get".to_owned(),
1803 Some(HashMap::from([("id", task_id)])),
1804 )
1805 .await
1806 }
1807
1808 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
1810 let status = TaskStatus {
1811 task_id,
1812 status: status.to_owned(),
1813 };
1814 self.rpc("docker.update.status".to_owned(), Some(status))
1815 .await
1816 }
1817
1818 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
1822 let stages: Vec<HashMap<String, String>> = stages
1823 .iter()
1824 .map(|(key, value)| {
1825 let mut stage_map = HashMap::new();
1826 stage_map.insert(key.to_string(), value.to_string());
1827 stage_map
1828 })
1829 .collect();
1830 let params = TaskStages { task_id, stages };
1831 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
1832 Ok(())
1833 }
1834
1835 pub async fn update_stage(
1838 &self,
1839 task_id: TaskID,
1840 stage: &str,
1841 status: &str,
1842 message: &str,
1843 percentage: u8,
1844 ) -> Result<(), Error> {
1845 let stage = Stage::new(
1846 Some(task_id),
1847 stage.to_owned(),
1848 Some(status.to_owned()),
1849 Some(message.to_owned()),
1850 percentage,
1851 );
1852 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
1853 Ok(())
1854 }
1855
1856 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
1858 let req = self
1859 .http
1860 .get(format!("{}/{}", self.url, query))
1861 .header("User-Agent", "EdgeFirst Client")
1862 .header("Authorization", format!("Bearer {}", self.token().await));
1863 let resp = req.send().await?;
1864
1865 if resp.status().is_success() {
1866 let body = resp.bytes().await?;
1867
1868 if log_enabled!(Level::Trace) {
1869 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
1870 }
1871
1872 Ok(body.to_vec())
1873 } else {
1874 let err = resp.error_for_status_ref().unwrap_err();
1875 Err(Error::HttpError(err))
1876 }
1877 }
1878
1879 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
1883 let req = self
1884 .http
1885 .post(format!("{}/api?method={}", self.url, method))
1886 .header("Accept", "application/json")
1887 .header("User-Agent", "EdgeFirst Client")
1888 .header("Authorization", format!("Bearer {}", self.token().await))
1889 .multipart(form);
1890 let resp = req.send().await?;
1891
1892 if resp.status().is_success() {
1893 let body = resp.bytes().await?;
1894
1895 if log_enabled!(Level::Trace) {
1896 trace!(
1897 "POST Multipart Response: {}",
1898 String::from_utf8_lossy(&body)
1899 );
1900 }
1901
1902 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
1903 Ok(response) => response,
1904 Err(err) => {
1905 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
1906 return Err(err.into());
1907 }
1908 };
1909
1910 if let Some(error) = response.error {
1911 Err(Error::RpcError(error.code, error.message))
1912 } else if let Some(result) = response.result {
1913 Ok(result)
1914 } else {
1915 Err(Error::InvalidResponse)
1916 }
1917 } else {
1918 let err = resp.error_for_status_ref().unwrap_err();
1919 Err(Error::HttpError(err))
1920 }
1921 }
1922
1923 pub async fn rpc<Params, RpcResult>(
1932 &self,
1933 method: String,
1934 params: Option<Params>,
1935 ) -> Result<RpcResult, Error>
1936 where
1937 Params: Serialize,
1938 RpcResult: DeserializeOwned,
1939 {
1940 let auth_expires = self.token_expiration().await?;
1941 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
1942 self.renew_token().await?;
1943 }
1944
1945 self.rpc_without_auth(method, params).await
1946 }
1947
1948 async fn rpc_without_auth<Params, RpcResult>(
1949 &self,
1950 method: String,
1951 params: Option<Params>,
1952 ) -> Result<RpcResult, Error>
1953 where
1954 Params: Serialize,
1955 RpcResult: DeserializeOwned,
1956 {
1957 let request = RpcRequest {
1958 method,
1959 params,
1960 ..Default::default()
1961 };
1962
1963 if log_enabled!(Level::Trace) {
1964 trace!(
1965 "RPC Request: {}",
1966 serde_json::ser::to_string_pretty(&request)?
1967 );
1968 }
1969
1970 for attempt in 0..MAX_RETRIES {
1971 let res = match self
1972 .http
1973 .post(format!("{}/api", self.url))
1974 .header("Accept", "application/json")
1975 .header("User-Agent", "EdgeFirst Client")
1976 .header("Authorization", format!("Bearer {}", self.token().await))
1977 .json(&request)
1978 .send()
1979 .await
1980 {
1981 Ok(res) => res,
1982 Err(err) => {
1983 warn!("Socket Error: {:?}", err);
1984 continue;
1985 }
1986 };
1987
1988 if res.status().is_success() {
1989 let body = res.bytes().await?;
1990
1991 if log_enabled!(Level::Trace) {
1992 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
1993 }
1994
1995 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
1996 Ok(response) => response,
1997 Err(err) => {
1998 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
1999 return Err(err.into());
2000 }
2001 };
2002
2003 if let Some(error) = response.error {
2009 return Err(Error::RpcError(error.code, error.message));
2010 } else if let Some(result) = response.result {
2011 return Ok(result);
2012 } else {
2013 return Err(Error::InvalidResponse);
2014 }
2015 } else {
2016 let err = res.error_for_status_ref().unwrap_err();
2017 warn!("HTTP Error {}: {}", err, res.text().await?);
2018 }
2019
2020 warn!(
2021 "Retrying RPC request (attempt {}/{})...",
2022 attempt + 1,
2023 MAX_RETRIES
2024 );
2025 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2026 }
2027
2028 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2029 }
2030}
2031
2032async fn upload_multipart(
2033 http: reqwest::Client,
2034 part: SnapshotPart,
2035 path: PathBuf,
2036 total: usize,
2037 current: Arc<AtomicUsize>,
2038 progress: Option<Sender<Progress>>,
2039) -> Result<SnapshotCompleteMultipartParams, Error> {
2040 let filesize = path.metadata()?.len() as usize;
2041 let n_parts = filesize.div_ceil(PART_SIZE);
2042 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2043
2044 let key = part.key.unwrap();
2045 let upload_id = part.upload_id;
2046
2047 let urls = part.urls.clone();
2048 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2049 EtagPart {
2050 etag: "".to_owned(),
2051 part_number: 0,
2052 };
2053 n_parts
2054 ]));
2055
2056 let tasks = (0..n_parts)
2057 .map(|part| {
2058 let http = http.clone();
2059 let url = urls[part].clone();
2060 let etags = etags.clone();
2061 let path = path.to_owned();
2062 let sem = sem.clone();
2063 let progress = progress.clone();
2064 let current = current.clone();
2065
2066 tokio::spawn(async move {
2067 let _permit = sem.acquire().await?;
2068 let mut etag = None;
2069
2070 for attempt in 0..MAX_RETRIES {
2071 match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
2072 {
2073 Ok(v) => {
2074 etag = Some(v);
2075 break;
2076 }
2077 Err(err) => {
2078 warn!("Upload Part Error: {:?}", err);
2079 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2080 }
2081 }
2082 }
2083
2084 if let Some(etag) = etag {
2085 let mut etags = etags.lock().await;
2086 etags[part] = EtagPart {
2087 etag,
2088 part_number: part + 1,
2089 };
2090
2091 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2092 if let Some(progress) = &progress {
2093 progress
2094 .send(Progress {
2095 current: current + PART_SIZE,
2096 total,
2097 })
2098 .await
2099 .unwrap();
2100 }
2101
2102 Ok(())
2103 } else {
2104 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2105 }
2106 })
2107 })
2108 .collect::<Vec<_>>();
2109
2110 join_all(tasks)
2111 .await
2112 .into_iter()
2113 .collect::<Result<Vec<_>, _>>()?;
2114
2115 Ok(SnapshotCompleteMultipartParams {
2116 key,
2117 upload_id,
2118 etag_list: etags.lock().await.clone(),
2119 })
2120}
2121
2122async fn upload_part(
2123 http: reqwest::Client,
2124 url: String,
2125 path: PathBuf,
2126 part: usize,
2127 n_parts: usize,
2128) -> Result<String, Error> {
2129 let filesize = path.metadata()?.len() as usize;
2130 let mut file = File::open(path).await.unwrap();
2131 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2132 .await
2133 .unwrap();
2134 let file = file.take(PART_SIZE as u64);
2135
2136 let body_length = if part + 1 == n_parts {
2137 filesize % PART_SIZE
2138 } else {
2139 PART_SIZE
2140 };
2141
2142 let stream = FramedRead::new(file, BytesCodec::new());
2143 let body = Body::wrap_stream(stream);
2144
2145 let resp = http
2146 .put(url.clone())
2147 .header(CONTENT_LENGTH, body_length)
2148 .body(body)
2149 .send()
2150 .await?
2151 .error_for_status()?;
2152 let etag = resp
2153 .headers()
2154 .get("etag")
2155 .unwrap()
2156 .to_str()
2157 .unwrap()
2158 .to_owned();
2159 Ok(etag
2161 .strip_prefix("\"")
2162 .unwrap()
2163 .strip_suffix("\"")
2164 .unwrap()
2165 .to_owned())
2166}
2167
2168async fn upload_file_to_presigned_url(
2173 http: reqwest::Client,
2174 url: &str,
2175 path: PathBuf,
2176) -> Result<(), Error> {
2177 let file_data = fs::read(&path).await?;
2179 let file_size = file_data.len();
2180
2181 for attempt in 1..=MAX_RETRIES {
2183 match http
2184 .put(url)
2185 .header(CONTENT_LENGTH, file_size)
2186 .body(file_data.clone())
2187 .send()
2188 .await
2189 {
2190 Ok(resp) => {
2191 if resp.status().is_success() {
2192 debug!(
2193 "Successfully uploaded file: {:?} ({} bytes)",
2194 path, file_size
2195 );
2196 return Ok(());
2197 } else {
2198 let status = resp.status();
2199 let error_text = resp.text().await.unwrap_or_default();
2200 warn!(
2201 "Upload failed [attempt {}/{}]: HTTP {} - {}",
2202 attempt, MAX_RETRIES, status, error_text
2203 );
2204 }
2205 }
2206 Err(err) => {
2207 warn!(
2208 "Upload error [attempt {}/{}]: {:?}",
2209 attempt, MAX_RETRIES, err
2210 );
2211 }
2212 }
2213
2214 if attempt < MAX_RETRIES {
2215 tokio::time::sleep(Duration::from_secs(attempt as u64)).await;
2216 }
2217 }
2218
2219 Err(Error::InvalidParameters(format!(
2220 "Failed to upload file {:?} after {} attempts",
2221 path, MAX_RETRIES
2222 )))
2223}