1use crate::{AnnotationSet, Client, Dataset, Error, Sample, client};
5use chrono::{DateTime, Utc};
6use log::trace;
7use reqwest::multipart::{Form, Part};
8use serde::{Deserialize, Deserializer, Serialize};
9use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr};
10
11#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
41#[serde(untagged)]
42pub enum Parameter {
43 Integer(i64),
45 Real(f64),
47 Boolean(bool),
49 String(String),
51 Array(Vec<Parameter>),
53 Object(HashMap<String, Parameter>),
55}
56
57#[derive(Deserialize)]
58pub struct LoginResult {
59 pub(crate) token: String,
60}
61
62#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
82pub struct OrganizationID(u64);
83
84impl Display for OrganizationID {
85 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
86 write!(f, "org-{:x}", self.0)
87 }
88}
89
90impl From<u64> for OrganizationID {
91 fn from(id: u64) -> Self {
92 OrganizationID(id)
93 }
94}
95
96impl From<OrganizationID> for u64 {
97 fn from(val: OrganizationID) -> Self {
98 val.0
99 }
100}
101
102impl OrganizationID {
103 pub fn value(&self) -> u64 {
104 self.0
105 }
106}
107
108impl TryFrom<&str> for OrganizationID {
109 type Error = Error;
110
111 fn try_from(s: &str) -> Result<Self, Self::Error> {
112 let hex_part = s.strip_prefix("org-").ok_or_else(|| {
113 Error::InvalidParameters("Organization ID must start with 'org-' prefix".to_string())
114 })?;
115 let id = u64::from_str_radix(hex_part, 16)?;
116 Ok(OrganizationID(id))
117 }
118}
119
120#[derive(Deserialize, Clone, Debug)]
141pub struct Organization {
142 id: OrganizationID,
143 name: String,
144 #[serde(rename = "latest_credit")]
145 credits: i64,
146}
147
148impl Display for Organization {
149 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154impl Organization {
155 pub fn id(&self) -> OrganizationID {
156 self.id
157 }
158
159 pub fn name(&self) -> &str {
160 &self.name
161 }
162
163 pub fn credits(&self) -> i64 {
164 self.credits
165 }
166}
167
168#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
189pub struct ProjectID(u64);
190
191impl Display for ProjectID {
192 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
193 write!(f, "p-{:x}", self.0)
194 }
195}
196
197impl From<u64> for ProjectID {
198 fn from(id: u64) -> Self {
199 ProjectID(id)
200 }
201}
202
203impl From<ProjectID> for u64 {
204 fn from(val: ProjectID) -> Self {
205 val.0
206 }
207}
208
209impl ProjectID {
210 pub fn value(&self) -> u64 {
211 self.0
212 }
213}
214
215impl TryFrom<&str> for ProjectID {
216 type Error = Error;
217
218 fn try_from(s: &str) -> Result<Self, Self::Error> {
219 ProjectID::from_str(s)
220 }
221}
222
223impl TryFrom<String> for ProjectID {
224 type Error = Error;
225
226 fn try_from(s: String) -> Result<Self, Self::Error> {
227 ProjectID::from_str(&s)
228 }
229}
230
231impl FromStr for ProjectID {
232 type Err = Error;
233
234 fn from_str(s: &str) -> Result<Self, Self::Err> {
235 let hex_part = s.strip_prefix("p-").ok_or_else(|| {
236 Error::InvalidParameters("Project ID must start with 'p-' prefix".to_string())
237 })?;
238 let id = u64::from_str_radix(hex_part, 16)?;
239 Ok(ProjectID(id))
240 }
241}
242
243#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
264pub struct ExperimentID(u64);
265
266impl Display for ExperimentID {
267 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
268 write!(f, "exp-{:x}", self.0)
269 }
270}
271
272impl From<u64> for ExperimentID {
273 fn from(id: u64) -> Self {
274 ExperimentID(id)
275 }
276}
277
278impl From<ExperimentID> for u64 {
279 fn from(val: ExperimentID) -> Self {
280 val.0
281 }
282}
283
284impl ExperimentID {
285 pub fn value(&self) -> u64 {
286 self.0
287 }
288}
289
290impl TryFrom<&str> for ExperimentID {
291 type Error = Error;
292
293 fn try_from(s: &str) -> Result<Self, Self::Error> {
294 ExperimentID::from_str(s)
295 }
296}
297
298impl TryFrom<String> for ExperimentID {
299 type Error = Error;
300
301 fn try_from(s: String) -> Result<Self, Self::Error> {
302 ExperimentID::from_str(&s)
303 }
304}
305
306impl FromStr for ExperimentID {
307 type Err = Error;
308
309 fn from_str(s: &str) -> Result<Self, Self::Err> {
310 let hex_part = s.strip_prefix("exp-").ok_or_else(|| {
311 Error::InvalidParameters("Experiment ID must start with 'exp-' prefix".to_string())
312 })?;
313 let id = u64::from_str_radix(hex_part, 16)?;
314 Ok(ExperimentID(id))
315 }
316}
317
318#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
339pub struct TrainingSessionID(u64);
340
341impl Display for TrainingSessionID {
342 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
343 write!(f, "t-{:x}", self.0)
344 }
345}
346
347impl From<u64> for TrainingSessionID {
348 fn from(id: u64) -> Self {
349 TrainingSessionID(id)
350 }
351}
352
353impl From<TrainingSessionID> for u64 {
354 fn from(val: TrainingSessionID) -> Self {
355 val.0
356 }
357}
358
359impl TrainingSessionID {
360 pub fn value(&self) -> u64 {
361 self.0
362 }
363}
364
365impl TryFrom<&str> for TrainingSessionID {
366 type Error = Error;
367
368 fn try_from(s: &str) -> Result<Self, Self::Error> {
369 TrainingSessionID::from_str(s)
370 }
371}
372
373impl TryFrom<String> for TrainingSessionID {
374 type Error = Error;
375
376 fn try_from(s: String) -> Result<Self, Self::Error> {
377 TrainingSessionID::from_str(&s)
378 }
379}
380
381impl FromStr for TrainingSessionID {
382 type Err = Error;
383
384 fn from_str(s: &str) -> Result<Self, Self::Err> {
385 let hex_part = s.strip_prefix("t-").ok_or_else(|| {
386 Error::InvalidParameters("Training Session ID must start with 't-' prefix".to_string())
387 })?;
388 let id = u64::from_str_radix(hex_part, 16)?;
389 Ok(TrainingSessionID(id))
390 }
391}
392
393#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
413pub struct ValidationSessionID(u64);
414
415impl Display for ValidationSessionID {
416 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
417 write!(f, "v-{:x}", self.0)
418 }
419}
420
421impl From<u64> for ValidationSessionID {
422 fn from(id: u64) -> Self {
423 ValidationSessionID(id)
424 }
425}
426
427impl From<ValidationSessionID> for u64 {
428 fn from(val: ValidationSessionID) -> Self {
429 val.0
430 }
431}
432
433impl ValidationSessionID {
434 pub fn value(&self) -> u64 {
435 self.0
436 }
437}
438
439impl TryFrom<&str> for ValidationSessionID {
440 type Error = Error;
441
442 fn try_from(s: &str) -> Result<Self, Self::Error> {
443 let hex_part = s.strip_prefix("v-").ok_or_else(|| {
444 Error::InvalidParameters(
445 "Validation Session ID must start with 'v-' prefix".to_string(),
446 )
447 })?;
448 let id = u64::from_str_radix(hex_part, 16)?;
449 Ok(ValidationSessionID(id))
450 }
451}
452
453impl TryFrom<String> for ValidationSessionID {
454 type Error = Error;
455
456 fn try_from(s: String) -> Result<Self, Self::Error> {
457 ValidationSessionID::try_from(s.as_str())
458 }
459}
460
461#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
462pub struct SnapshotID(u64);
463
464impl Display for SnapshotID {
465 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
466 write!(f, "ss-{:x}", self.0)
467 }
468}
469
470impl From<u64> for SnapshotID {
471 fn from(id: u64) -> Self {
472 SnapshotID(id)
473 }
474}
475
476impl From<SnapshotID> for u64 {
477 fn from(val: SnapshotID) -> Self {
478 val.0
479 }
480}
481
482impl SnapshotID {
483 pub fn value(&self) -> u64 {
484 self.0
485 }
486}
487
488impl TryFrom<&str> for SnapshotID {
489 type Error = Error;
490
491 fn try_from(s: &str) -> Result<Self, Self::Error> {
492 let hex_part = s.strip_prefix("ss-").ok_or_else(|| {
493 Error::InvalidParameters("Snapshot ID must start with 'ss-' prefix".to_string())
494 })?;
495 let id = u64::from_str_radix(hex_part, 16)?;
496 Ok(SnapshotID(id))
497 }
498}
499
500#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
501pub struct TaskID(u64);
502
503impl Display for TaskID {
504 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
505 write!(f, "task-{:x}", self.0)
506 }
507}
508
509impl From<u64> for TaskID {
510 fn from(id: u64) -> Self {
511 TaskID(id)
512 }
513}
514
515impl From<TaskID> for u64 {
516 fn from(val: TaskID) -> Self {
517 val.0
518 }
519}
520
521impl TaskID {
522 pub fn value(&self) -> u64 {
523 self.0
524 }
525}
526
527impl TryFrom<&str> for TaskID {
528 type Error = Error;
529
530 fn try_from(s: &str) -> Result<Self, Self::Error> {
531 TaskID::from_str(s)
532 }
533}
534
535impl TryFrom<String> for TaskID {
536 type Error = Error;
537
538 fn try_from(s: String) -> Result<Self, Self::Error> {
539 TaskID::from_str(&s)
540 }
541}
542
543impl FromStr for TaskID {
544 type Err = Error;
545
546 fn from_str(s: &str) -> Result<Self, Self::Err> {
547 let hex_part = s.strip_prefix("task-").ok_or_else(|| {
548 Error::InvalidParameters("Task ID must start with 'task-' prefix".to_string())
549 })?;
550 let id = u64::from_str_radix(hex_part, 16)?;
551 Ok(TaskID(id))
552 }
553}
554
555#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
576pub struct DatasetID(u64);
577
578impl Display for DatasetID {
579 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
580 write!(f, "ds-{:x}", self.0)
581 }
582}
583
584impl From<u64> for DatasetID {
585 fn from(id: u64) -> Self {
586 DatasetID(id)
587 }
588}
589
590impl From<DatasetID> for u64 {
591 fn from(val: DatasetID) -> Self {
592 val.0
593 }
594}
595
596impl DatasetID {
597 pub fn value(&self) -> u64 {
598 self.0
599 }
600}
601
602impl TryFrom<&str> for DatasetID {
603 type Error = Error;
604
605 fn try_from(s: &str) -> Result<Self, Self::Error> {
606 DatasetID::from_str(s)
607 }
608}
609
610impl TryFrom<String> for DatasetID {
611 type Error = Error;
612
613 fn try_from(s: String) -> Result<Self, Self::Error> {
614 DatasetID::from_str(&s)
615 }
616}
617
618impl FromStr for DatasetID {
619 type Err = Error;
620
621 fn from_str(s: &str) -> Result<Self, Self::Err> {
622 let hex_part = s.strip_prefix("ds-").ok_or_else(|| {
623 Error::InvalidParameters("Dataset ID must start with 'ds-' prefix".to_string())
624 })?;
625 let id = u64::from_str_radix(hex_part, 16)?;
626 Ok(DatasetID(id))
627 }
628}
629
630#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
631pub struct AnnotationSetID(u64);
632
633impl Display for AnnotationSetID {
634 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
635 write!(f, "as-{:x}", self.0)
636 }
637}
638
639impl From<u64> for AnnotationSetID {
640 fn from(id: u64) -> Self {
641 AnnotationSetID(id)
642 }
643}
644
645impl From<AnnotationSetID> for u64 {
646 fn from(val: AnnotationSetID) -> Self {
647 val.0
648 }
649}
650
651impl AnnotationSetID {
652 pub fn value(&self) -> u64 {
653 self.0
654 }
655}
656
657impl TryFrom<&str> for AnnotationSetID {
658 type Error = Error;
659
660 fn try_from(s: &str) -> Result<Self, Self::Error> {
661 AnnotationSetID::from_str(s)
662 }
663}
664
665impl TryFrom<String> for AnnotationSetID {
666 type Error = Error;
667
668 fn try_from(s: String) -> Result<Self, Self::Error> {
669 AnnotationSetID::from_str(&s)
670 }
671}
672
673impl FromStr for AnnotationSetID {
674 type Err = Error;
675
676 fn from_str(s: &str) -> Result<Self, Self::Err> {
677 let hex_part = s.strip_prefix("as-").ok_or_else(|| {
678 Error::InvalidParameters("Annotation Set ID must start with 'as-' prefix".to_string())
679 })?;
680 let id = u64::from_str_radix(hex_part, 16)?;
681 Ok(AnnotationSetID(id))
682 }
683}
684
685#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
686pub struct SampleID(u64);
687
688impl Display for SampleID {
689 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
690 write!(f, "s-{:x}", self.0)
691 }
692}
693
694impl From<u64> for SampleID {
695 fn from(id: u64) -> Self {
696 SampleID(id)
697 }
698}
699
700impl From<SampleID> for u64 {
701 fn from(val: SampleID) -> Self {
702 val.0
703 }
704}
705
706impl SampleID {
707 pub fn value(&self) -> u64 {
708 self.0
709 }
710}
711
712impl TryFrom<&str> for SampleID {
713 type Error = Error;
714
715 fn try_from(s: &str) -> Result<Self, Self::Error> {
716 let hex_part = s.strip_prefix("s-").ok_or_else(|| {
717 Error::InvalidParameters("Sample ID must start with 's-' prefix".to_string())
718 })?;
719 let id = u64::from_str_radix(hex_part, 16)?;
720 Ok(SampleID(id))
721 }
722}
723
724#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
725pub struct AppId(u64);
726
727impl Display for AppId {
728 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
729 write!(f, "app-{:x}", self.0)
730 }
731}
732
733impl From<u64> for AppId {
734 fn from(id: u64) -> Self {
735 AppId(id)
736 }
737}
738
739impl From<AppId> for u64 {
740 fn from(val: AppId) -> Self {
741 val.0
742 }
743}
744
745impl AppId {
746 pub fn value(&self) -> u64 {
747 self.0
748 }
749}
750
751impl TryFrom<&str> for AppId {
752 type Error = Error;
753
754 fn try_from(s: &str) -> Result<Self, Self::Error> {
755 let hex_part = s.strip_prefix("app-").ok_or_else(|| {
756 Error::InvalidParameters("App ID must start with 'app-' prefix".to_string())
757 })?;
758 let id = u64::from_str_radix(hex_part, 16)?;
759 Ok(AppId(id))
760 }
761}
762
763#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
764pub struct ImageId(u64);
765
766impl Display for ImageId {
767 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
768 write!(f, "im-{:x}", self.0)
769 }
770}
771
772impl From<u64> for ImageId {
773 fn from(id: u64) -> Self {
774 ImageId(id)
775 }
776}
777
778impl From<ImageId> for u64 {
779 fn from(val: ImageId) -> Self {
780 val.0
781 }
782}
783
784impl ImageId {
785 pub fn value(&self) -> u64 {
786 self.0
787 }
788}
789
790impl TryFrom<&str> for ImageId {
791 type Error = Error;
792
793 fn try_from(s: &str) -> Result<Self, Self::Error> {
794 let hex_part = s.strip_prefix("im-").ok_or_else(|| {
795 Error::InvalidParameters("Image ID must start with 'im-' prefix".to_string())
796 })?;
797 let id = u64::from_str_radix(hex_part, 16)?;
798 Ok(ImageId(id))
799 }
800}
801
802#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
803pub struct SequenceId(u64);
804
805impl Display for SequenceId {
806 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
807 write!(f, "se-{:x}", self.0)
808 }
809}
810
811impl From<u64> for SequenceId {
812 fn from(id: u64) -> Self {
813 SequenceId(id)
814 }
815}
816
817impl From<SequenceId> for u64 {
818 fn from(val: SequenceId) -> Self {
819 val.0
820 }
821}
822
823impl SequenceId {
824 pub fn value(&self) -> u64 {
825 self.0
826 }
827}
828
829impl TryFrom<&str> for SequenceId {
830 type Error = Error;
831
832 fn try_from(s: &str) -> Result<Self, Self::Error> {
833 let hex_part = s.strip_prefix("se-").ok_or_else(|| {
834 Error::InvalidParameters("Sequence ID must start with 'se-' prefix".to_string())
835 })?;
836 let id = u64::from_str_radix(hex_part, 16)?;
837 Ok(SequenceId(id))
838 }
839}
840
841#[derive(Deserialize, Clone, Debug)]
845pub struct Project {
846 id: ProjectID,
847 name: String,
848 description: String,
849}
850
851impl Display for Project {
852 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
853 write!(f, "{} {}", self.id(), self.name())
854 }
855}
856
857impl Project {
858 pub fn id(&self) -> ProjectID {
859 self.id
860 }
861
862 pub fn name(&self) -> &str {
863 &self.name
864 }
865
866 pub fn description(&self) -> &str {
867 &self.description
868 }
869
870 pub async fn datasets(
871 &self,
872 client: &client::Client,
873 name: Option<&str>,
874 ) -> Result<Vec<Dataset>, Error> {
875 client.datasets(self.id, name).await
876 }
877
878 pub async fn experiments(
879 &self,
880 client: &client::Client,
881 name: Option<&str>,
882 ) -> Result<Vec<Experiment>, Error> {
883 client.experiments(self.id, name).await
884 }
885}
886
887#[derive(Deserialize, Debug)]
888pub struct SamplesCountResult {
889 pub total: u64,
890}
891
892#[derive(Serialize, Clone, Debug)]
893pub struct SamplesListParams {
894 pub dataset_id: DatasetID,
895 #[serde(skip_serializing_if = "Option::is_none")]
896 pub annotation_set_id: Option<AnnotationSetID>,
897 #[serde(skip_serializing_if = "Option::is_none")]
898 pub continue_token: Option<String>,
899 #[serde(skip_serializing_if = "Vec::is_empty")]
900 pub types: Vec<String>,
901 #[serde(skip_serializing_if = "Vec::is_empty")]
902 pub group_names: Vec<String>,
903}
904
905#[derive(Deserialize, Debug)]
906pub struct SamplesListResult {
907 pub samples: Vec<Sample>,
908 pub continue_token: Option<String>,
909}
910
911#[derive(Deserialize)]
912pub struct Snapshot {
913 id: SnapshotID,
914 description: String,
915 status: String,
916 path: String,
917 #[serde(rename = "date")]
918 created: DateTime<Utc>,
919}
920
921impl Display for Snapshot {
922 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
923 write!(f, "{} {}", self.id, self.description)
924 }
925}
926
927impl Snapshot {
928 pub fn id(&self) -> SnapshotID {
929 self.id
930 }
931
932 pub fn description(&self) -> &str {
933 &self.description
934 }
935
936 pub fn status(&self) -> &str {
937 &self.status
938 }
939
940 pub fn path(&self) -> &str {
941 &self.path
942 }
943
944 pub fn created(&self) -> &DateTime<Utc> {
945 &self.created
946 }
947}
948
949#[derive(Serialize, Debug)]
950pub struct SnapshotRestore {
951 pub project_id: ProjectID,
952 pub snapshot_id: SnapshotID,
953 pub fps: u64,
954 #[serde(rename = "enabled_topics")]
955 pub topics: Vec<String>,
956 #[serde(rename = "label_names")]
957 pub autolabel: Vec<String>,
958 #[serde(rename = "depth_gen")]
959 pub autodepth: bool,
960 pub agtg_pipeline: bool,
961 #[serde(skip_serializing_if = "Option::is_none")]
962 pub dataset_name: Option<String>,
963 #[serde(skip_serializing_if = "Option::is_none")]
964 pub dataset_description: Option<String>,
965}
966
967#[derive(Deserialize, Debug)]
968pub struct SnapshotRestoreResult {
969 pub id: SnapshotID,
970 pub description: String,
971 pub dataset_name: String,
972 pub dataset_id: DatasetID,
973 pub annotation_set_id: AnnotationSetID,
974 #[serde(rename = "docker_task_id")]
975 pub task_id: TaskID,
976 pub date: DateTime<Utc>,
977}
978
979#[derive(Deserialize)]
980pub struct Experiment {
981 id: ExperimentID,
982 project_id: ProjectID,
983 name: String,
984 description: String,
985}
986
987impl Display for Experiment {
988 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
989 write!(f, "{} {}", self.uid(), self.name)
990 }
991}
992
993impl Experiment {
994 pub fn id(&self) -> ExperimentID {
995 self.id
996 }
997
998 pub fn uid(&self) -> String {
999 self.id.to_string()
1000 }
1001
1002 pub fn project_id(&self) -> ProjectID {
1003 self.project_id
1004 }
1005
1006 pub fn name(&self) -> &str {
1007 &self.name
1008 }
1009
1010 pub fn description(&self) -> &str {
1011 &self.description
1012 }
1013
1014 pub async fn project(&self, client: &client::Client) -> Result<Project, Error> {
1015 client.project(self.project_id).await
1016 }
1017
1018 pub async fn training_sessions(
1019 &self,
1020 client: &client::Client,
1021 name: Option<&str>,
1022 ) -> Result<Vec<TrainingSession>, Error> {
1023 client.training_sessions(self.id, name).await
1024 }
1025}
1026
1027#[derive(Serialize, Debug)]
1028pub struct PublishMetrics {
1029 #[serde(rename = "trainer_session_id", skip_serializing_if = "Option::is_none")]
1030 pub trainer_session_id: Option<TrainingSessionID>,
1031 #[serde(
1032 rename = "validate_session_id",
1033 skip_serializing_if = "Option::is_none"
1034 )]
1035 pub validate_session_id: Option<ValidationSessionID>,
1036 pub metrics: HashMap<String, Parameter>,
1037}
1038
1039#[derive(Deserialize)]
1040struct TrainingSessionParams {
1041 model_params: HashMap<String, Parameter>,
1042 dataset_params: DatasetParams,
1043}
1044
1045#[derive(Deserialize)]
1046pub struct TrainingSession {
1047 id: TrainingSessionID,
1048 #[serde(rename = "trainer_id")]
1049 experiment_id: ExperimentID,
1050 model: String,
1051 name: String,
1052 description: String,
1053 params: TrainingSessionParams,
1054 #[serde(rename = "docker_task")]
1055 task: Task,
1056}
1057
1058impl Display for TrainingSession {
1059 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1060 write!(f, "{} {}", self.uid(), self.name())
1061 }
1062}
1063
1064impl TrainingSession {
1065 pub fn id(&self) -> TrainingSessionID {
1066 self.id
1067 }
1068
1069 pub fn uid(&self) -> String {
1070 self.id.to_string()
1071 }
1072
1073 pub fn name(&self) -> &str {
1074 &self.name
1075 }
1076
1077 pub fn description(&self) -> &str {
1078 &self.description
1079 }
1080
1081 pub fn model(&self) -> &str {
1082 &self.model
1083 }
1084
1085 pub fn experiment_id(&self) -> ExperimentID {
1086 self.experiment_id
1087 }
1088
1089 pub fn task(&self) -> Task {
1090 self.task.clone()
1091 }
1092
1093 pub fn model_params(&self) -> &HashMap<String, Parameter> {
1094 &self.params.model_params
1095 }
1096
1097 pub fn dataset_params(&self) -> &DatasetParams {
1098 &self.params.dataset_params
1099 }
1100
1101 pub fn train_group(&self) -> &str {
1102 &self.params.dataset_params.train_group
1103 }
1104
1105 pub fn val_group(&self) -> &str {
1106 &self.params.dataset_params.val_group
1107 }
1108
1109 pub async fn experiment(&self, client: &client::Client) -> Result<Experiment, Error> {
1110 client.experiment(self.experiment_id).await
1111 }
1112
1113 pub async fn dataset(&self, client: &client::Client) -> Result<Dataset, Error> {
1114 client.dataset(self.params.dataset_params.dataset_id).await
1115 }
1116
1117 pub async fn annotation_set(&self, client: &client::Client) -> Result<AnnotationSet, Error> {
1118 client
1119 .annotation_set(self.params.dataset_params.annotation_set_id)
1120 .await
1121 }
1122
1123 pub async fn artifacts(&self, client: &client::Client) -> Result<Vec<Artifact>, Error> {
1124 client.artifacts(self.id).await
1125 }
1126
1127 pub async fn metrics(
1128 &self,
1129 client: &client::Client,
1130 ) -> Result<HashMap<String, Parameter>, Error> {
1131 #[derive(Deserialize)]
1132 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
1133 enum Response {
1134 Empty {},
1135 Map(HashMap<String, Parameter>),
1136 String(String),
1137 }
1138
1139 let params = HashMap::from([("trainer_session_id", self.id().value())]);
1140 let resp: Response = client
1141 .rpc("trainer.session.metrics".to_owned(), Some(params))
1142 .await?;
1143
1144 Ok(match resp {
1145 Response::String(metrics) => serde_json::from_str(&metrics)?,
1146 Response::Map(metrics) => metrics,
1147 Response::Empty {} => HashMap::new(),
1148 })
1149 }
1150
1151 pub async fn set_metrics(
1152 &self,
1153 client: &client::Client,
1154 metrics: HashMap<String, Parameter>,
1155 ) -> Result<(), Error> {
1156 let metrics = PublishMetrics {
1157 trainer_session_id: Some(self.id()),
1158 validate_session_id: None,
1159 metrics,
1160 };
1161
1162 let _: String = client
1163 .rpc("trainer.session.metrics".to_owned(), Some(metrics))
1164 .await?;
1165
1166 Ok(())
1167 }
1168
1169 pub async fn download_artifact(
1171 &self,
1172 client: &client::Client,
1173 filename: &str,
1174 ) -> Result<Vec<u8>, Error> {
1175 client
1176 .fetch(&format!(
1177 "download_model?training_session_id={}&file={}",
1178 self.id().value(),
1179 filename
1180 ))
1181 .await
1182 }
1183
1184 pub async fn upload_artifact(
1188 &self,
1189 client: &client::Client,
1190 filename: &str,
1191 path: PathBuf,
1192 ) -> Result<(), Error> {
1193 self.upload(client, &[(format!("artifacts/{}", filename), path)])
1194 .await
1195 }
1196
1197 pub async fn download_checkpoint(
1199 &self,
1200 client: &client::Client,
1201 filename: &str,
1202 ) -> Result<Vec<u8>, Error> {
1203 client
1204 .fetch(&format!(
1205 "download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1206 self.id().value(),
1207 filename
1208 ))
1209 .await
1210 }
1211
1212 pub async fn upload_checkpoint(
1216 &self,
1217 client: &client::Client,
1218 filename: &str,
1219 path: PathBuf,
1220 ) -> Result<(), Error> {
1221 self.upload(client, &[(format!("checkpoints/{}", filename), path)])
1222 .await
1223 }
1224
1225 pub async fn download(&self, client: &client::Client, filename: &str) -> Result<String, Error> {
1229 #[derive(Serialize)]
1230 struct DownloadRequest {
1231 session_id: TrainingSessionID,
1232 file_path: String,
1233 }
1234
1235 let params = DownloadRequest {
1236 session_id: self.id(),
1237 file_path: filename.to_string(),
1238 };
1239
1240 client
1241 .rpc("trainer.download.file".to_owned(), Some(params))
1242 .await
1243 }
1244
1245 pub async fn upload(
1246 &self,
1247 client: &client::Client,
1248 files: &[(String, PathBuf)],
1249 ) -> Result<(), Error> {
1250 let mut parts = Form::new().part(
1251 "params",
1252 Part::text(format!("{{ \"session_id\": {} }}", self.id().value())),
1253 );
1254
1255 for (name, path) in files {
1256 let file_part = Part::file(path).await?.file_name(name.to_owned());
1257 parts = parts.part("file", file_part);
1258 }
1259
1260 let result = client.post_multipart("trainer.upload.files", parts).await?;
1261 trace!("TrainingSession::upload: {:?}", result);
1262 Ok(())
1263 }
1264}
1265
1266#[derive(Deserialize, Clone, Debug)]
1267pub struct ValidationSession {
1268 id: ValidationSessionID,
1269 description: String,
1270 dataset_id: DatasetID,
1271 experiment_id: ExperimentID,
1272 training_session_id: TrainingSessionID,
1273 #[serde(rename = "gt_annotation_set_id")]
1274 annotation_set_id: AnnotationSetID,
1275 #[serde(deserialize_with = "validation_session_params")]
1276 params: HashMap<String, Parameter>,
1277 #[serde(rename = "docker_task")]
1278 task: Task,
1279}
1280
1281fn validation_session_params<'de, D>(
1282 deserializer: D,
1283) -> Result<HashMap<String, Parameter>, D::Error>
1284where
1285 D: Deserializer<'de>,
1286{
1287 #[derive(Deserialize)]
1288 struct ModelParams {
1289 validation: Option<HashMap<String, Parameter>>,
1290 }
1291
1292 #[derive(Deserialize)]
1293 struct ValidateParams {
1294 model: String,
1295 }
1296
1297 #[derive(Deserialize)]
1298 struct Params {
1299 model_params: ModelParams,
1300 validate_params: ValidateParams,
1301 }
1302
1303 let params = Params::deserialize(deserializer)?;
1304 let params = match params.model_params.validation {
1305 Some(mut map) => {
1306 map.insert(
1307 "model".to_string(),
1308 Parameter::String(params.validate_params.model),
1309 );
1310 map
1311 }
1312 None => HashMap::from([(
1313 "model".to_string(),
1314 Parameter::String(params.validate_params.model),
1315 )]),
1316 };
1317
1318 Ok(params)
1319}
1320
1321impl ValidationSession {
1322 pub fn id(&self) -> ValidationSessionID {
1323 self.id
1324 }
1325
1326 pub fn uid(&self) -> String {
1327 self.id.to_string()
1328 }
1329
1330 pub fn name(&self) -> &str {
1331 self.task.name()
1332 }
1333
1334 pub fn description(&self) -> &str {
1335 &self.description
1336 }
1337
1338 pub fn dataset_id(&self) -> DatasetID {
1339 self.dataset_id
1340 }
1341
1342 pub fn experiment_id(&self) -> ExperimentID {
1343 self.experiment_id
1344 }
1345
1346 pub fn training_session_id(&self) -> TrainingSessionID {
1347 self.training_session_id
1348 }
1349
1350 pub fn annotation_set_id(&self) -> AnnotationSetID {
1351 self.annotation_set_id
1352 }
1353
1354 pub fn params(&self) -> &HashMap<String, Parameter> {
1355 &self.params
1356 }
1357
1358 pub fn task(&self) -> &Task {
1359 &self.task
1360 }
1361
1362 pub async fn metrics(
1363 &self,
1364 client: &client::Client,
1365 ) -> Result<HashMap<String, Parameter>, Error> {
1366 #[derive(Deserialize)]
1367 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
1368 enum Response {
1369 Empty {},
1370 Map(HashMap<String, Parameter>),
1371 String(String),
1372 }
1373
1374 let params = HashMap::from([("validate_session_id", self.id().value())]);
1375 let resp: Response = client
1376 .rpc("validate.session.metrics".to_owned(), Some(params))
1377 .await?;
1378
1379 Ok(match resp {
1380 Response::String(metrics) => serde_json::from_str(&metrics)?,
1381 Response::Map(metrics) => metrics,
1382 Response::Empty {} => HashMap::new(),
1383 })
1384 }
1385
1386 pub async fn set_metrics(
1387 &self,
1388 client: &client::Client,
1389 metrics: HashMap<String, Parameter>,
1390 ) -> Result<(), Error> {
1391 let metrics = PublishMetrics {
1392 trainer_session_id: None,
1393 validate_session_id: Some(self.id()),
1394 metrics,
1395 };
1396
1397 let _: String = client
1398 .rpc("validate.session.metrics".to_owned(), Some(metrics))
1399 .await?;
1400
1401 Ok(())
1402 }
1403
1404 pub async fn upload(
1405 &self,
1406 client: &client::Client,
1407 files: &[(String, PathBuf)],
1408 ) -> Result<(), Error> {
1409 let mut parts = Form::new().part(
1410 "params",
1411 Part::text(format!("{{ \"session_id\": {} }}", self.id().value())),
1412 );
1413
1414 for (name, path) in files {
1415 let file_part = Part::file(path).await?.file_name(name.to_owned());
1416 parts = parts.part("file", file_part);
1417 }
1418
1419 let result = client
1420 .post_multipart("validate.upload.files", parts)
1421 .await?;
1422 trace!("ValidationSession::upload: {:?}", result);
1423 Ok(())
1424 }
1425}
1426
1427#[derive(Deserialize, Clone, Debug)]
1428pub struct DatasetParams {
1429 dataset_id: DatasetID,
1430 annotation_set_id: AnnotationSetID,
1431 #[serde(rename = "train_group_name")]
1432 train_group: String,
1433 #[serde(rename = "val_group_name")]
1434 val_group: String,
1435}
1436
1437impl DatasetParams {
1438 pub fn dataset_id(&self) -> DatasetID {
1439 self.dataset_id
1440 }
1441
1442 pub fn annotation_set_id(&self) -> AnnotationSetID {
1443 self.annotation_set_id
1444 }
1445
1446 pub fn train_group(&self) -> &str {
1447 &self.train_group
1448 }
1449
1450 pub fn val_group(&self) -> &str {
1451 &self.val_group
1452 }
1453}
1454
1455#[derive(Serialize, Debug, Clone)]
1456pub struct TasksListParams {
1457 #[serde(skip_serializing_if = "Option::is_none")]
1458 pub continue_token: Option<String>,
1459 #[serde(rename = "manage_types", skip_serializing_if = "Option::is_none")]
1460 pub manager: Option<Vec<String>>,
1461 #[serde(skip_serializing_if = "Option::is_none")]
1462 pub status: Option<Vec<String>>,
1463}
1464
1465#[derive(Deserialize, Debug, Clone)]
1466pub struct TasksListResult {
1467 pub tasks: Vec<Task>,
1468 pub continue_token: Option<String>,
1469}
1470
1471#[derive(Deserialize, Debug, Clone)]
1472pub struct Task {
1473 id: TaskID,
1474 name: String,
1475 #[serde(rename = "type")]
1476 workflow: String,
1477 status: String,
1478 #[serde(rename = "manage_type")]
1479 manager: Option<String>,
1480 #[serde(rename = "instance_type")]
1481 instance: String,
1482 #[serde(rename = "date")]
1483 created: DateTime<Utc>,
1484}
1485
1486impl Task {
1487 pub fn id(&self) -> TaskID {
1488 self.id
1489 }
1490
1491 pub fn uid(&self) -> String {
1492 self.id.to_string()
1493 }
1494
1495 pub fn name(&self) -> &str {
1496 &self.name
1497 }
1498
1499 pub fn workflow(&self) -> &str {
1500 &self.workflow
1501 }
1502
1503 pub fn status(&self) -> &str {
1504 &self.status
1505 }
1506
1507 pub fn manager(&self) -> Option<&str> {
1508 self.manager.as_deref()
1509 }
1510
1511 pub fn instance(&self) -> &str {
1512 &self.instance
1513 }
1514
1515 pub fn created(&self) -> &DateTime<Utc> {
1516 &self.created
1517 }
1518}
1519
1520impl Display for Task {
1521 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1522 write!(
1523 f,
1524 "{} [{:?} {}] {}",
1525 self.uid(),
1526 self.manager(),
1527 self.workflow(),
1528 self.name()
1529 )
1530 }
1531}
1532
1533#[derive(Deserialize, Debug)]
1534pub struct TaskInfo {
1535 id: TaskID,
1536 project_id: Option<ProjectID>,
1537 #[serde(rename = "task_description")]
1538 description: String,
1539 #[serde(rename = "type")]
1540 workflow: String,
1541 status: Option<String>,
1542 progress: TaskProgress,
1543 #[serde(rename = "created_date")]
1544 created: DateTime<Utc>,
1545 #[serde(rename = "end_date")]
1546 completed: DateTime<Utc>,
1547}
1548
1549impl Display for TaskInfo {
1550 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1551 write!(
1552 f,
1553 "{} {}: {}",
1554 self.uid(),
1555 self.workflow(),
1556 self.description()
1557 )
1558 }
1559}
1560
1561impl TaskInfo {
1562 pub fn id(&self) -> TaskID {
1563 self.id
1564 }
1565
1566 pub fn uid(&self) -> String {
1567 self.id.to_string()
1568 }
1569
1570 pub fn project_id(&self) -> Option<ProjectID> {
1571 self.project_id
1572 }
1573
1574 pub fn description(&self) -> &str {
1575 &self.description
1576 }
1577
1578 pub fn workflow(&self) -> &str {
1579 &self.workflow
1580 }
1581
1582 pub fn status(&self) -> &Option<String> {
1583 &self.status
1584 }
1585
1586 pub async fn set_status(&mut self, client: &Client, status: &str) -> Result<(), Error> {
1587 let t = client.task_status(self.id(), status).await?;
1588 self.status = Some(t.status);
1589 Ok(())
1590 }
1591
1592 pub fn stages(&self) -> HashMap<String, Stage> {
1593 match &self.progress.stages {
1594 Some(stages) => stages.clone(),
1595 None => HashMap::new(),
1596 }
1597 }
1598
1599 pub async fn update_stage(
1600 &mut self,
1601 client: &Client,
1602 stage: &str,
1603 status: &str,
1604 message: &str,
1605 percentage: u8,
1606 ) -> Result<(), Error> {
1607 client
1608 .update_stage(self.id(), stage, status, message, percentage)
1609 .await?;
1610 let t = client.task_info(self.id()).await?;
1611 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1612 Ok(())
1613 }
1614
1615 pub async fn set_stages(
1616 &mut self,
1617 client: &Client,
1618 stages: &[(&str, &str)],
1619 ) -> Result<(), Error> {
1620 client.set_stages(self.id(), stages).await?;
1621 let t = client.task_info(self.id()).await?;
1622 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1623 Ok(())
1624 }
1625
1626 pub fn created(&self) -> &DateTime<Utc> {
1627 &self.created
1628 }
1629
1630 pub fn completed(&self) -> &DateTime<Utc> {
1631 &self.completed
1632 }
1633}
1634
1635#[derive(Deserialize, Debug)]
1636pub struct TaskProgress {
1637 stages: Option<HashMap<String, Stage>>,
1638}
1639
1640#[derive(Serialize, Debug, Clone)]
1641pub struct TaskStatus {
1642 #[serde(rename = "docker_task_id")]
1643 pub task_id: TaskID,
1644 pub status: String,
1645}
1646
1647#[derive(Serialize, Deserialize, Debug, Clone)]
1648pub struct Stage {
1649 #[serde(rename = "docker_task_id", skip_serializing_if = "Option::is_none")]
1650 task_id: Option<TaskID>,
1651 stage: String,
1652 #[serde(skip_serializing_if = "Option::is_none")]
1653 status: Option<String>,
1654 #[serde(skip_serializing_if = "Option::is_none")]
1655 description: Option<String>,
1656 #[serde(skip_serializing_if = "Option::is_none")]
1657 message: Option<String>,
1658 percentage: u8,
1659}
1660
1661impl Stage {
1662 pub fn new(
1663 task_id: Option<TaskID>,
1664 stage: String,
1665 status: Option<String>,
1666 message: Option<String>,
1667 percentage: u8,
1668 ) -> Self {
1669 Stage {
1670 task_id,
1671 stage,
1672 status,
1673 description: None,
1674 message,
1675 percentage,
1676 }
1677 }
1678
1679 pub fn task_id(&self) -> &Option<TaskID> {
1680 &self.task_id
1681 }
1682
1683 pub fn stage(&self) -> &str {
1684 &self.stage
1685 }
1686
1687 pub fn status(&self) -> &Option<String> {
1688 &self.status
1689 }
1690
1691 pub fn description(&self) -> &Option<String> {
1692 &self.description
1693 }
1694
1695 pub fn message(&self) -> &Option<String> {
1696 &self.message
1697 }
1698
1699 pub fn percentage(&self) -> u8 {
1700 self.percentage
1701 }
1702}
1703
1704#[derive(Serialize, Debug)]
1705pub struct TaskStages {
1706 #[serde(rename = "docker_task_id")]
1707 pub task_id: TaskID,
1708 pub stages: Vec<HashMap<String, String>>,
1709}
1710
1711#[derive(Deserialize, Debug)]
1712pub struct Artifact {
1713 name: String,
1714 #[serde(rename = "modelType")]
1715 model_type: String,
1716}
1717
1718impl Artifact {
1719 pub fn name(&self) -> &str {
1720 &self.name
1721 }
1722
1723 pub fn model_type(&self) -> &str {
1724 &self.model_type
1725 }
1726}