1use crate::{AnnotationSet, Client, Dataset, Error, Sample, client};
5use chrono::{DateTime, Utc};
6use log::trace;
7use reqwest::multipart::{Form, Part};
8use serde::{Deserialize, Deserializer, Serialize};
9use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr};
10
11#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
41#[serde(untagged)]
42pub enum Parameter {
43 Integer(i64),
45 Real(f64),
47 Boolean(bool),
49 String(String),
51 Array(Vec<Parameter>),
53 Object(HashMap<String, Parameter>),
55}
56
57#[derive(Deserialize)]
58pub struct LoginResult {
59 pub(crate) token: String,
60}
61
62#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
82pub struct OrganizationID(u64);
83
84impl Display for OrganizationID {
85 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
86 write!(f, "org-{:x}", self.0)
87 }
88}
89
90impl From<u64> for OrganizationID {
91 fn from(id: u64) -> Self {
92 OrganizationID(id)
93 }
94}
95
96impl From<OrganizationID> for u64 {
97 fn from(val: OrganizationID) -> Self {
98 val.0
99 }
100}
101
102impl OrganizationID {
103 pub fn value(&self) -> u64 {
104 self.0
105 }
106}
107
108impl TryFrom<&str> for OrganizationID {
109 type Error = Error;
110
111 fn try_from(s: &str) -> Result<Self, Self::Error> {
112 let hex_part = s.strip_prefix("org-").ok_or_else(|| {
113 Error::InvalidParameters("Organization ID must start with 'org-' prefix".to_string())
114 })?;
115 let id = u64::from_str_radix(hex_part, 16)?;
116 Ok(OrganizationID(id))
117 }
118}
119
120#[derive(Deserialize, Clone, Debug)]
141pub struct Organization {
142 id: OrganizationID,
143 name: String,
144 #[serde(rename = "latest_credit")]
145 credits: i64,
146}
147
148impl Display for Organization {
149 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
150 write!(f, "{}", self.name())
151 }
152}
153
154impl Organization {
155 pub fn id(&self) -> OrganizationID {
156 self.id
157 }
158
159 pub fn name(&self) -> &str {
160 &self.name
161 }
162
163 pub fn credits(&self) -> i64 {
164 self.credits
165 }
166}
167
168#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
189pub struct ProjectID(u64);
190
191impl Display for ProjectID {
192 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
193 write!(f, "p-{:x}", self.0)
194 }
195}
196
197impl From<u64> for ProjectID {
198 fn from(id: u64) -> Self {
199 ProjectID(id)
200 }
201}
202
203impl From<ProjectID> for u64 {
204 fn from(val: ProjectID) -> Self {
205 val.0
206 }
207}
208
209impl ProjectID {
210 pub fn value(&self) -> u64 {
211 self.0
212 }
213}
214
215impl TryFrom<&str> for ProjectID {
216 type Error = Error;
217
218 fn try_from(s: &str) -> Result<Self, Self::Error> {
219 ProjectID::from_str(s)
220 }
221}
222
223impl TryFrom<String> for ProjectID {
224 type Error = Error;
225
226 fn try_from(s: String) -> Result<Self, Self::Error> {
227 ProjectID::from_str(&s)
228 }
229}
230
231impl FromStr for ProjectID {
232 type Err = Error;
233
234 fn from_str(s: &str) -> Result<Self, Self::Err> {
235 let hex_part = s.strip_prefix("p-").ok_or_else(|| {
236 Error::InvalidParameters("Project ID must start with 'p-' prefix".to_string())
237 })?;
238 let id = u64::from_str_radix(hex_part, 16)?;
239 Ok(ProjectID(id))
240 }
241}
242
243#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
264pub struct ExperimentID(u64);
265
266impl Display for ExperimentID {
267 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
268 write!(f, "exp-{:x}", self.0)
269 }
270}
271
272impl From<u64> for ExperimentID {
273 fn from(id: u64) -> Self {
274 ExperimentID(id)
275 }
276}
277
278impl From<ExperimentID> for u64 {
279 fn from(val: ExperimentID) -> Self {
280 val.0
281 }
282}
283
284impl ExperimentID {
285 pub fn value(&self) -> u64 {
286 self.0
287 }
288}
289
290impl TryFrom<&str> for ExperimentID {
291 type Error = Error;
292
293 fn try_from(s: &str) -> Result<Self, Self::Error> {
294 ExperimentID::from_str(s)
295 }
296}
297
298impl TryFrom<String> for ExperimentID {
299 type Error = Error;
300
301 fn try_from(s: String) -> Result<Self, Self::Error> {
302 ExperimentID::from_str(&s)
303 }
304}
305
306impl FromStr for ExperimentID {
307 type Err = Error;
308
309 fn from_str(s: &str) -> Result<Self, Self::Err> {
310 let hex_part = s.strip_prefix("exp-").ok_or_else(|| {
311 Error::InvalidParameters("Experiment ID must start with 'exp-' prefix".to_string())
312 })?;
313 let id = u64::from_str_radix(hex_part, 16)?;
314 Ok(ExperimentID(id))
315 }
316}
317
318#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
339pub struct TrainingSessionID(u64);
340
341impl Display for TrainingSessionID {
342 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
343 write!(f, "t-{:x}", self.0)
344 }
345}
346
347impl From<u64> for TrainingSessionID {
348 fn from(id: u64) -> Self {
349 TrainingSessionID(id)
350 }
351}
352
353impl From<TrainingSessionID> for u64 {
354 fn from(val: TrainingSessionID) -> Self {
355 val.0
356 }
357}
358
359impl TrainingSessionID {
360 pub fn value(&self) -> u64 {
361 self.0
362 }
363}
364
365impl TryFrom<&str> for TrainingSessionID {
366 type Error = Error;
367
368 fn try_from(s: &str) -> Result<Self, Self::Error> {
369 TrainingSessionID::from_str(s)
370 }
371}
372
373impl TryFrom<String> for TrainingSessionID {
374 type Error = Error;
375
376 fn try_from(s: String) -> Result<Self, Self::Error> {
377 TrainingSessionID::from_str(&s)
378 }
379}
380
381impl FromStr for TrainingSessionID {
382 type Err = Error;
383
384 fn from_str(s: &str) -> Result<Self, Self::Err> {
385 let hex_part = s.strip_prefix("t-").ok_or_else(|| {
386 Error::InvalidParameters("Training Session ID must start with 't-' prefix".to_string())
387 })?;
388 let id = u64::from_str_radix(hex_part, 16)?;
389 Ok(TrainingSessionID(id))
390 }
391}
392
393#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
413pub struct ValidationSessionID(u64);
414
415impl Display for ValidationSessionID {
416 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
417 write!(f, "v-{:x}", self.0)
418 }
419}
420
421impl From<u64> for ValidationSessionID {
422 fn from(id: u64) -> Self {
423 ValidationSessionID(id)
424 }
425}
426
427impl From<ValidationSessionID> for u64 {
428 fn from(val: ValidationSessionID) -> Self {
429 val.0
430 }
431}
432
433impl ValidationSessionID {
434 pub fn value(&self) -> u64 {
435 self.0
436 }
437}
438
439impl TryFrom<&str> for ValidationSessionID {
440 type Error = Error;
441
442 fn try_from(s: &str) -> Result<Self, Self::Error> {
443 let hex_part = s.strip_prefix("v-").ok_or_else(|| {
444 Error::InvalidParameters(
445 "Validation Session ID must start with 'v-' prefix".to_string(),
446 )
447 })?;
448 let id = u64::from_str_radix(hex_part, 16)?;
449 Ok(ValidationSessionID(id))
450 }
451}
452
453#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
454pub struct SnapshotID(u64);
455
456impl Display for SnapshotID {
457 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
458 write!(f, "ss-{:x}", self.0)
459 }
460}
461
462impl From<u64> for SnapshotID {
463 fn from(id: u64) -> Self {
464 SnapshotID(id)
465 }
466}
467
468impl From<SnapshotID> for u64 {
469 fn from(val: SnapshotID) -> Self {
470 val.0
471 }
472}
473
474impl SnapshotID {
475 pub fn value(&self) -> u64 {
476 self.0
477 }
478}
479
480impl TryFrom<&str> for SnapshotID {
481 type Error = Error;
482
483 fn try_from(s: &str) -> Result<Self, Self::Error> {
484 let hex_part = s.strip_prefix("ss-").ok_or_else(|| {
485 Error::InvalidParameters("Snapshot ID must start with 'ss-' prefix".to_string())
486 })?;
487 let id = u64::from_str_radix(hex_part, 16)?;
488 Ok(SnapshotID(id))
489 }
490}
491
492#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
493pub struct TaskID(u64);
494
495impl Display for TaskID {
496 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
497 write!(f, "task-{:x}", self.0)
498 }
499}
500
501impl From<u64> for TaskID {
502 fn from(id: u64) -> Self {
503 TaskID(id)
504 }
505}
506
507impl From<TaskID> for u64 {
508 fn from(val: TaskID) -> Self {
509 val.0
510 }
511}
512
513impl TaskID {
514 pub fn value(&self) -> u64 {
515 self.0
516 }
517}
518
519impl TryFrom<&str> for TaskID {
520 type Error = Error;
521
522 fn try_from(s: &str) -> Result<Self, Self::Error> {
523 TaskID::from_str(s)
524 }
525}
526
527impl TryFrom<String> for TaskID {
528 type Error = Error;
529
530 fn try_from(s: String) -> Result<Self, Self::Error> {
531 TaskID::from_str(&s)
532 }
533}
534
535impl FromStr for TaskID {
536 type Err = Error;
537
538 fn from_str(s: &str) -> Result<Self, Self::Err> {
539 let hex_part = s.strip_prefix("task-").ok_or_else(|| {
540 Error::InvalidParameters("Task ID must start with 'task-' prefix".to_string())
541 })?;
542 let id = u64::from_str_radix(hex_part, 16)?;
543 Ok(TaskID(id))
544 }
545}
546
547#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
568pub struct DatasetID(u64);
569
570impl Display for DatasetID {
571 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
572 write!(f, "ds-{:x}", self.0)
573 }
574}
575
576impl From<u64> for DatasetID {
577 fn from(id: u64) -> Self {
578 DatasetID(id)
579 }
580}
581
582impl From<DatasetID> for u64 {
583 fn from(val: DatasetID) -> Self {
584 val.0
585 }
586}
587
588impl DatasetID {
589 pub fn value(&self) -> u64 {
590 self.0
591 }
592}
593
594impl TryFrom<&str> for DatasetID {
595 type Error = Error;
596
597 fn try_from(s: &str) -> Result<Self, Self::Error> {
598 DatasetID::from_str(s)
599 }
600}
601
602impl TryFrom<String> for DatasetID {
603 type Error = Error;
604
605 fn try_from(s: String) -> Result<Self, Self::Error> {
606 DatasetID::from_str(&s)
607 }
608}
609
610impl FromStr for DatasetID {
611 type Err = Error;
612
613 fn from_str(s: &str) -> Result<Self, Self::Err> {
614 let hex_part = s.strip_prefix("ds-").ok_or_else(|| {
615 Error::InvalidParameters("Dataset ID must start with 'ds-' prefix".to_string())
616 })?;
617 let id = u64::from_str_radix(hex_part, 16)?;
618 Ok(DatasetID(id))
619 }
620}
621
622#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
623pub struct AnnotationSetID(u64);
624
625impl Display for AnnotationSetID {
626 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
627 write!(f, "as-{:x}", self.0)
628 }
629}
630
631impl From<u64> for AnnotationSetID {
632 fn from(id: u64) -> Self {
633 AnnotationSetID(id)
634 }
635}
636
637impl From<AnnotationSetID> for u64 {
638 fn from(val: AnnotationSetID) -> Self {
639 val.0
640 }
641}
642
643impl AnnotationSetID {
644 pub fn value(&self) -> u64 {
645 self.0
646 }
647}
648
649impl TryFrom<&str> for AnnotationSetID {
650 type Error = Error;
651
652 fn try_from(s: &str) -> Result<Self, Self::Error> {
653 AnnotationSetID::from_str(s)
654 }
655}
656
657impl TryFrom<String> for AnnotationSetID {
658 type Error = Error;
659
660 fn try_from(s: String) -> Result<Self, Self::Error> {
661 AnnotationSetID::from_str(&s)
662 }
663}
664
665impl FromStr for AnnotationSetID {
666 type Err = Error;
667
668 fn from_str(s: &str) -> Result<Self, Self::Err> {
669 let hex_part = s.strip_prefix("as-").ok_or_else(|| {
670 Error::InvalidParameters("Annotation Set ID must start with 'as-' prefix".to_string())
671 })?;
672 let id = u64::from_str_radix(hex_part, 16)?;
673 Ok(AnnotationSetID(id))
674 }
675}
676
677#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
678pub struct SampleID(u64);
679
680impl Display for SampleID {
681 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
682 write!(f, "s-{:x}", self.0)
683 }
684}
685
686impl From<u64> for SampleID {
687 fn from(id: u64) -> Self {
688 SampleID(id)
689 }
690}
691
692impl From<SampleID> for u64 {
693 fn from(val: SampleID) -> Self {
694 val.0
695 }
696}
697
698impl SampleID {
699 pub fn value(&self) -> u64 {
700 self.0
701 }
702}
703
704impl TryFrom<&str> for SampleID {
705 type Error = Error;
706
707 fn try_from(s: &str) -> Result<Self, Self::Error> {
708 let hex_part = s.strip_prefix("s-").ok_or_else(|| {
709 Error::InvalidParameters("Sample ID must start with 's-' prefix".to_string())
710 })?;
711 let id = u64::from_str_radix(hex_part, 16)?;
712 Ok(SampleID(id))
713 }
714}
715
716#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
717pub struct AppId(u64);
718
719impl Display for AppId {
720 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
721 write!(f, "app-{:x}", self.0)
722 }
723}
724
725impl From<u64> for AppId {
726 fn from(id: u64) -> Self {
727 AppId(id)
728 }
729}
730
731impl From<AppId> for u64 {
732 fn from(val: AppId) -> Self {
733 val.0
734 }
735}
736
737impl AppId {
738 pub fn value(&self) -> u64 {
739 self.0
740 }
741}
742
743impl TryFrom<&str> for AppId {
744 type Error = Error;
745
746 fn try_from(s: &str) -> Result<Self, Self::Error> {
747 let hex_part = s.strip_prefix("app-").ok_or_else(|| {
748 Error::InvalidParameters("App ID must start with 'app-' prefix".to_string())
749 })?;
750 let id = u64::from_str_radix(hex_part, 16)?;
751 Ok(AppId(id))
752 }
753}
754
755#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
756pub struct ImageId(u64);
757
758impl Display for ImageId {
759 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
760 write!(f, "im-{:x}", self.0)
761 }
762}
763
764impl From<u64> for ImageId {
765 fn from(id: u64) -> Self {
766 ImageId(id)
767 }
768}
769
770impl From<ImageId> for u64 {
771 fn from(val: ImageId) -> Self {
772 val.0
773 }
774}
775
776impl ImageId {
777 pub fn value(&self) -> u64 {
778 self.0
779 }
780}
781
782impl TryFrom<&str> for ImageId {
783 type Error = Error;
784
785 fn try_from(s: &str) -> Result<Self, Self::Error> {
786 let hex_part = s.strip_prefix("im-").ok_or_else(|| {
787 Error::InvalidParameters("Image ID must start with 'im-' prefix".to_string())
788 })?;
789 let id = u64::from_str_radix(hex_part, 16)?;
790 Ok(ImageId(id))
791 }
792}
793
794#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
795pub struct SequenceId(u64);
796
797impl Display for SequenceId {
798 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
799 write!(f, "se-{:x}", self.0)
800 }
801}
802
803impl From<u64> for SequenceId {
804 fn from(id: u64) -> Self {
805 SequenceId(id)
806 }
807}
808
809impl From<SequenceId> for u64 {
810 fn from(val: SequenceId) -> Self {
811 val.0
812 }
813}
814
815impl SequenceId {
816 pub fn value(&self) -> u64 {
817 self.0
818 }
819}
820
821impl TryFrom<&str> for SequenceId {
822 type Error = Error;
823
824 fn try_from(s: &str) -> Result<Self, Self::Error> {
825 let hex_part = s.strip_prefix("se-").ok_or_else(|| {
826 Error::InvalidParameters("Sequence ID must start with 'se-' prefix".to_string())
827 })?;
828 let id = u64::from_str_radix(hex_part, 16)?;
829 Ok(SequenceId(id))
830 }
831}
832
833#[derive(Deserialize, Clone, Debug)]
837pub struct Project {
838 id: ProjectID,
839 name: String,
840 description: String,
841}
842
843impl Display for Project {
844 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
845 write!(f, "{} {}", self.id(), self.name())
846 }
847}
848
849impl Project {
850 pub fn id(&self) -> ProjectID {
851 self.id
852 }
853
854 pub fn name(&self) -> &str {
855 &self.name
856 }
857
858 pub fn description(&self) -> &str {
859 &self.description
860 }
861
862 pub async fn datasets(
863 &self,
864 client: &client::Client,
865 name: Option<&str>,
866 ) -> Result<Vec<Dataset>, Error> {
867 client.datasets(self.id, name).await
868 }
869
870 pub async fn experiments(
871 &self,
872 client: &client::Client,
873 name: Option<&str>,
874 ) -> Result<Vec<Experiment>, Error> {
875 client.experiments(self.id, name).await
876 }
877}
878
879#[derive(Deserialize, Debug)]
880pub struct SamplesCountResult {
881 pub total: u64,
882}
883
884#[derive(Serialize, Clone, Debug)]
885pub struct SamplesListParams {
886 pub dataset_id: DatasetID,
887 #[serde(skip_serializing_if = "Option::is_none")]
888 pub annotation_set_id: Option<AnnotationSetID>,
889 #[serde(skip_serializing_if = "Option::is_none")]
890 pub continue_token: Option<String>,
891 #[serde(skip_serializing_if = "Vec::is_empty")]
892 pub types: Vec<String>,
893 #[serde(skip_serializing_if = "Vec::is_empty")]
894 pub group_names: Vec<String>,
895}
896
897#[derive(Deserialize, Debug)]
898pub struct SamplesListResult {
899 pub samples: Vec<Sample>,
900 pub continue_token: Option<String>,
901}
902
903#[derive(Deserialize)]
904pub struct Snapshot {
905 id: SnapshotID,
906 description: String,
907 status: String,
908 path: String,
909 #[serde(rename = "date")]
910 created: DateTime<Utc>,
911}
912
913impl Display for Snapshot {
914 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
915 write!(f, "{} {}", self.id, self.description)
916 }
917}
918
919impl Snapshot {
920 pub fn id(&self) -> SnapshotID {
921 self.id
922 }
923
924 pub fn description(&self) -> &str {
925 &self.description
926 }
927
928 pub fn status(&self) -> &str {
929 &self.status
930 }
931
932 pub fn path(&self) -> &str {
933 &self.path
934 }
935
936 pub fn created(&self) -> &DateTime<Utc> {
937 &self.created
938 }
939}
940
941#[derive(Serialize, Debug)]
942pub struct SnapshotRestore {
943 pub project_id: ProjectID,
944 pub snapshot_id: SnapshotID,
945 pub fps: u64,
946 #[serde(rename = "enabled_topics")]
947 pub topics: Vec<String>,
948 #[serde(rename = "label_names")]
949 pub autolabel: Vec<String>,
950 #[serde(rename = "depth_gen")]
951 pub autodepth: bool,
952 pub agtg_pipeline: bool,
953 #[serde(skip_serializing_if = "Option::is_none")]
954 pub dataset_name: Option<String>,
955 #[serde(skip_serializing_if = "Option::is_none")]
956 pub dataset_description: Option<String>,
957}
958
959#[derive(Deserialize, Debug)]
960pub struct SnapshotRestoreResult {
961 pub id: SnapshotID,
962 pub description: String,
963 pub dataset_name: String,
964 pub dataset_id: DatasetID,
965 pub annotation_set_id: AnnotationSetID,
966 #[serde(rename = "docker_task_id")]
967 pub task_id: TaskID,
968 pub date: DateTime<Utc>,
969}
970
971#[derive(Deserialize)]
972pub struct Experiment {
973 id: ExperimentID,
974 project_id: ProjectID,
975 name: String,
976 description: String,
977}
978
979impl Display for Experiment {
980 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
981 write!(f, "{} {}", self.uid(), self.name)
982 }
983}
984
985impl Experiment {
986 pub fn id(&self) -> ExperimentID {
987 self.id
988 }
989
990 pub fn uid(&self) -> String {
991 self.id.to_string()
992 }
993
994 pub fn project_id(&self) -> ProjectID {
995 self.project_id
996 }
997
998 pub fn name(&self) -> &str {
999 &self.name
1000 }
1001
1002 pub fn description(&self) -> &str {
1003 &self.description
1004 }
1005
1006 pub async fn project(&self, client: &client::Client) -> Result<Project, Error> {
1007 client.project(self.project_id).await
1008 }
1009
1010 pub async fn training_sessions(
1011 &self,
1012 client: &client::Client,
1013 name: Option<&str>,
1014 ) -> Result<Vec<TrainingSession>, Error> {
1015 client.training_sessions(self.id, name).await
1016 }
1017}
1018
1019#[derive(Serialize, Debug)]
1020pub struct PublishMetrics {
1021 #[serde(rename = "trainer_session_id", skip_serializing_if = "Option::is_none")]
1022 pub trainer_session_id: Option<TrainingSessionID>,
1023 #[serde(
1024 rename = "validate_session_id",
1025 skip_serializing_if = "Option::is_none"
1026 )]
1027 pub validate_session_id: Option<ValidationSessionID>,
1028 pub metrics: HashMap<String, Parameter>,
1029}
1030
1031#[derive(Deserialize)]
1032struct TrainingSessionParams {
1033 model_params: HashMap<String, Parameter>,
1034 dataset_params: DatasetParams,
1035}
1036
1037#[derive(Deserialize)]
1038pub struct TrainingSession {
1039 id: TrainingSessionID,
1040 #[serde(rename = "trainer_id")]
1041 experiment_id: ExperimentID,
1042 model: String,
1043 name: String,
1044 description: String,
1045 params: TrainingSessionParams,
1046 #[serde(rename = "docker_task")]
1047 task: Task,
1048}
1049
1050impl Display for TrainingSession {
1051 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1052 write!(f, "{} {}", self.uid(), self.name())
1053 }
1054}
1055
1056impl TrainingSession {
1057 pub fn id(&self) -> TrainingSessionID {
1058 self.id
1059 }
1060
1061 pub fn uid(&self) -> String {
1062 self.id.to_string()
1063 }
1064
1065 pub fn name(&self) -> &str {
1066 &self.name
1067 }
1068
1069 pub fn description(&self) -> &str {
1070 &self.description
1071 }
1072
1073 pub fn model(&self) -> &str {
1074 &self.model
1075 }
1076
1077 pub fn experiment_id(&self) -> ExperimentID {
1078 self.experiment_id
1079 }
1080
1081 pub fn task(&self) -> Task {
1082 self.task.clone()
1083 }
1084
1085 pub fn model_params(&self) -> &HashMap<String, Parameter> {
1086 &self.params.model_params
1087 }
1088
1089 pub fn dataset_params(&self) -> &DatasetParams {
1090 &self.params.dataset_params
1091 }
1092
1093 pub fn train_group(&self) -> &str {
1094 &self.params.dataset_params.train_group
1095 }
1096
1097 pub fn val_group(&self) -> &str {
1098 &self.params.dataset_params.val_group
1099 }
1100
1101 pub async fn experiment(&self, client: &client::Client) -> Result<Experiment, Error> {
1102 client.experiment(self.experiment_id).await
1103 }
1104
1105 pub async fn dataset(&self, client: &client::Client) -> Result<Dataset, Error> {
1106 client.dataset(self.params.dataset_params.dataset_id).await
1107 }
1108
1109 pub async fn annotation_set(&self, client: &client::Client) -> Result<AnnotationSet, Error> {
1110 client
1111 .annotation_set(self.params.dataset_params.annotation_set_id)
1112 .await
1113 }
1114
1115 pub async fn artifacts(&self, client: &client::Client) -> Result<Vec<Artifact>, Error> {
1116 client.artifacts(self.id).await
1117 }
1118
1119 pub async fn metrics(
1120 &self,
1121 client: &client::Client,
1122 ) -> Result<HashMap<String, Parameter>, Error> {
1123 #[derive(Deserialize)]
1124 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
1125 enum Response {
1126 Empty {},
1127 Map(HashMap<String, Parameter>),
1128 String(String),
1129 }
1130
1131 let params = HashMap::from([("trainer_session_id", self.id().value())]);
1132 let resp: Response = client
1133 .rpc("trainer.session.metrics".to_owned(), Some(params))
1134 .await?;
1135
1136 Ok(match resp {
1137 Response::String(metrics) => serde_json::from_str(&metrics)?,
1138 Response::Map(metrics) => metrics,
1139 Response::Empty {} => HashMap::new(),
1140 })
1141 }
1142
1143 pub async fn set_metrics(
1144 &self,
1145 client: &client::Client,
1146 metrics: HashMap<String, Parameter>,
1147 ) -> Result<(), Error> {
1148 let metrics = PublishMetrics {
1149 trainer_session_id: Some(self.id()),
1150 validate_session_id: None,
1151 metrics,
1152 };
1153
1154 let _: String = client
1155 .rpc("trainer.session.metrics".to_owned(), Some(metrics))
1156 .await?;
1157
1158 Ok(())
1159 }
1160
1161 pub async fn download_artifact(
1163 &self,
1164 client: &client::Client,
1165 filename: &str,
1166 ) -> Result<Vec<u8>, Error> {
1167 client
1168 .fetch(&format!(
1169 "download_model?training_session_id={}&file={}",
1170 self.id().value(),
1171 filename
1172 ))
1173 .await
1174 }
1175
1176 pub async fn upload_artifact(
1180 &self,
1181 client: &client::Client,
1182 filename: &str,
1183 path: PathBuf,
1184 ) -> Result<(), Error> {
1185 self.upload(client, &[(format!("artifacts/{}", filename), path)])
1186 .await
1187 }
1188
1189 pub async fn download_checkpoint(
1191 &self,
1192 client: &client::Client,
1193 filename: &str,
1194 ) -> Result<Vec<u8>, Error> {
1195 client
1196 .fetch(&format!(
1197 "download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
1198 self.id().value(),
1199 filename
1200 ))
1201 .await
1202 }
1203
1204 pub async fn upload_checkpoint(
1208 &self,
1209 client: &client::Client,
1210 filename: &str,
1211 path: PathBuf,
1212 ) -> Result<(), Error> {
1213 self.upload(client, &[(format!("checkpoints/{}", filename), path)])
1214 .await
1215 }
1216
1217 pub async fn download(&self, client: &client::Client, filename: &str) -> Result<String, Error> {
1221 #[derive(Serialize)]
1222 struct DownloadRequest {
1223 session_id: TrainingSessionID,
1224 file_path: String,
1225 }
1226
1227 let params = DownloadRequest {
1228 session_id: self.id(),
1229 file_path: filename.to_string(),
1230 };
1231
1232 client
1233 .rpc("trainer.download.file".to_owned(), Some(params))
1234 .await
1235 }
1236
1237 pub async fn upload(
1238 &self,
1239 client: &client::Client,
1240 files: &[(String, PathBuf)],
1241 ) -> Result<(), Error> {
1242 let mut parts = Form::new().part(
1243 "params",
1244 Part::text(format!("{{ \"session_id\": {} }}", self.id().value())),
1245 );
1246
1247 for (name, path) in files {
1248 let file_part = Part::file(path).await?.file_name(name.to_owned());
1249 parts = parts.part("file", file_part);
1250 }
1251
1252 let result = client.post_multipart("trainer.upload.files", parts).await?;
1253 trace!("TrainingSession::upload: {:?}", result);
1254 Ok(())
1255 }
1256}
1257
1258#[derive(Deserialize, Clone, Debug)]
1259pub struct ValidationSession {
1260 id: ValidationSessionID,
1261 description: String,
1262 dataset_id: DatasetID,
1263 experiment_id: ExperimentID,
1264 training_session_id: TrainingSessionID,
1265 #[serde(rename = "gt_annotation_set_id")]
1266 annotation_set_id: AnnotationSetID,
1267 #[serde(deserialize_with = "validation_session_params")]
1268 params: HashMap<String, Parameter>,
1269 #[serde(rename = "docker_task")]
1270 task: Task,
1271}
1272
1273fn validation_session_params<'de, D>(
1274 deserializer: D,
1275) -> Result<HashMap<String, Parameter>, D::Error>
1276where
1277 D: Deserializer<'de>,
1278{
1279 #[derive(Deserialize)]
1280 struct ModelParams {
1281 validation: Option<HashMap<String, Parameter>>,
1282 }
1283
1284 #[derive(Deserialize)]
1285 struct ValidateParams {
1286 model: String,
1287 }
1288
1289 #[derive(Deserialize)]
1290 struct Params {
1291 model_params: ModelParams,
1292 validate_params: ValidateParams,
1293 }
1294
1295 let params = Params::deserialize(deserializer)?;
1296 let params = match params.model_params.validation {
1297 Some(mut map) => {
1298 map.insert(
1299 "model".to_string(),
1300 Parameter::String(params.validate_params.model),
1301 );
1302 map
1303 }
1304 None => HashMap::from([(
1305 "model".to_string(),
1306 Parameter::String(params.validate_params.model),
1307 )]),
1308 };
1309
1310 Ok(params)
1311}
1312
1313impl ValidationSession {
1314 pub fn id(&self) -> ValidationSessionID {
1315 self.id
1316 }
1317
1318 pub fn uid(&self) -> String {
1319 self.id.to_string()
1320 }
1321
1322 pub fn name(&self) -> &str {
1323 self.task.name()
1324 }
1325
1326 pub fn description(&self) -> &str {
1327 &self.description
1328 }
1329
1330 pub fn dataset_id(&self) -> DatasetID {
1331 self.dataset_id
1332 }
1333
1334 pub fn experiment_id(&self) -> ExperimentID {
1335 self.experiment_id
1336 }
1337
1338 pub fn training_session_id(&self) -> TrainingSessionID {
1339 self.training_session_id
1340 }
1341
1342 pub fn annotation_set_id(&self) -> AnnotationSetID {
1343 self.annotation_set_id
1344 }
1345
1346 pub fn params(&self) -> &HashMap<String, Parameter> {
1347 &self.params
1348 }
1349
1350 pub fn task(&self) -> &Task {
1351 &self.task
1352 }
1353
1354 pub async fn metrics(
1355 &self,
1356 client: &client::Client,
1357 ) -> Result<HashMap<String, Parameter>, Error> {
1358 #[derive(Deserialize)]
1359 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
1360 enum Response {
1361 Empty {},
1362 Map(HashMap<String, Parameter>),
1363 String(String),
1364 }
1365
1366 let params = HashMap::from([("validate_session_id", self.id().value())]);
1367 let resp: Response = client
1368 .rpc("validate.session.metrics".to_owned(), Some(params))
1369 .await?;
1370
1371 Ok(match resp {
1372 Response::String(metrics) => serde_json::from_str(&metrics)?,
1373 Response::Map(metrics) => metrics,
1374 Response::Empty {} => HashMap::new(),
1375 })
1376 }
1377
1378 pub async fn set_metrics(
1379 &self,
1380 client: &client::Client,
1381 metrics: HashMap<String, Parameter>,
1382 ) -> Result<(), Error> {
1383 let metrics = PublishMetrics {
1384 trainer_session_id: None,
1385 validate_session_id: Some(self.id()),
1386 metrics,
1387 };
1388
1389 let _: String = client
1390 .rpc("validate.session.metrics".to_owned(), Some(metrics))
1391 .await?;
1392
1393 Ok(())
1394 }
1395
1396 pub async fn upload(
1397 &self,
1398 client: &client::Client,
1399 files: &[(String, PathBuf)],
1400 ) -> Result<(), Error> {
1401 let mut parts = Form::new().part(
1402 "params",
1403 Part::text(format!("{{ \"session_id\": {} }}", self.id().value())),
1404 );
1405
1406 for (name, path) in files {
1407 let file_part = Part::file(path).await?.file_name(name.to_owned());
1408 parts = parts.part("file", file_part);
1409 }
1410
1411 let result = client
1412 .post_multipart("validate.upload.files", parts)
1413 .await?;
1414 trace!("ValidationSession::upload: {:?}", result);
1415 Ok(())
1416 }
1417}
1418
1419#[derive(Deserialize, Clone, Debug)]
1420pub struct DatasetParams {
1421 dataset_id: DatasetID,
1422 annotation_set_id: AnnotationSetID,
1423 #[serde(rename = "train_group_name")]
1424 train_group: String,
1425 #[serde(rename = "val_group_name")]
1426 val_group: String,
1427}
1428
1429impl DatasetParams {
1430 pub fn dataset_id(&self) -> DatasetID {
1431 self.dataset_id
1432 }
1433
1434 pub fn annotation_set_id(&self) -> AnnotationSetID {
1435 self.annotation_set_id
1436 }
1437
1438 pub fn train_group(&self) -> &str {
1439 &self.train_group
1440 }
1441
1442 pub fn val_group(&self) -> &str {
1443 &self.val_group
1444 }
1445}
1446
1447#[derive(Serialize, Debug, Clone)]
1448pub struct TasksListParams {
1449 #[serde(skip_serializing_if = "Option::is_none")]
1450 pub continue_token: Option<String>,
1451 #[serde(rename = "manage_types", skip_serializing_if = "Option::is_none")]
1452 pub manager: Option<Vec<String>>,
1453 #[serde(skip_serializing_if = "Option::is_none")]
1454 pub status: Option<Vec<String>>,
1455}
1456
1457#[derive(Deserialize, Debug, Clone)]
1458pub struct TasksListResult {
1459 pub tasks: Vec<Task>,
1460 pub continue_token: Option<String>,
1461}
1462
1463#[derive(Deserialize, Debug, Clone)]
1464pub struct Task {
1465 id: TaskID,
1466 name: String,
1467 #[serde(rename = "type")]
1468 workflow: String,
1469 status: String,
1470 #[serde(rename = "manage_type")]
1471 manager: Option<String>,
1472 #[serde(rename = "instance_type")]
1473 instance: String,
1474 #[serde(rename = "date")]
1475 created: DateTime<Utc>,
1476}
1477
1478impl Task {
1479 pub fn id(&self) -> TaskID {
1480 self.id
1481 }
1482
1483 pub fn uid(&self) -> String {
1484 self.id.to_string()
1485 }
1486
1487 pub fn name(&self) -> &str {
1488 &self.name
1489 }
1490
1491 pub fn workflow(&self) -> &str {
1492 &self.workflow
1493 }
1494
1495 pub fn status(&self) -> &str {
1496 &self.status
1497 }
1498
1499 pub fn manager(&self) -> Option<&str> {
1500 self.manager.as_deref()
1501 }
1502
1503 pub fn instance(&self) -> &str {
1504 &self.instance
1505 }
1506
1507 pub fn created(&self) -> &DateTime<Utc> {
1508 &self.created
1509 }
1510}
1511
1512impl Display for Task {
1513 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1514 write!(
1515 f,
1516 "{} [{:?} {}] {}",
1517 self.uid(),
1518 self.manager(),
1519 self.workflow(),
1520 self.name()
1521 )
1522 }
1523}
1524
1525#[derive(Deserialize, Debug)]
1526pub struct TaskInfo {
1527 id: TaskID,
1528 project_id: Option<ProjectID>,
1529 #[serde(rename = "task_description")]
1530 description: String,
1531 #[serde(rename = "type")]
1532 workflow: String,
1533 status: Option<String>,
1534 progress: TaskProgress,
1535 #[serde(rename = "created_date")]
1536 created: DateTime<Utc>,
1537 #[serde(rename = "end_date")]
1538 completed: DateTime<Utc>,
1539}
1540
1541impl Display for TaskInfo {
1542 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1543 write!(
1544 f,
1545 "{} {}: {}",
1546 self.uid(),
1547 self.workflow(),
1548 self.description()
1549 )
1550 }
1551}
1552
1553impl TaskInfo {
1554 pub fn id(&self) -> TaskID {
1555 self.id
1556 }
1557
1558 pub fn uid(&self) -> String {
1559 self.id.to_string()
1560 }
1561
1562 pub fn project_id(&self) -> Option<ProjectID> {
1563 self.project_id
1564 }
1565
1566 pub fn description(&self) -> &str {
1567 &self.description
1568 }
1569
1570 pub fn workflow(&self) -> &str {
1571 &self.workflow
1572 }
1573
1574 pub fn status(&self) -> &Option<String> {
1575 &self.status
1576 }
1577
1578 pub async fn set_status(&mut self, client: &Client, status: &str) -> Result<(), Error> {
1579 let t = client.task_status(self.id(), status).await?;
1580 self.status = Some(t.status);
1581 Ok(())
1582 }
1583
1584 pub fn stages(&self) -> HashMap<String, Stage> {
1585 match &self.progress.stages {
1586 Some(stages) => stages.clone(),
1587 None => HashMap::new(),
1588 }
1589 }
1590
1591 pub async fn update_stage(
1592 &mut self,
1593 client: &Client,
1594 stage: &str,
1595 status: &str,
1596 message: &str,
1597 percentage: u8,
1598 ) -> Result<(), Error> {
1599 client
1600 .update_stage(self.id(), stage, status, message, percentage)
1601 .await?;
1602 let t = client.task_info(self.id()).await?;
1603 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1604 Ok(())
1605 }
1606
1607 pub async fn set_stages(
1608 &mut self,
1609 client: &Client,
1610 stages: &[(&str, &str)],
1611 ) -> Result<(), Error> {
1612 client.set_stages(self.id(), stages).await?;
1613 let t = client.task_info(self.id()).await?;
1614 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1615 Ok(())
1616 }
1617
1618 pub fn created(&self) -> &DateTime<Utc> {
1619 &self.created
1620 }
1621
1622 pub fn completed(&self) -> &DateTime<Utc> {
1623 &self.completed
1624 }
1625}
1626
1627#[derive(Deserialize, Debug)]
1628pub struct TaskProgress {
1629 stages: Option<HashMap<String, Stage>>,
1630}
1631
1632#[derive(Serialize, Debug, Clone)]
1633pub struct TaskStatus {
1634 #[serde(rename = "docker_task_id")]
1635 pub task_id: TaskID,
1636 pub status: String,
1637}
1638
1639#[derive(Serialize, Deserialize, Debug, Clone)]
1640pub struct Stage {
1641 #[serde(rename = "docker_task_id", skip_serializing_if = "Option::is_none")]
1642 task_id: Option<TaskID>,
1643 stage: String,
1644 #[serde(skip_serializing_if = "Option::is_none")]
1645 status: Option<String>,
1646 #[serde(skip_serializing_if = "Option::is_none")]
1647 description: Option<String>,
1648 #[serde(skip_serializing_if = "Option::is_none")]
1649 message: Option<String>,
1650 percentage: u8,
1651}
1652
1653impl Stage {
1654 pub fn new(
1655 task_id: Option<TaskID>,
1656 stage: String,
1657 status: Option<String>,
1658 message: Option<String>,
1659 percentage: u8,
1660 ) -> Self {
1661 Stage {
1662 task_id,
1663 stage,
1664 status,
1665 description: None,
1666 message,
1667 percentage,
1668 }
1669 }
1670
1671 pub fn task_id(&self) -> &Option<TaskID> {
1672 &self.task_id
1673 }
1674
1675 pub fn stage(&self) -> &str {
1676 &self.stage
1677 }
1678
1679 pub fn status(&self) -> &Option<String> {
1680 &self.status
1681 }
1682
1683 pub fn description(&self) -> &Option<String> {
1684 &self.description
1685 }
1686
1687 pub fn message(&self) -> &Option<String> {
1688 &self.message
1689 }
1690
1691 pub fn percentage(&self) -> u8 {
1692 self.percentage
1693 }
1694}
1695
1696#[derive(Serialize, Debug)]
1697pub struct TaskStages {
1698 #[serde(rename = "docker_task_id")]
1699 pub task_id: TaskID,
1700 pub stages: Vec<HashMap<String, String>>,
1701}
1702
1703#[derive(Deserialize, Debug)]
1704pub struct Artifact {
1705 name: String,
1706 #[serde(rename = "modelType")]
1707 model_type: String,
1708}
1709
1710impl Artifact {
1711 pub fn name(&self) -> &str {
1712 &self.name
1713 }
1714
1715 pub fn model_type(&self) -> &str {
1716 &self.model_type
1717 }
1718}