1use crate::{
5 Annotation, Error, Sample, Task,
6 api::{
7 AnnotationSetID, Artifact, DatasetID, Experiment, ExperimentID, LoginResult, Organization,
8 Project, ProjectID, SamplesCountResult, SamplesListParams, SamplesListResult, Snapshot,
9 SnapshotID, SnapshotRestore, SnapshotRestoreResult, Stage, TaskID, TaskInfo, TaskStages,
10 TaskStatus, TasksListParams, TasksListResult, TrainingSession, TrainingSessionID,
11 ValidationSession, ValidationSessionID,
12 },
13 dataset::{AnnotationSet, AnnotationType, Dataset, FileType, Label, NewLabel, NewLabelObject},
14};
15use base64::Engine as _;
16use chrono::{DateTime, Utc};
17use directories::ProjectDirs;
18use futures::{StreamExt as _, future::join_all};
19use log::{Level, debug, error, log_enabled, trace, warn};
20use reqwest::{Body, header::CONTENT_LENGTH, multipart::Form};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::{
23 collections::HashMap,
24 fs::create_dir_all,
25 io::{SeekFrom, Write as _},
26 path::{Path, PathBuf},
27 sync::{
28 Arc,
29 atomic::{AtomicUsize, Ordering},
30 },
31 time::Duration,
32 vec,
33};
34use tokio::{
35 fs::{self, File},
36 io::{AsyncReadExt as _, AsyncSeekExt as _, AsyncWriteExt as _},
37 sync::{RwLock, Semaphore, mpsc::Sender},
38};
39use tokio_util::codec::{BytesCodec, FramedRead};
40use walkdir::WalkDir;
41
42#[cfg(feature = "polars")]
43use polars::prelude::*;
44
45static MAX_TASKS: usize = 32;
46static MAX_RETRIES: u32 = 10;
47static PART_SIZE: usize = 100 * 1024 * 1024;
48
49#[derive(Debug, Clone)]
71pub struct Progress {
72 pub current: usize,
74 pub total: usize,
76}
77
78#[derive(Serialize)]
79struct RpcRequest<Params> {
80 id: u64,
81 jsonrpc: String,
82 method: String,
83 params: Option<Params>,
84}
85
86impl<T> Default for RpcRequest<T> {
87 fn default() -> Self {
88 RpcRequest {
89 id: 0,
90 jsonrpc: "2.0".to_string(),
91 method: "".to_string(),
92 params: None,
93 }
94 }
95}
96
97#[derive(Deserialize)]
98struct RpcError {
99 code: i32,
100 message: String,
101}
102
103#[derive(Deserialize)]
104struct RpcResponse<RpcResult> {
105 #[allow(dead_code)]
106 id: String,
107 #[allow(dead_code)]
108 jsonrpc: String,
109 error: Option<RpcError>,
110 result: Option<RpcResult>,
111}
112
113#[derive(Deserialize)]
114#[allow(dead_code)]
115struct EmptyResult {}
116
117#[derive(Debug, Serialize)]
118#[allow(dead_code)]
119struct SnapshotCreateParams {
120 snapshot_name: String,
121 keys: Vec<String>,
122}
123
124#[derive(Debug, Deserialize)]
125#[allow(dead_code)]
126struct SnapshotCreateResult {
127 snapshot_id: SnapshotID,
128 urls: Vec<String>,
129}
130
131#[derive(Debug, Serialize)]
132struct SnapshotCreateMultipartParams {
133 snapshot_name: String,
134 keys: Vec<String>,
135 file_sizes: Vec<usize>,
136}
137
138#[derive(Debug, Deserialize)]
139#[serde(untagged)]
140enum SnapshotCreateMultipartResultField {
141 Id(u64),
142 Part(SnapshotPart),
143}
144
145#[derive(Debug, Serialize)]
146struct SnapshotCompleteMultipartParams {
147 key: String,
148 upload_id: String,
149 etag_list: Vec<EtagPart>,
150}
151
152#[derive(Debug, Clone, Serialize)]
153struct EtagPart {
154 #[serde(rename = "ETag")]
155 etag: String,
156 #[serde(rename = "PartNumber")]
157 part_number: usize,
158}
159
160#[derive(Debug, Clone, Deserialize)]
161struct SnapshotPart {
162 key: Option<String>,
163 upload_id: String,
164 urls: Vec<String>,
165}
166
167#[derive(Debug, Serialize)]
168struct SnapshotStatusParams {
169 snapshot_id: SnapshotID,
170 status: String,
171}
172
173#[derive(Deserialize, Debug)]
174struct SnapshotStatusResult {
175 #[allow(dead_code)]
176 pub id: SnapshotID,
177 #[allow(dead_code)]
178 pub uid: String,
179 #[allow(dead_code)]
180 pub description: String,
181 #[allow(dead_code)]
182 pub date: String,
183 #[allow(dead_code)]
184 pub status: String,
185}
186
187#[derive(Serialize)]
188#[allow(dead_code)]
189struct ImageListParams {
190 images_filter: ImagesFilter,
191 image_files_filter: HashMap<String, String>,
192 only_ids: bool,
193}
194
195#[derive(Serialize)]
196#[allow(dead_code)]
197struct ImagesFilter {
198 dataset_id: DatasetID,
199}
200
201#[derive(Clone, Debug)]
250pub struct Client {
251 http: reqwest::Client,
252 url: String,
253 token: Arc<RwLock<String>>,
254 token_path: Option<PathBuf>,
255}
256
257impl Client {
258 pub fn new() -> Result<Self, Error> {
267 Ok(Client {
268 http: reqwest::Client::builder()
269 .read_timeout(Duration::from_secs(60))
270 .build()?,
271 url: "https://edgefirst.studio".to_string(),
272 token: Arc::new(tokio::sync::RwLock::new("".to_string())),
273 token_path: None,
274 })
275 }
276
277 pub fn with_server(&self, server: &str) -> Result<Self, Error> {
281 Ok(Client {
282 url: format!("https://{}.edgefirst.studio", server),
283 ..self.clone()
284 })
285 }
286
287 pub async fn with_login(&self, username: &str, password: &str) -> Result<Self, Error> {
290 let params = HashMap::from([("username", username), ("password", password)]);
291 let login: LoginResult = self
292 .rpc_without_auth("auth.login".to_owned(), Some(params))
293 .await?;
294 Ok(Client {
295 token: Arc::new(tokio::sync::RwLock::new(login.token)),
296 ..self.clone()
297 })
298 }
299
300 pub fn with_token_path(&self, token_path: Option<&Path>) -> Result<Self, Error> {
303 let token_path = match token_path {
304 Some(path) => path.to_path_buf(),
305 None => ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
306 .unwrap()
307 .config_dir()
308 .join("token"),
309 };
310
311 debug!("Using token path: {:?}", token_path);
312
313 let token = match token_path.exists() {
314 true => std::fs::read_to_string(&token_path)?,
315 false => "".to_string(),
316 };
317
318 if !token.is_empty() {
319 let client = self.with_token(&token)?;
320 Ok(Client {
321 token_path: Some(token_path),
322 ..client
323 })
324 } else {
325 Ok(Client {
326 token_path: Some(token_path),
327 ..self.clone()
328 })
329 }
330 }
331
332 pub fn with_token(&self, token: &str) -> Result<Self, Error> {
334 if token.is_empty() {
335 return Ok(self.clone());
336 }
337
338 let token_parts: Vec<&str> = token.split('.').collect();
339 if token_parts.len() != 3 {
340 return Err(Error::InvalidToken);
341 }
342
343 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
344 .decode(token_parts[1])
345 .unwrap();
346 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
347 let server = match payload.get("database") {
348 Some(value) => Ok(value.as_str().unwrap().to_string()),
349 None => Err(Error::InvalidToken),
350 }?;
351
352 Ok(Client {
353 url: format!("https://{}.edgefirst.studio", server),
354 token: Arc::new(tokio::sync::RwLock::new(token.to_string())),
355 ..self.clone()
356 })
357 }
358
359 pub async fn save_token(&self) -> Result<(), Error> {
360 let path = self.token_path.clone().unwrap_or_else(|| {
361 ProjectDirs::from("ai", "EdgeFirst", "EdgeFirst Studio")
362 .unwrap()
363 .config_dir()
364 .join("token")
365 });
366
367 create_dir_all(path.parent().unwrap())?;
368 let mut file = std::fs::File::create(&path)?;
369 file.write_all(self.token.read().await.as_bytes())?;
370
371 debug!("Saved token to {:?}", path);
372
373 Ok(())
374 }
375
376 pub async fn version(&self) -> Result<String, Error> {
379 let version: HashMap<String, String> = self
380 .rpc_without_auth::<(), HashMap<String, String>>("version".to_owned(), None)
381 .await?;
382 let version = version.get("version").ok_or(Error::InvalidResponse)?;
383 Ok(version.to_owned())
384 }
385
386 pub async fn logout(&self) -> Result<(), Error> {
390 {
391 let mut token = self.token.write().await;
392 *token = "".to_string();
393 }
394
395 if let Some(path) = &self.token_path
396 && path.exists()
397 {
398 fs::remove_file(path).await?;
399 }
400
401 Ok(())
402 }
403
404 pub async fn token(&self) -> String {
408 self.token.read().await.clone()
409 }
410
411 pub async fn verify_token(&self) -> Result<(), Error> {
416 self.rpc::<(), LoginResult>("auth.verify_token".to_owned(), None)
417 .await?;
418 Ok::<(), Error>(())
419 }
420
421 pub async fn renew_token(&self) -> Result<(), Error> {
426 let params = HashMap::from([("username".to_string(), self.username().await?)]);
427 let result: LoginResult = self
428 .rpc_without_auth("auth.refresh".to_owned(), Some(params))
429 .await?;
430
431 {
432 let mut token = self.token.write().await;
433 *token = result.token;
434 }
435
436 if self.token_path.is_some() {
437 self.save_token().await?;
438 }
439
440 Ok(())
441 }
442
443 async fn token_field(&self, field: &str) -> Result<serde_json::Value, Error> {
444 let token = self.token.read().await;
445 if token.is_empty() {
446 return Err(Error::EmptyToken);
447 }
448
449 let token_parts: Vec<&str> = token.split('.').collect();
450 if token_parts.len() != 3 {
451 return Err(Error::InvalidToken);
452 }
453
454 let decoded = base64::engine::general_purpose::STANDARD_NO_PAD
455 .decode(token_parts[1])
456 .unwrap();
457 let payload: HashMap<String, serde_json::Value> = serde_json::from_slice(&decoded)?;
458 match payload.get(field) {
459 Some(value) => Ok(value.to_owned()),
460 None => Err(Error::InvalidToken),
461 }
462 }
463
464 pub fn url(&self) -> &str {
466 &self.url
467 }
468
469 pub async fn username(&self) -> Result<String, Error> {
471 match self.token_field("username").await? {
472 serde_json::Value::String(username) => Ok(username),
473 _ => Err(Error::InvalidToken),
474 }
475 }
476
477 pub async fn token_expiration(&self) -> Result<DateTime<Utc>, Error> {
479 let ts = match self.token_field("exp").await? {
480 serde_json::Value::Number(exp) => exp.as_i64().ok_or(Error::InvalidToken)?,
481 _ => return Err(Error::InvalidToken),
482 };
483
484 match DateTime::<Utc>::from_timestamp_secs(ts) {
485 Some(dt) => Ok(dt),
486 None => Err(Error::InvalidToken),
487 }
488 }
489
490 pub async fn organization(&self) -> Result<Organization, Error> {
492 self.rpc::<(), Organization>("org.get".to_owned(), None)
493 .await
494 }
495
496 pub async fn projects(&self, name: Option<&str>) -> Result<Vec<Project>, Error> {
504 let projects = self
505 .rpc::<(), Vec<Project>>("project.list".to_owned(), None)
506 .await?;
507 if let Some(name) = name {
508 Ok(projects
509 .into_iter()
510 .filter(|p| p.name().contains(name))
511 .collect())
512 } else {
513 Ok(projects)
514 }
515 }
516
517 pub async fn project(&self, project_id: ProjectID) -> Result<Project, Error> {
520 let params = HashMap::from([("project_id", project_id)]);
521 self.rpc("project.get".to_owned(), Some(params)).await
522 }
523
524 pub async fn datasets(
528 &self,
529 project_id: ProjectID,
530 name: Option<&str>,
531 ) -> Result<Vec<Dataset>, Error> {
532 let params = HashMap::from([("project_id", project_id)]);
533 let datasets: Vec<Dataset> = self.rpc("dataset.list".to_owned(), Some(params)).await?;
534 if let Some(name) = name {
535 Ok(datasets
536 .into_iter()
537 .filter(|d| d.name().contains(name))
538 .collect())
539 } else {
540 Ok(datasets)
541 }
542 }
543
544 pub async fn dataset(&self, dataset_id: DatasetID) -> Result<Dataset, Error> {
547 let params = HashMap::from([("dataset_id", dataset_id)]);
548 self.rpc("dataset.get".to_owned(), Some(params)).await
549 }
550
551 pub async fn labels(&self, dataset_id: DatasetID) -> Result<Vec<Label>, Error> {
553 let params = HashMap::from([("dataset_id", dataset_id)]);
554 self.rpc("label.list".to_owned(), Some(params)).await
555 }
556
557 pub async fn add_label(&self, dataset_id: DatasetID, name: &str) -> Result<(), Error> {
559 let new_label = NewLabel {
560 dataset_id,
561 labels: vec![NewLabelObject {
562 name: name.to_owned(),
563 }],
564 };
565 let _: String = self.rpc("label.add2".to_owned(), Some(new_label)).await?;
566 Ok(())
567 }
568
569 pub async fn remove_label(&self, label_id: u64) -> Result<(), Error> {
572 let params = HashMap::from([("label_id", label_id)]);
573 let _: String = self.rpc("label.del".to_owned(), Some(params)).await?;
574 Ok(())
575 }
576
577 pub async fn create_dataset(
589 &self,
590 project_id: &str,
591 name: &str,
592 description: Option<&str>,
593 ) -> Result<DatasetID, Error> {
594 let mut params = HashMap::new();
595 params.insert("project_id", project_id);
596 params.insert("name", name);
597 if let Some(desc) = description {
598 params.insert("description", desc);
599 }
600
601 #[derive(Deserialize)]
602 struct CreateDatasetResult {
603 id: DatasetID,
604 }
605
606 let result: CreateDatasetResult =
607 self.rpc("dataset.create".to_owned(), Some(params)).await?;
608 Ok(result.id)
609 }
610
611 pub async fn delete_dataset(&self, dataset_id: DatasetID) -> Result<(), Error> {
621 let params = HashMap::from([("id", dataset_id)]);
622 let _: String = self.rpc("dataset.delete".to_owned(), Some(params)).await?;
623 Ok(())
624 }
625
626 pub async fn update_label(&self, label: &Label) -> Result<(), Error> {
630 #[derive(Serialize)]
631 struct Params {
632 dataset_id: DatasetID,
633 label_id: u64,
634 label_name: String,
635 label_index: u64,
636 }
637
638 let _: String = self
639 .rpc(
640 "label.update".to_owned(),
641 Some(Params {
642 dataset_id: label.dataset_id(),
643 label_id: label.id(),
644 label_name: label.name().to_owned(),
645 label_index: label.index(),
646 }),
647 )
648 .await?;
649 Ok(())
650 }
651
652 pub async fn download_dataset(
653 &self,
654 dataset_id: DatasetID,
655 groups: &[String],
656 file_types: &[FileType],
657 output: PathBuf,
658 progress: Option<Sender<Progress>>,
659 ) -> Result<(), Error> {
660 let samples = self
661 .samples(dataset_id, None, &[], groups, file_types, progress.clone())
662 .await?;
663 fs::create_dir_all(&output).await?;
664
665 let total = samples.len();
666 let current = Arc::new(AtomicUsize::new(0));
667 let sem = Arc::new(Semaphore::new(MAX_TASKS));
668
669 let tasks = samples
670 .into_iter()
671 .map(|sample| {
672 let sem = sem.clone();
673 let client = self.clone();
674 let current = current.clone();
675 let progress = progress.clone();
676 let file_types = file_types.to_vec();
677 let output = output.clone();
678
679 tokio::spawn(async move {
680 let _permit = sem.acquire().await.unwrap();
681
682 for file_type in file_types {
683 if let Some(data) = sample.download(&client, file_type.clone()).await? {
684 let file_ext = match file_type {
685 FileType::Image => infer::get(&data)
686 .expect("Failed to identify image file format for sample")
687 .extension()
688 .to_string(),
689 t => t.to_string(),
690 };
691
692 let file_name = format!(
693 "{}.{}",
694 sample.name().unwrap_or_else(|| "unknown".to_string()),
695 file_ext
696 );
697 let file_path = output.join(&file_name);
698
699 let mut file = File::create(&file_path).await?;
700 file.write_all(&data).await?;
701 } else {
702 warn!(
703 "No data for sample: {}",
704 sample
705 .id()
706 .map(|id| id.to_string())
707 .unwrap_or_else(|| "unknown".to_string())
708 );
709 }
710 }
711
712 if let Some(progress) = &progress {
713 let current = current.fetch_add(1, Ordering::SeqCst);
714 progress
715 .send(Progress {
716 current: current + 1,
717 total,
718 })
719 .await
720 .unwrap();
721 }
722
723 Ok::<(), Error>(())
724 })
725 })
726 .collect::<Vec<_>>();
727
728 join_all(tasks)
729 .await
730 .into_iter()
731 .collect::<Result<Vec<_>, _>>()?;
732
733 if let Some(progress) = progress {
734 drop(progress);
735 }
736
737 Ok(())
738 }
739
740 pub async fn annotation_sets(
742 &self,
743 dataset_id: DatasetID,
744 ) -> Result<Vec<AnnotationSet>, Error> {
745 let params = HashMap::from([("dataset_id", dataset_id)]);
746 self.rpc("annset.list".to_owned(), Some(params)).await
747 }
748
749 pub async fn create_annotation_set(
761 &self,
762 dataset_id: DatasetID,
763 name: &str,
764 description: Option<&str>,
765 ) -> Result<AnnotationSetID, Error> {
766 #[derive(Serialize)]
767 struct Params<'a> {
768 dataset_id: DatasetID,
769 name: &'a str,
770 operator: &'a str,
771 #[serde(skip_serializing_if = "Option::is_none")]
772 description: Option<&'a str>,
773 }
774
775 #[derive(Deserialize)]
776 struct CreateAnnotationSetResult {
777 id: AnnotationSetID,
778 }
779
780 let username = self.username().await?;
781 let result: CreateAnnotationSetResult = self
782 .rpc(
783 "annset.add".to_owned(),
784 Some(Params {
785 dataset_id,
786 name,
787 operator: &username,
788 description,
789 }),
790 )
791 .await?;
792 Ok(result.id)
793 }
794
795 pub async fn delete_annotation_set(
806 &self,
807 annotation_set_id: AnnotationSetID,
808 ) -> Result<(), Error> {
809 let params = HashMap::from([("id", annotation_set_id)]);
810 let _: String = self.rpc("annset.delete".to_owned(), Some(params)).await?;
811 Ok(())
812 }
813
814 pub async fn annotation_set(
816 &self,
817 annotation_set_id: AnnotationSetID,
818 ) -> Result<AnnotationSet, Error> {
819 let params = HashMap::from([("annotation_set_id", annotation_set_id)]);
820 self.rpc("annset.get".to_owned(), Some(params)).await
821 }
822
823 pub async fn annotations(
836 &self,
837 annotation_set_id: AnnotationSetID,
838 groups: &[String],
839 annotation_types: &[AnnotationType],
840 progress: Option<Sender<Progress>>,
841 ) -> Result<Vec<Annotation>, Error> {
842 let dataset_id = self.annotation_set(annotation_set_id).await?.dataset_id();
843 let labels = self
844 .labels(dataset_id)
845 .await?
846 .into_iter()
847 .map(|label| (label.name().to_string(), label.index()))
848 .collect::<HashMap<_, _>>();
849 let total = self
850 .samples_count(
851 dataset_id,
852 Some(annotation_set_id),
853 annotation_types,
854 groups,
855 &[],
856 )
857 .await?
858 .total as usize;
859 let mut annotations = vec![];
860 let mut continue_token: Option<String> = None;
861 let mut current = 0;
862
863 if total == 0 {
864 return Ok(annotations);
865 }
866
867 loop {
868 let params = SamplesListParams {
869 dataset_id,
870 annotation_set_id: Some(annotation_set_id),
871 types: annotation_types.iter().map(|t| t.to_string()).collect(),
872 group_names: groups.to_vec(),
873 continue_token,
874 };
875
876 let result: SamplesListResult =
877 self.rpc("samples.list".to_owned(), Some(params)).await?;
878 current += result.samples.len();
879 continue_token = result.continue_token;
880
881 if result.samples.is_empty() {
882 break;
883 }
884
885 for sample in result.samples {
886 if sample.annotations().is_empty() {
889 let mut annotation = Annotation::new();
890 annotation.set_sample_id(sample.id());
891 annotation.set_name(sample.name());
892 annotation.set_group(sample.group().cloned());
893 annotation.set_sequence_name(sample.sequence_name().cloned());
894 annotations.push(annotation);
895 continue;
896 }
897
898 sample.annotations().iter().for_each(|annotation| {
899 let mut annotation = annotation.clone();
900 annotation.set_sample_id(sample.id());
901 annotation.set_name(sample.name());
902 annotation.set_group(sample.group().cloned());
903 annotation.set_sequence_name(sample.sequence_name().cloned());
904 annotation.set_label_index(Some(labels[annotation.label().unwrap().as_str()]));
905 annotations.push(annotation);
906 });
907 }
908
909 if let Some(progress) = &progress {
910 progress.send(Progress { current, total }).await.unwrap();
911 }
912
913 match &continue_token {
914 Some(token) if !token.is_empty() => continue,
915 _ => break,
916 }
917 }
918
919 if let Some(progress) = progress {
920 drop(progress);
921 }
922
923 Ok(annotations)
924 }
925
926 pub async fn samples_count(
927 &self,
928 dataset_id: DatasetID,
929 annotation_set_id: Option<AnnotationSetID>,
930 annotation_types: &[AnnotationType],
931 groups: &[String],
932 types: &[FileType],
933 ) -> Result<SamplesCountResult, Error> {
934 let types = annotation_types
935 .iter()
936 .map(|t| t.to_string())
937 .chain(types.iter().map(|t| t.to_string()))
938 .collect::<Vec<_>>();
939
940 let params = SamplesListParams {
941 dataset_id,
942 annotation_set_id,
943 group_names: groups.to_vec(),
944 types,
945 continue_token: None,
946 };
947
948 self.rpc("samples.count".to_owned(), Some(params)).await
949 }
950
951 pub async fn samples(
952 &self,
953 dataset_id: DatasetID,
954 annotation_set_id: Option<AnnotationSetID>,
955 annotation_types: &[AnnotationType],
956 groups: &[String],
957 types: &[FileType],
958 progress: Option<Sender<Progress>>,
959 ) -> Result<Vec<Sample>, Error> {
960 let types = annotation_types
961 .iter()
962 .map(|t| t.to_string())
963 .chain(types.iter().map(|t| t.to_string()))
964 .collect::<Vec<_>>();
965 let labels = self
966 .labels(dataset_id)
967 .await?
968 .into_iter()
969 .map(|label| (label.name().to_string(), label.index()))
970 .collect::<HashMap<_, _>>();
971 let total = self
972 .samples_count(dataset_id, annotation_set_id, annotation_types, groups, &[])
973 .await?
974 .total as usize;
975
976 let mut samples = vec![];
977 let mut continue_token: Option<String> = None;
978 let mut current = 0;
979
980 if total == 0 {
981 return Ok(samples);
982 }
983
984 loop {
985 let params = SamplesListParams {
986 dataset_id,
987 annotation_set_id,
988 types: types.clone(),
989 group_names: groups.to_vec(),
990 continue_token: continue_token.clone(),
991 };
992
993 let result: SamplesListResult =
994 self.rpc("samples.list".to_owned(), Some(params)).await?;
995 current += result.samples.len();
996 continue_token = result.continue_token;
997
998 if result.samples.is_empty() {
999 break;
1000 }
1001
1002 samples.append(
1003 &mut result
1004 .samples
1005 .into_iter()
1006 .map(|s| {
1007 let mut anns = s.annotations().to_vec();
1008 for ann in &mut anns {
1009 if let Some(label) = ann.label() {
1010 ann.set_label_index(Some(labels[label.as_str()]));
1011 }
1012 }
1013 s.with_annotations(anns)
1014 })
1015 .collect::<Vec<_>>(),
1016 );
1017
1018 if let Some(progress) = &progress {
1019 progress.send(Progress { current, total }).await.unwrap();
1020 }
1021
1022 match &continue_token {
1023 Some(token) if !token.is_empty() => continue,
1024 _ => break,
1025 }
1026 }
1027
1028 if let Some(progress) = progress {
1029 drop(progress);
1030 }
1031
1032 Ok(samples)
1033 }
1034
1035 pub async fn populate_samples(
1127 &self,
1128 dataset_id: DatasetID,
1129 annotation_set_id: Option<AnnotationSetID>,
1130 samples: Vec<Sample>,
1131 progress: Option<Sender<Progress>>,
1132 ) -> Result<Vec<crate::SamplesPopulateResult>, Error> {
1133 use crate::api::SamplesPopulateParams;
1134 use std::path::Path;
1135
1136 let total = samples.len();
1137
1138 let mut files_to_upload: Vec<(String, String, PathBuf, String)> = Vec::new();
1141
1142 let samples: Vec<Sample> = samples
1144 .into_iter()
1145 .map(|mut sample| {
1146 if sample.uuid.is_none() {
1148 sample.uuid = Some(uuid::Uuid::new_v4().to_string());
1149 }
1150
1151 let sample_uuid = sample.uuid.clone().unwrap();
1152
1153 let updated_files: Vec<crate::SampleFile> = sample
1155 .files
1156 .iter()
1157 .map(|file| {
1158 if let Some(filename) = file.filename() {
1159 let path = Path::new(filename);
1160
1161 if path.exists() && path.is_file() {
1163 if let Some(basename) = path.file_name().and_then(|s| s.to_str()) {
1165 if file.file_type() == "image"
1167 && (sample.width.is_none() || sample.height.is_none())
1168 && let Ok(size) = imagesize::size(path)
1169 {
1170 sample.width = Some(size.width as u32);
1171 sample.height = Some(size.height as u32);
1172 }
1173
1174 files_to_upload.push((
1176 sample_uuid.clone(),
1177 file.file_type().to_string(),
1178 path.to_path_buf(),
1179 basename.to_string(),
1180 ));
1181
1182 return crate::SampleFile::with_filename(
1184 file.file_type().to_string(),
1185 basename.to_string(),
1186 );
1187 }
1188 }
1189 }
1190 file.clone()
1192 })
1193 .collect();
1194
1195 sample.files = updated_files;
1196 sample
1197 })
1198 .collect();
1199
1200 let has_files_to_upload = !files_to_upload.is_empty();
1201
1202 let params = SamplesPopulateParams {
1204 dataset_id,
1205 annotation_set_id,
1206 presigned_urls: if has_files_to_upload {
1207 Some(true)
1208 } else {
1209 Some(false)
1210 },
1211 samples,
1212 };
1213
1214 let results: Vec<crate::SamplesPopulateResult> = self
1215 .rpc("samples.populate".to_owned(), Some(params))
1216 .await?;
1217
1218 if has_files_to_upload {
1220 let mut upload_map: HashMap<(String, String), PathBuf> = HashMap::new();
1222 for (uuid, _file_type, path, basename) in files_to_upload {
1223 upload_map.insert((uuid, basename), path);
1224 }
1225
1226 let current = Arc::new(AtomicUsize::new(0));
1227 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1228
1229 let upload_tasks = results
1231 .iter()
1232 .map(|result| {
1233 let sem = sem.clone();
1234 let http = self.http.clone();
1235 let current = current.clone();
1236 let progress = progress.clone();
1237 let result_uuid = result.uuid.clone();
1238 let urls = result.urls.clone();
1239 let upload_map = upload_map.clone();
1240
1241 tokio::spawn(async move {
1242 let _permit = sem.acquire().await.unwrap();
1243
1244 for url_info in &urls {
1246 if let Some(local_path) =
1247 upload_map.get(&(result_uuid.clone(), url_info.filename.clone()))
1248 {
1249 upload_file_to_presigned_url(
1251 http.clone(),
1252 &url_info.url,
1253 local_path.clone(),
1254 )
1255 .await?;
1256 }
1257 }
1258
1259 if let Some(progress) = &progress {
1261 let current = current.fetch_add(1, Ordering::SeqCst);
1262 progress
1263 .send(Progress {
1264 current: current + 1,
1265 total,
1266 })
1267 .await
1268 .unwrap();
1269 }
1270
1271 Ok::<(), Error>(())
1272 })
1273 })
1274 .collect::<Vec<_>>();
1275
1276 join_all(upload_tasks)
1277 .await
1278 .into_iter()
1279 .collect::<Result<Vec<_>, _>>()?;
1280 }
1281
1282 if let Some(progress) = progress {
1283 drop(progress);
1284 }
1285
1286 Ok(results)
1287 }
1288
1289 pub async fn download(&self, url: &str) -> Result<Vec<u8>, Error> {
1290 for attempt in 1..MAX_RETRIES {
1291 let resp = match self.http.get(url).send().await {
1292 Ok(resp) => resp,
1293 Err(err) => {
1294 warn!(
1295 "Socket Error [retry {}/{}]: {:?}",
1296 attempt, MAX_RETRIES, err
1297 );
1298 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1299 continue;
1300 }
1301 };
1302
1303 match resp.bytes().await {
1304 Ok(body) => return Ok(body.to_vec()),
1305 Err(err) => {
1306 warn!("HTTP Error [retry {}/{}]: {:?}", attempt, MAX_RETRIES, err);
1307 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
1308 continue;
1309 }
1310 };
1311 }
1312
1313 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
1314 }
1315
1316 #[cfg(feature = "polars")]
1327 pub async fn annotations_dataframe(
1328 &self,
1329 annotation_set_id: AnnotationSetID,
1330 groups: &[String],
1331 types: &[AnnotationType],
1332 progress: Option<Sender<Progress>>,
1333 ) -> Result<DataFrame, Error> {
1334 use crate::dataset::annotations_dataframe;
1335
1336 let annotations = self
1337 .annotations(annotation_set_id, groups, types, progress)
1338 .await?;
1339 Ok(annotations_dataframe(&annotations))
1340 }
1341
1342 pub async fn snapshots(&self, name: Option<&str>) -> Result<Vec<Snapshot>, Error> {
1345 let snapshots: Vec<Snapshot> = self
1346 .rpc::<(), Vec<Snapshot>>("snapshots.list".to_owned(), None)
1347 .await?;
1348 if let Some(name) = name {
1349 Ok(snapshots
1350 .into_iter()
1351 .filter(|s| s.description().contains(name))
1352 .collect())
1353 } else {
1354 Ok(snapshots)
1355 }
1356 }
1357
1358 pub async fn snapshot(&self, snapshot_id: SnapshotID) -> Result<Snapshot, Error> {
1360 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1361 self.rpc("snapshots.get".to_owned(), Some(params)).await
1362 }
1363
1364 pub async fn create_snapshot(
1371 &self,
1372 path: &str,
1373 progress: Option<Sender<Progress>>,
1374 ) -> Result<Snapshot, Error> {
1375 let path = Path::new(path);
1376
1377 if path.is_dir() {
1378 return self
1379 .create_snapshot_folder(path.to_str().unwrap(), progress)
1380 .await;
1381 }
1382
1383 let name = path.file_name().unwrap().to_str().unwrap();
1384 let total = path.metadata()?.len() as usize;
1385 let current = Arc::new(AtomicUsize::new(0));
1386
1387 if let Some(progress) = &progress {
1388 progress.send(Progress { current: 0, total }).await.unwrap();
1389 }
1390
1391 let params = SnapshotCreateMultipartParams {
1392 snapshot_name: name.to_owned(),
1393 keys: vec![name.to_owned()],
1394 file_sizes: vec![total],
1395 };
1396 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1397 .rpc(
1398 "snapshots.create_upload_url_multipart".to_owned(),
1399 Some(params),
1400 )
1401 .await?;
1402
1403 let snapshot_id = match multipart.get("snapshot_id") {
1404 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1405 _ => return Err(Error::InvalidResponse),
1406 };
1407
1408 let snapshot = self.snapshot(snapshot_id).await?;
1409 let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1410 let part_key = format!("{}/{}", part_prefix, name);
1411 let mut part = match multipart.get(&part_key) {
1412 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1413 _ => return Err(Error::InvalidResponse),
1414 }
1415 .clone();
1416 part.key = Some(part_key);
1417
1418 let params = upload_multipart(
1419 self.http.clone(),
1420 part.clone(),
1421 path.to_path_buf(),
1422 total,
1423 current,
1424 progress.clone(),
1425 )
1426 .await?;
1427
1428 let complete: String = self
1429 .rpc(
1430 "snapshots.complete_multipart_upload".to_owned(),
1431 Some(params),
1432 )
1433 .await?;
1434 debug!("Snapshot Multipart Complete: {:?}", complete);
1435
1436 let params: SnapshotStatusParams = SnapshotStatusParams {
1437 snapshot_id,
1438 status: "available".to_owned(),
1439 };
1440 let _: SnapshotStatusResult = self
1441 .rpc("snapshots.update".to_owned(), Some(params))
1442 .await?;
1443
1444 if let Some(progress) = progress {
1445 drop(progress);
1446 }
1447
1448 self.snapshot(snapshot_id).await
1449 }
1450
1451 async fn create_snapshot_folder(
1452 &self,
1453 path: &str,
1454 progress: Option<Sender<Progress>>,
1455 ) -> Result<Snapshot, Error> {
1456 let path = Path::new(path);
1457 let name = path.file_name().unwrap().to_str().unwrap();
1458
1459 let files = WalkDir::new(path)
1460 .into_iter()
1461 .filter_map(|entry| entry.ok())
1462 .filter(|entry| entry.file_type().is_file())
1463 .map(|entry| entry.path().strip_prefix(path).unwrap().to_owned())
1464 .collect::<Vec<_>>();
1465
1466 let total = files
1467 .iter()
1468 .map(|file| path.join(file).metadata().unwrap().len() as usize)
1469 .sum();
1470 let current = Arc::new(AtomicUsize::new(0));
1471
1472 if let Some(progress) = &progress {
1473 progress.send(Progress { current: 0, total }).await.unwrap();
1474 }
1475
1476 let keys = files
1477 .iter()
1478 .map(|key| key.to_str().unwrap().to_owned())
1479 .collect::<Vec<_>>();
1480 let file_sizes = files
1481 .iter()
1482 .map(|key| path.join(key).metadata().unwrap().len() as usize)
1483 .collect::<Vec<_>>();
1484
1485 let params = SnapshotCreateMultipartParams {
1486 snapshot_name: name.to_owned(),
1487 keys,
1488 file_sizes,
1489 };
1490
1491 let multipart: HashMap<String, SnapshotCreateMultipartResultField> = self
1492 .rpc(
1493 "snapshots.create_upload_url_multipart".to_owned(),
1494 Some(params),
1495 )
1496 .await?;
1497
1498 let snapshot_id = match multipart.get("snapshot_id") {
1499 Some(SnapshotCreateMultipartResultField::Id(id)) => SnapshotID::from(*id),
1500 _ => return Err(Error::InvalidResponse),
1501 };
1502
1503 let snapshot = self.snapshot(snapshot_id).await?;
1504 let part_prefix = snapshot.path().split("::/").last().unwrap().to_owned();
1505
1506 for file in files {
1507 let part_key = format!("{}/{}", part_prefix, file.to_str().unwrap());
1508 let mut part = match multipart.get(&part_key) {
1509 Some(SnapshotCreateMultipartResultField::Part(part)) => part,
1510 _ => return Err(Error::InvalidResponse),
1511 }
1512 .clone();
1513 part.key = Some(part_key);
1514
1515 let params = upload_multipart(
1516 self.http.clone(),
1517 part.clone(),
1518 path.join(file),
1519 total,
1520 current.clone(),
1521 progress.clone(),
1522 )
1523 .await?;
1524
1525 let complete: String = self
1526 .rpc(
1527 "snapshots.complete_multipart_upload".to_owned(),
1528 Some(params),
1529 )
1530 .await?;
1531 debug!("Snapshot Part Complete: {:?}", complete);
1532 }
1533
1534 let params = SnapshotStatusParams {
1535 snapshot_id,
1536 status: "available".to_owned(),
1537 };
1538 let _: SnapshotStatusResult = self
1539 .rpc("snapshots.update".to_owned(), Some(params))
1540 .await?;
1541
1542 if let Some(progress) = progress {
1543 drop(progress);
1544 }
1545
1546 self.snapshot(snapshot_id).await
1547 }
1548
1549 pub async fn download_snapshot(
1554 &self,
1555 snapshot_id: SnapshotID,
1556 output: PathBuf,
1557 progress: Option<Sender<Progress>>,
1558 ) -> Result<(), Error> {
1559 fs::create_dir_all(&output).await?;
1560
1561 let params = HashMap::from([("snapshot_id", snapshot_id)]);
1562 let items: HashMap<String, String> = self
1563 .rpc("snapshots.create_download_url".to_owned(), Some(params))
1564 .await?;
1565
1566 let total = Arc::new(AtomicUsize::new(0));
1567 let current = Arc::new(AtomicUsize::new(0));
1568 let sem = Arc::new(Semaphore::new(MAX_TASKS));
1569
1570 let tasks = items
1571 .iter()
1572 .map(|(key, url)| {
1573 let http = self.http.clone();
1574 let key = key.clone();
1575 let url = url.clone();
1576 let output = output.clone();
1577 let progress = progress.clone();
1578 let current = current.clone();
1579 let total = total.clone();
1580 let sem = sem.clone();
1581
1582 tokio::spawn(async move {
1583 let _permit = sem.acquire().await.unwrap();
1584 let res = http.get(url).send().await.unwrap();
1585 let content_length = res.content_length().unwrap() as usize;
1586
1587 if let Some(progress) = &progress {
1588 let total = total.fetch_add(content_length, Ordering::SeqCst);
1589 progress
1590 .send(Progress {
1591 current: current.load(Ordering::SeqCst),
1592 total: total + content_length,
1593 })
1594 .await
1595 .unwrap();
1596 }
1597
1598 let mut file = File::create(output.join(key)).await.unwrap();
1599 let mut stream = res.bytes_stream();
1600
1601 while let Some(chunk) = stream.next().await {
1602 let chunk = chunk.unwrap();
1603 file.write_all(&chunk).await.unwrap();
1604 let len = chunk.len();
1605
1606 if let Some(progress) = &progress {
1607 let total = total.load(Ordering::SeqCst);
1608 let current = current.fetch_add(len, Ordering::SeqCst);
1609
1610 progress
1611 .send(Progress {
1612 current: current + len,
1613 total,
1614 })
1615 .await
1616 .unwrap();
1617 }
1618 }
1619 })
1620 })
1621 .collect::<Vec<_>>();
1622
1623 join_all(tasks)
1624 .await
1625 .into_iter()
1626 .collect::<Result<Vec<_>, _>>()
1627 .unwrap();
1628
1629 Ok(())
1630 }
1631
1632 #[allow(clippy::too_many_arguments)]
1647 pub async fn restore_snapshot(
1648 &self,
1649 project_id: ProjectID,
1650 snapshot_id: SnapshotID,
1651 topics: &[String],
1652 autolabel: &[String],
1653 autodepth: bool,
1654 dataset_name: Option<&str>,
1655 dataset_description: Option<&str>,
1656 ) -> Result<SnapshotRestoreResult, Error> {
1657 let params = SnapshotRestore {
1658 project_id,
1659 snapshot_id,
1660 fps: 1,
1661 autodepth,
1662 agtg_pipeline: !autolabel.is_empty(),
1663 autolabel: autolabel.to_vec(),
1664 topics: topics.to_vec(),
1665 dataset_name: dataset_name.map(|s| s.to_owned()),
1666 dataset_description: dataset_description.map(|s| s.to_owned()),
1667 };
1668 self.rpc("snapshots.restore".to_owned(), Some(params)).await
1669 }
1670
1671 pub async fn experiments(
1680 &self,
1681 project_id: ProjectID,
1682 name: Option<&str>,
1683 ) -> Result<Vec<Experiment>, Error> {
1684 let params = HashMap::from([("project_id", project_id)]);
1685 let experiments: Vec<Experiment> =
1686 self.rpc("trainer.list2".to_owned(), Some(params)).await?;
1687 if let Some(name) = name {
1688 Ok(experiments
1689 .into_iter()
1690 .filter(|e| e.name().contains(name))
1691 .collect())
1692 } else {
1693 Ok(experiments)
1694 }
1695 }
1696
1697 pub async fn experiment(&self, experiment_id: ExperimentID) -> Result<Experiment, Error> {
1700 let params = HashMap::from([("trainer_id", experiment_id)]);
1701 self.rpc("trainer.get".to_owned(), Some(params)).await
1702 }
1703
1704 pub async fn training_sessions(
1713 &self,
1714 experiment_id: ExperimentID,
1715 name: Option<&str>,
1716 ) -> Result<Vec<TrainingSession>, Error> {
1717 let params = HashMap::from([("trainer_id", experiment_id)]);
1718 let sessions: Vec<TrainingSession> = self
1719 .rpc("trainer.session.list".to_owned(), Some(params))
1720 .await?;
1721 if let Some(name) = name {
1722 Ok(sessions
1723 .into_iter()
1724 .filter(|s| s.name().contains(name))
1725 .collect())
1726 } else {
1727 Ok(sessions)
1728 }
1729 }
1730
1731 pub async fn training_session(
1734 &self,
1735 session_id: TrainingSessionID,
1736 ) -> Result<TrainingSession, Error> {
1737 let params = HashMap::from([("trainer_session_id", session_id)]);
1738 self.rpc("trainer.session.get".to_owned(), Some(params))
1739 .await
1740 }
1741
1742 pub async fn validation_sessions(
1744 &self,
1745 project_id: ProjectID,
1746 ) -> Result<Vec<ValidationSession>, Error> {
1747 let params = HashMap::from([("project_id", project_id)]);
1748 self.rpc("validate.session.list".to_owned(), Some(params))
1749 .await
1750 }
1751
1752 pub async fn validation_session(
1754 &self,
1755 session_id: ValidationSessionID,
1756 ) -> Result<ValidationSession, Error> {
1757 let params = HashMap::from([("validate_session_id", session_id)]);
1758 self.rpc("validate.session.get".to_owned(), Some(params))
1759 .await
1760 }
1761
1762 pub async fn artifacts(
1765 &self,
1766 training_session_id: TrainingSessionID,
1767 ) -> Result<Vec<Artifact>, Error> {
1768 let params = HashMap::from([("training_session_id", training_session_id)]);
1769 self.rpc("trainer.get_artifacts".to_owned(), Some(params))
1770 .await
1771 }
1772
1773 pub async fn download_artifact(
1779 &self,
1780 training_session_id: TrainingSessionID,
1781 modelname: &str,
1782 filename: Option<PathBuf>,
1783 progress: Option<Sender<Progress>>,
1784 ) -> Result<(), Error> {
1785 let filename = filename.unwrap_or_else(|| PathBuf::from(modelname));
1786 let resp = self
1787 .http
1788 .get(format!(
1789 "{}/download_model?training_session_id={}&file={}",
1790 self.url,
1791 training_session_id.value(),
1792 modelname
1793 ))
1794 .header("Authorization", format!("Bearer {}", self.token().await))
1795 .send()
1796 .await?;
1797 if !resp.status().is_success() {
1798 let err = resp.error_for_status_ref().unwrap_err();
1799 return Err(Error::HttpError(err));
1800 }
1801
1802 fs::create_dir_all(filename.parent().unwrap()).await?;
1803
1804 if let Some(progress) = progress {
1805 let total = resp.content_length().unwrap() as usize;
1806 progress.send(Progress { current: 0, total }).await.unwrap();
1807
1808 let mut file = File::create(filename).await?;
1809 let mut current = 0;
1810 let mut stream = resp.bytes_stream();
1811
1812 while let Some(item) = stream.next().await {
1813 let chunk = item?;
1814 file.write_all(&chunk).await?;
1815 current += chunk.len();
1816 progress.send(Progress { current, total }).await.unwrap();
1817 }
1818 } else {
1819 let body = resp.bytes().await?;
1820 fs::write(filename, body).await?;
1821 }
1822
1823 Ok(())
1824 }
1825
1826 pub async fn download_checkpoint(
1836 &self,
1837 training_session_id: TrainingSessionID,
1838 checkpoint: &str,
1839 filename: Option<PathBuf>,
1840 progress: Option<Sender<Progress>>,
1841 ) -> Result<(), Error> {
1842 let filename = filename.unwrap_or_else(|| PathBuf::from(checkpoint));
1843 let resp = self
1844 .http
1845 .get(format!(
1846 "{}/download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1847 self.url,
1848 training_session_id.value(),
1849 checkpoint
1850 ))
1851 .header("Authorization", format!("Bearer {}", self.token().await))
1852 .send()
1853 .await?;
1854 if !resp.status().is_success() {
1855 let err = resp.error_for_status_ref().unwrap_err();
1856 return Err(Error::HttpError(err));
1857 }
1858
1859 fs::create_dir_all(filename.parent().unwrap()).await?;
1860
1861 if let Some(progress) = progress {
1862 let total = resp.content_length().unwrap() as usize;
1863 progress.send(Progress { current: 0, total }).await.unwrap();
1864
1865 let mut file = File::create(filename).await?;
1866 let mut current = 0;
1867 let mut stream = resp.bytes_stream();
1868
1869 while let Some(item) = stream.next().await {
1870 let chunk = item?;
1871 file.write_all(&chunk).await?;
1872 current += chunk.len();
1873 progress.send(Progress { current, total }).await.unwrap();
1874 }
1875 } else {
1876 let body = resp.bytes().await?;
1877 fs::write(filename, body).await?;
1878 }
1879
1880 Ok(())
1881 }
1882
1883 pub async fn tasks(
1885 &self,
1886 name: Option<&str>,
1887 workflow: Option<&str>,
1888 status: Option<&str>,
1889 manager: Option<&str>,
1890 ) -> Result<Vec<Task>, Error> {
1891 let mut params = TasksListParams {
1892 continue_token: None,
1893 status: status.map(|s| vec![s.to_owned()]),
1894 manager: manager.map(|m| vec![m.to_owned()]),
1895 };
1896 let mut tasks = Vec::new();
1897
1898 loop {
1899 let result = self
1900 .rpc::<_, TasksListResult>("task.list".to_owned(), Some(¶ms))
1901 .await?;
1902 tasks.extend(result.tasks);
1903
1904 if result.continue_token.is_none() || result.continue_token == Some("".into()) {
1905 params.continue_token = None;
1906 } else {
1907 params.continue_token = result.continue_token;
1908 }
1909
1910 if params.continue_token.is_none() {
1911 break;
1912 }
1913 }
1914
1915 if let Some(name) = name {
1916 tasks.retain(|t| t.name().contains(name));
1917 }
1918
1919 if let Some(workflow) = workflow {
1920 tasks.retain(|t| t.workflow().contains(workflow));
1921 }
1922
1923 Ok(tasks)
1924 }
1925
1926 pub async fn task_info(&self, task_id: TaskID) -> Result<TaskInfo, Error> {
1928 self.rpc(
1929 "task.get".to_owned(),
1930 Some(HashMap::from([("id", task_id)])),
1931 )
1932 .await
1933 }
1934
1935 pub async fn task_status(&self, task_id: TaskID, status: &str) -> Result<Task, Error> {
1937 let status = TaskStatus {
1938 task_id,
1939 status: status.to_owned(),
1940 };
1941 self.rpc("docker.update.status".to_owned(), Some(status))
1942 .await
1943 }
1944
1945 pub async fn set_stages(&self, task_id: TaskID, stages: &[(&str, &str)]) -> Result<(), Error> {
1949 let stages: Vec<HashMap<String, String>> = stages
1950 .iter()
1951 .map(|(key, value)| {
1952 let mut stage_map = HashMap::new();
1953 stage_map.insert(key.to_string(), value.to_string());
1954 stage_map
1955 })
1956 .collect();
1957 let params = TaskStages { task_id, stages };
1958 let _: Task = self.rpc("status.stages".to_owned(), Some(params)).await?;
1959 Ok(())
1960 }
1961
1962 pub async fn update_stage(
1965 &self,
1966 task_id: TaskID,
1967 stage: &str,
1968 status: &str,
1969 message: &str,
1970 percentage: u8,
1971 ) -> Result<(), Error> {
1972 let stage = Stage::new(
1973 Some(task_id),
1974 stage.to_owned(),
1975 Some(status.to_owned()),
1976 Some(message.to_owned()),
1977 percentage,
1978 );
1979 let _: Task = self.rpc("status.update".to_owned(), Some(stage)).await?;
1980 Ok(())
1981 }
1982
1983 pub async fn fetch(&self, query: &str) -> Result<Vec<u8>, Error> {
1985 let req = self
1986 .http
1987 .get(format!("{}/{}", self.url, query))
1988 .header("User-Agent", "EdgeFirst Client")
1989 .header("Authorization", format!("Bearer {}", self.token().await));
1990 let resp = req.send().await?;
1991
1992 if resp.status().is_success() {
1993 let body = resp.bytes().await?;
1994
1995 if log_enabled!(Level::Trace) {
1996 trace!("Fetch Response: {}", String::from_utf8_lossy(&body));
1997 }
1998
1999 Ok(body.to_vec())
2000 } else {
2001 let err = resp.error_for_status_ref().unwrap_err();
2002 Err(Error::HttpError(err))
2003 }
2004 }
2005
2006 pub async fn post_multipart(&self, method: &str, form: Form) -> Result<String, Error> {
2010 let req = self
2011 .http
2012 .post(format!("{}/api?method={}", self.url, method))
2013 .header("Accept", "application/json")
2014 .header("User-Agent", "EdgeFirst Client")
2015 .header("Authorization", format!("Bearer {}", self.token().await))
2016 .multipart(form);
2017 let resp = req.send().await?;
2018
2019 if resp.status().is_success() {
2020 let body = resp.bytes().await?;
2021
2022 if log_enabled!(Level::Trace) {
2023 trace!(
2024 "POST Multipart Response: {}",
2025 String::from_utf8_lossy(&body)
2026 );
2027 }
2028
2029 let response: RpcResponse<String> = match serde_json::from_slice(&body) {
2030 Ok(response) => response,
2031 Err(err) => {
2032 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2033 return Err(err.into());
2034 }
2035 };
2036
2037 if let Some(error) = response.error {
2038 Err(Error::RpcError(error.code, error.message))
2039 } else if let Some(result) = response.result {
2040 Ok(result)
2041 } else {
2042 Err(Error::InvalidResponse)
2043 }
2044 } else {
2045 let err = resp.error_for_status_ref().unwrap_err();
2046 Err(Error::HttpError(err))
2047 }
2048 }
2049
2050 pub async fn rpc<Params, RpcResult>(
2059 &self,
2060 method: String,
2061 params: Option<Params>,
2062 ) -> Result<RpcResult, Error>
2063 where
2064 Params: Serialize,
2065 RpcResult: DeserializeOwned,
2066 {
2067 let auth_expires = self.token_expiration().await?;
2068 if auth_expires <= Utc::now() + Duration::from_secs(3600) {
2069 self.renew_token().await?;
2070 }
2071
2072 self.rpc_without_auth(method, params).await
2073 }
2074
2075 async fn rpc_without_auth<Params, RpcResult>(
2076 &self,
2077 method: String,
2078 params: Option<Params>,
2079 ) -> Result<RpcResult, Error>
2080 where
2081 Params: Serialize,
2082 RpcResult: DeserializeOwned,
2083 {
2084 let request = RpcRequest {
2085 method,
2086 params,
2087 ..Default::default()
2088 };
2089
2090 if log_enabled!(Level::Trace) {
2091 trace!(
2092 "RPC Request: {}",
2093 serde_json::ser::to_string_pretty(&request)?
2094 );
2095 }
2096
2097 for attempt in 0..MAX_RETRIES {
2098 let res = match self
2099 .http
2100 .post(format!("{}/api", self.url))
2101 .header("Accept", "application/json")
2102 .header("User-Agent", "EdgeFirst Client")
2103 .header("Authorization", format!("Bearer {}", self.token().await))
2104 .json(&request)
2105 .send()
2106 .await
2107 {
2108 Ok(res) => res,
2109 Err(err) => {
2110 warn!("Socket Error: {:?}", err);
2111 continue;
2112 }
2113 };
2114
2115 if res.status().is_success() {
2116 let body = res.bytes().await?;
2117
2118 if log_enabled!(Level::Trace) {
2119 trace!("RPC Response: {}", String::from_utf8_lossy(&body));
2120 }
2121
2122 let response: RpcResponse<RpcResult> = match serde_json::from_slice(&body) {
2123 Ok(response) => response,
2124 Err(err) => {
2125 error!("Invalid JSON Response: {}", String::from_utf8_lossy(&body));
2126 return Err(err.into());
2127 }
2128 };
2129
2130 if let Some(error) = response.error {
2136 return Err(Error::RpcError(error.code, error.message));
2137 } else if let Some(result) = response.result {
2138 return Ok(result);
2139 } else {
2140 return Err(Error::InvalidResponse);
2141 }
2142 } else {
2143 let err = res.error_for_status_ref().unwrap_err();
2144 warn!("HTTP Error {}: {}", err, res.text().await?);
2145 }
2146
2147 warn!(
2148 "Retrying RPC request (attempt {}/{})...",
2149 attempt + 1,
2150 MAX_RETRIES
2151 );
2152 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2153 }
2154
2155 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2156 }
2157}
2158
2159async fn upload_multipart(
2160 http: reqwest::Client,
2161 part: SnapshotPart,
2162 path: PathBuf,
2163 total: usize,
2164 current: Arc<AtomicUsize>,
2165 progress: Option<Sender<Progress>>,
2166) -> Result<SnapshotCompleteMultipartParams, Error> {
2167 let filesize = path.metadata()?.len() as usize;
2168 let n_parts = filesize.div_ceil(PART_SIZE);
2169 let sem = Arc::new(Semaphore::new(MAX_TASKS));
2170
2171 let key = part.key.unwrap();
2172 let upload_id = part.upload_id;
2173
2174 let urls = part.urls.clone();
2175 let etags = Arc::new(tokio::sync::Mutex::new(vec![
2176 EtagPart {
2177 etag: "".to_owned(),
2178 part_number: 0,
2179 };
2180 n_parts
2181 ]));
2182
2183 let tasks = (0..n_parts)
2184 .map(|part| {
2185 let http = http.clone();
2186 let url = urls[part].clone();
2187 let etags = etags.clone();
2188 let path = path.to_owned();
2189 let sem = sem.clone();
2190 let progress = progress.clone();
2191 let current = current.clone();
2192
2193 tokio::spawn(async move {
2194 let _permit = sem.acquire().await?;
2195 let mut etag = None;
2196
2197 for attempt in 0..MAX_RETRIES {
2198 match upload_part(http.clone(), url.clone(), path.clone(), part, n_parts).await
2199 {
2200 Ok(v) => {
2201 etag = Some(v);
2202 break;
2203 }
2204 Err(err) => {
2205 warn!("Upload Part Error: {:?}", err);
2206 tokio::time::sleep(Duration::from_secs(1) * attempt).await;
2207 }
2208 }
2209 }
2210
2211 if let Some(etag) = etag {
2212 let mut etags = etags.lock().await;
2213 etags[part] = EtagPart {
2214 etag,
2215 part_number: part + 1,
2216 };
2217
2218 let current = current.fetch_add(PART_SIZE, Ordering::SeqCst);
2219 if let Some(progress) = &progress {
2220 progress
2221 .send(Progress {
2222 current: current + PART_SIZE,
2223 total,
2224 })
2225 .await
2226 .unwrap();
2227 }
2228
2229 Ok(())
2230 } else {
2231 Err(Error::MaxRetriesExceeded(MAX_RETRIES))
2232 }
2233 })
2234 })
2235 .collect::<Vec<_>>();
2236
2237 join_all(tasks)
2238 .await
2239 .into_iter()
2240 .collect::<Result<Vec<_>, _>>()?;
2241
2242 Ok(SnapshotCompleteMultipartParams {
2243 key,
2244 upload_id,
2245 etag_list: etags.lock().await.clone(),
2246 })
2247}
2248
2249async fn upload_part(
2250 http: reqwest::Client,
2251 url: String,
2252 path: PathBuf,
2253 part: usize,
2254 n_parts: usize,
2255) -> Result<String, Error> {
2256 let filesize = path.metadata()?.len() as usize;
2257 let mut file = File::open(path).await.unwrap();
2258 file.seek(SeekFrom::Start((part * PART_SIZE) as u64))
2259 .await
2260 .unwrap();
2261 let file = file.take(PART_SIZE as u64);
2262
2263 let body_length = if part + 1 == n_parts {
2264 filesize % PART_SIZE
2265 } else {
2266 PART_SIZE
2267 };
2268
2269 let stream = FramedRead::new(file, BytesCodec::new());
2270 let body = Body::wrap_stream(stream);
2271
2272 let resp = http
2273 .put(url.clone())
2274 .header(CONTENT_LENGTH, body_length)
2275 .body(body)
2276 .send()
2277 .await?
2278 .error_for_status()?;
2279 let etag = resp
2280 .headers()
2281 .get("etag")
2282 .unwrap()
2283 .to_str()
2284 .unwrap()
2285 .to_owned();
2286 Ok(etag
2288 .strip_prefix("\"")
2289 .unwrap()
2290 .strip_suffix("\"")
2291 .unwrap()
2292 .to_owned())
2293}
2294
2295async fn upload_file_to_presigned_url(
2300 http: reqwest::Client,
2301 url: &str,
2302 path: PathBuf,
2303) -> Result<(), Error> {
2304 let file_data = fs::read(&path).await?;
2306 let file_size = file_data.len();
2307
2308 for attempt in 1..=MAX_RETRIES {
2310 match http
2311 .put(url)
2312 .header(CONTENT_LENGTH, file_size)
2313 .body(file_data.clone())
2314 .send()
2315 .await
2316 {
2317 Ok(resp) => {
2318 if resp.status().is_success() {
2319 debug!(
2320 "Successfully uploaded file: {:?} ({} bytes)",
2321 path, file_size
2322 );
2323 return Ok(());
2324 } else {
2325 let status = resp.status();
2326 let error_text = resp.text().await.unwrap_or_default();
2327 warn!(
2328 "Upload failed [attempt {}/{}]: HTTP {} - {}",
2329 attempt, MAX_RETRIES, status, error_text
2330 );
2331 }
2332 }
2333 Err(err) => {
2334 warn!(
2335 "Upload error [attempt {}/{}]: {:?}",
2336 attempt, MAX_RETRIES, err
2337 );
2338 }
2339 }
2340
2341 if attempt < MAX_RETRIES {
2342 tokio::time::sleep(Duration::from_secs(attempt as u64)).await;
2343 }
2344 }
2345
2346 Err(Error::InvalidParameters(format!(
2347 "Failed to upload file {:?} after {} attempts",
2348 path, MAX_RETRIES
2349 )))
2350}