1use crate::{AnnotationSet, Client, Dataset, Error, Progress, Sample, client};
5use chrono::{DateTime, Utc};
6use log::trace;
7use reqwest::multipart::{Form, Part};
8use serde::{Deserialize, Deserializer, Serialize};
9use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr};
10
11#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
41#[serde(untagged)]
42pub enum Parameter {
43 Integer(i64),
45 Real(f64),
47 Boolean(bool),
49 String(String),
51 Array(Vec<Parameter>),
53 Object(HashMap<String, Parameter>),
55}
56
57#[derive(Deserialize)]
58pub struct LoginResult {
59 pub(crate) token: String,
60}
61
62macro_rules! typeid {
71 ($(#[$meta:meta])* $name:ident, $prefix:literal) => {
72 $(#[$meta])*
73 #[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
74 pub struct $name(u64);
75
76 impl Display for $name {
77 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
78 write!(f, concat!($prefix, "-{:x}"), self.0)
79 }
80 }
81
82 impl From<u64> for $name {
83 fn from(id: u64) -> Self {
84 $name(id)
85 }
86 }
87
88 impl From<$name> for u64 {
89 fn from(val: $name) -> Self {
90 val.0
91 }
92 }
93
94 impl $name {
95 pub fn value(&self) -> u64 {
97 self.0
98 }
99 }
100
101 impl TryFrom<&str> for $name {
102 type Error = Error;
103
104 fn try_from(s: &str) -> Result<Self, Self::Error> {
105 $name::from_str(s)
106 }
107 }
108
109 impl TryFrom<String> for $name {
110 type Error = Error;
111
112 fn try_from(s: String) -> Result<Self, Self::Error> {
113 $name::from_str(&s)
114 }
115 }
116
117 impl FromStr for $name {
118 type Err = Error;
119
120 fn from_str(s: &str) -> Result<Self, Self::Err> {
121 let hex_part =
122 s.strip_prefix(concat!($prefix, "-")).ok_or_else(|| {
123 Error::InvalidParameters(format!(
124 "{} must start with '{}-' prefix",
125 stringify!($name),
126 $prefix
127 ))
128 })?;
129 let id = u64::from_str_radix(hex_part, 16)?;
130 Ok($name(id))
131 }
132 }
133 };
134}
135
136typeid!(
137 OrganizationID,
157 "org"
158);
159
160#[derive(Deserialize, Clone, Debug)]
181pub struct Organization {
182 id: OrganizationID,
183 name: String,
184 #[serde(rename = "latest_credit")]
185 credits: i64,
186}
187
188impl Display for Organization {
189 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
190 write!(f, "{}", self.name())
191 }
192}
193
194impl Organization {
195 pub fn id(&self) -> OrganizationID {
196 self.id
197 }
198
199 pub fn name(&self) -> &str {
200 &self.name
201 }
202
203 pub fn credits(&self) -> i64 {
204 self.credits
205 }
206}
207
208typeid!(
209 ProjectID,
230 "p"
231);
232
233typeid!(
234 ExperimentID,
255 "exp"
256);
257
258typeid!(
259 TrainingSessionID,
280 "t"
281);
282
283typeid!(
284 ValidationSessionID,
304 "v"
305);
306
307typeid!(
308 SnapshotID,
324 "ss"
325);
326
327typeid!(
328 TaskID,
344 "task"
345);
346
347typeid!(
348 DatasetID,
369 "ds"
370);
371
372typeid!(
373 AnnotationSetID,
389 "as"
390);
391
392typeid!(
393 SampleID,
409 "s"
410);
411
412typeid!(
413 AppId,
419 "app"
420);
421
422typeid!(
423 ImageId,
429 "im"
430);
431
432typeid!(
433 SequenceId,
439 "se"
440);
441
442#[derive(Deserialize, Clone, Debug)]
446pub struct Project {
447 id: ProjectID,
448 name: String,
449 description: String,
450}
451
452impl Display for Project {
453 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
454 write!(f, "{} {}", self.id(), self.name())
455 }
456}
457
458impl Project {
459 pub fn id(&self) -> ProjectID {
460 self.id
461 }
462
463 pub fn name(&self) -> &str {
464 &self.name
465 }
466
467 pub fn description(&self) -> &str {
468 &self.description
469 }
470
471 pub async fn datasets(
472 &self,
473 client: &client::Client,
474 name: Option<&str>,
475 ) -> Result<Vec<Dataset>, Error> {
476 client.datasets(self.id, name).await
477 }
478
479 pub async fn experiments(
480 &self,
481 client: &client::Client,
482 name: Option<&str>,
483 ) -> Result<Vec<Experiment>, Error> {
484 client.experiments(self.id, name).await
485 }
486}
487
488#[derive(Deserialize, Debug)]
489pub struct SamplesCountResult {
490 pub total: u64,
491}
492
493#[derive(Serialize, Clone, Debug)]
494pub struct SamplesListParams {
495 pub dataset_id: DatasetID,
496 #[serde(skip_serializing_if = "Option::is_none")]
497 pub annotation_set_id: Option<AnnotationSetID>,
498 #[serde(skip_serializing_if = "Option::is_none")]
499 pub continue_token: Option<String>,
500 #[serde(skip_serializing_if = "Vec::is_empty")]
501 pub types: Vec<String>,
502 #[serde(skip_serializing_if = "Vec::is_empty")]
503 pub group_names: Vec<String>,
504}
505
506#[derive(Deserialize, Debug)]
507pub struct SamplesListResult {
508 pub samples: Vec<Sample>,
509 pub continue_token: Option<String>,
510}
511
512#[derive(Serialize, Clone, Debug)]
514pub struct SampleDimensionUpdate {
515 pub id: SampleID,
516 pub width: u32,
517 pub height: u32,
518}
519
520#[derive(Serialize, Clone, Debug)]
522pub struct SamplesUpdateDimensionsParams {
523 pub dataset_id: DatasetID,
524 pub samples: Vec<SampleDimensionUpdate>,
525}
526
527#[derive(Deserialize, Debug)]
529pub struct SamplesUpdateDimensionsResult {
530 pub updated: u64,
531}
532
533#[derive(Serialize, Clone, Debug)]
538pub struct SamplesPopulateParams {
539 pub dataset_id: DatasetID,
540 #[serde(skip_serializing_if = "Option::is_none")]
541 pub annotation_set_id: Option<AnnotationSetID>,
542 #[serde(skip_serializing_if = "Option::is_none")]
543 pub presigned_urls: Option<bool>,
544 pub samples: Vec<Sample>,
545}
546
547#[derive(Deserialize, Debug, Clone)]
553pub struct SamplesPopulateResult {
554 pub uuid: String,
556 pub urls: Vec<PresignedUrl>,
558}
559
560#[derive(Deserialize, Debug, Clone)]
562pub struct PresignedUrl {
563 pub filename: String,
565 pub key: String,
567 pub url: String,
569}
570
571#[derive(Serialize, Clone, Debug)]
584pub struct ServerAnnotation {
585 #[serde(skip_serializing_if = "Option::is_none")]
587 pub label_id: Option<u64>,
588 #[serde(skip_serializing_if = "Option::is_none")]
590 pub label_index: Option<u64>,
591 #[serde(skip_serializing_if = "Option::is_none")]
593 pub label_name: Option<String>,
594 #[serde(rename = "type")]
596 pub annotation_type: String,
597 pub x: f64,
599 pub y: f64,
601 pub w: f64,
603 pub h: f64,
605 pub score: f64,
607 #[serde(skip_serializing_if = "String::is_empty")]
609 pub polygon: String,
610 pub image_id: u64,
612 pub annotation_set_id: u64,
614 #[serde(skip_serializing_if = "Option::is_none")]
616 pub object_reference: Option<String>,
617}
618
619#[derive(Serialize, Debug)]
621pub struct AnnotationAddBulkParams {
622 pub annotation_set_id: u64,
623 pub annotations: Vec<ServerAnnotation>,
624}
625
626#[derive(Serialize, Debug)]
628pub struct AnnotationBulkDeleteParams {
629 pub annotation_set_id: u64,
630 pub annotation_types: Vec<String>,
631 #[serde(skip_serializing_if = "Vec::is_empty")]
633 pub image_ids: Vec<u64>,
634 #[serde(skip_serializing_if = "Option::is_none")]
636 pub delete_all: Option<bool>,
637}
638
639#[derive(Deserialize)]
640pub struct Snapshot {
641 id: SnapshotID,
642 description: String,
643 status: String,
644 path: String,
645 #[serde(rename = "date")]
646 created: DateTime<Utc>,
647}
648
649impl Display for Snapshot {
650 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
651 write!(f, "{} {}", self.id, self.description)
652 }
653}
654
655impl Snapshot {
656 pub fn id(&self) -> SnapshotID {
657 self.id
658 }
659
660 pub fn description(&self) -> &str {
661 &self.description
662 }
663
664 pub fn status(&self) -> &str {
665 &self.status
666 }
667
668 pub fn path(&self) -> &str {
669 &self.path
670 }
671
672 pub fn created(&self) -> &DateTime<Utc> {
673 &self.created
674 }
675}
676
677#[derive(Serialize, Debug)]
678pub struct SnapshotRestore {
679 pub project_id: ProjectID,
680 pub snapshot_id: SnapshotID,
681 pub fps: u64,
682 #[serde(rename = "enabled_topics", skip_serializing_if = "Vec::is_empty")]
683 pub topics: Vec<String>,
684 #[serde(rename = "label_names", skip_serializing_if = "Vec::is_empty")]
685 pub autolabel: Vec<String>,
686 #[serde(rename = "depth_gen")]
687 pub autodepth: bool,
688 pub agtg_pipeline: bool,
689 #[serde(skip_serializing_if = "Option::is_none")]
690 pub dataset_name: Option<String>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 pub dataset_description: Option<String>,
693}
694
695#[derive(Deserialize, Debug)]
696pub struct SnapshotRestoreResult {
697 pub id: SnapshotID,
698 pub description: String,
699 pub dataset_name: String,
700 pub dataset_id: DatasetID,
701 pub annotation_set_id: AnnotationSetID,
702 #[serde(default)]
703 pub task_id: Option<TaskID>,
704 pub date: DateTime<Utc>,
705}
706
707#[derive(Serialize, Debug)]
712pub struct SnapshotCreateFromDataset {
713 pub description: String,
715 pub dataset_id: DatasetID,
717 pub annotation_set_id: AnnotationSetID,
719}
720
721#[derive(Deserialize, Debug)]
725pub struct SnapshotFromDatasetResult {
726 #[serde(alias = "snapshot_id")]
728 pub id: SnapshotID,
729 #[serde(default)]
731 pub task_id: Option<TaskID>,
732}
733
734#[derive(Deserialize)]
735pub struct Experiment {
736 id: ExperimentID,
737 project_id: ProjectID,
738 name: String,
739 description: String,
740}
741
742impl Display for Experiment {
743 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
744 write!(f, "{} {}", self.id, self.name)
745 }
746}
747
748impl Experiment {
749 pub fn id(&self) -> ExperimentID {
750 self.id
751 }
752
753 pub fn project_id(&self) -> ProjectID {
754 self.project_id
755 }
756
757 pub fn name(&self) -> &str {
758 &self.name
759 }
760
761 pub fn description(&self) -> &str {
762 &self.description
763 }
764
765 pub async fn project(&self, client: &client::Client) -> Result<Project, Error> {
766 client.project(self.project_id).await
767 }
768
769 pub async fn training_sessions(
770 &self,
771 client: &client::Client,
772 name: Option<&str>,
773 ) -> Result<Vec<TrainingSession>, Error> {
774 client.training_sessions(self.id, name).await
775 }
776}
777
778#[derive(Serialize, Debug)]
779pub struct PublishMetrics {
780 #[serde(rename = "trainer_session_id", skip_serializing_if = "Option::is_none")]
781 pub trainer_session_id: Option<TrainingSessionID>,
782 #[serde(
783 rename = "validate_session_id",
784 skip_serializing_if = "Option::is_none"
785 )]
786 pub validate_session_id: Option<ValidationSessionID>,
787 pub metrics: HashMap<String, Parameter>,
788}
789
790#[derive(Deserialize)]
791struct TrainingSessionParams {
792 model_params: HashMap<String, Parameter>,
793 dataset_params: DatasetParams,
794}
795
796#[derive(Deserialize)]
797pub struct TrainingSession {
798 id: TrainingSessionID,
799 #[serde(rename = "trainer_id")]
800 experiment_id: ExperimentID,
801 model: String,
802 name: String,
803 description: String,
804 params: TrainingSessionParams,
805 #[serde(rename = "docker_task")]
806 task: Task,
807}
808
809impl Display for TrainingSession {
810 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
811 write!(f, "{} {}", self.id, self.name())
812 }
813}
814
815impl TrainingSession {
816 pub fn id(&self) -> TrainingSessionID {
817 self.id
818 }
819
820 pub fn name(&self) -> &str {
821 &self.name
822 }
823
824 pub fn description(&self) -> &str {
825 &self.description
826 }
827
828 pub fn model(&self) -> &str {
829 &self.model
830 }
831
832 pub fn experiment_id(&self) -> ExperimentID {
833 self.experiment_id
834 }
835
836 pub fn task(&self) -> Task {
837 self.task.clone()
838 }
839
840 pub fn model_params(&self) -> &HashMap<String, Parameter> {
841 &self.params.model_params
842 }
843
844 pub fn dataset_params(&self) -> &DatasetParams {
845 &self.params.dataset_params
846 }
847
848 pub fn train_group(&self) -> &str {
849 &self.params.dataset_params.train_group
850 }
851
852 pub fn val_group(&self) -> &str {
853 &self.params.dataset_params.val_group
854 }
855
856 pub async fn experiment(&self, client: &client::Client) -> Result<Experiment, Error> {
857 client.experiment(self.experiment_id).await
858 }
859
860 pub async fn dataset(&self, client: &client::Client) -> Result<Dataset, Error> {
861 client.dataset(self.params.dataset_params.dataset_id).await
862 }
863
864 pub async fn annotation_set(&self, client: &client::Client) -> Result<AnnotationSet, Error> {
865 client
866 .annotation_set(self.params.dataset_params.annotation_set_id)
867 .await
868 }
869
870 pub async fn artifacts(&self, client: &client::Client) -> Result<Vec<Artifact>, Error> {
871 client.artifacts(self.id).await
872 }
873
874 pub async fn metrics(
875 &self,
876 client: &client::Client,
877 ) -> Result<HashMap<String, Parameter>, Error> {
878 #[derive(Deserialize)]
879 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
880 enum Response {
881 Empty {},
882 Map(HashMap<String, Parameter>),
883 String(String),
884 }
885
886 let params = HashMap::from([("trainer_session_id", self.id().value())]);
887 let resp: Response = client
888 .rpc("trainer.session.metrics".to_owned(), Some(params))
889 .await?;
890
891 Ok(match resp {
892 Response::String(metrics) => serde_json::from_str(&metrics)?,
893 Response::Map(metrics) => metrics,
894 Response::Empty {} => HashMap::new(),
895 })
896 }
897
898 pub async fn set_metrics(
899 &self,
900 client: &client::Client,
901 metrics: HashMap<String, Parameter>,
902 ) -> Result<(), Error> {
903 let metrics = PublishMetrics {
904 trainer_session_id: Some(self.id()),
905 validate_session_id: None,
906 metrics,
907 };
908
909 let _: String = client
910 .rpc("trainer.session.metrics".to_owned(), Some(metrics))
911 .await?;
912
913 Ok(())
914 }
915
916 pub async fn download_artifact(
918 &self,
919 client: &client::Client,
920 filename: &str,
921 ) -> Result<Vec<u8>, Error> {
922 client
923 .fetch(&format!(
924 "download_model?training_session_id={}&file={}",
925 self.id().value(),
926 filename
927 ))
928 .await
929 }
930
931 pub async fn upload_artifact(
935 &self,
936 client: &client::Client,
937 filename: &str,
938 path: PathBuf,
939 ) -> Result<(), Error> {
940 self.upload(client, &[(format!("artifacts/{}", filename), path)])
941 .await
942 }
943
944 pub async fn download_checkpoint(
946 &self,
947 client: &client::Client,
948 filename: &str,
949 ) -> Result<Vec<u8>, Error> {
950 client
951 .fetch(&format!(
952 "download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
953 self.id().value(),
954 filename
955 ))
956 .await
957 }
958
959 pub async fn upload_checkpoint(
963 &self,
964 client: &client::Client,
965 filename: &str,
966 path: PathBuf,
967 ) -> Result<(), Error> {
968 self.upload(client, &[(format!("checkpoints/{}", filename), path)])
969 .await
970 }
971
972 pub async fn download(&self, client: &client::Client, filename: &str) -> Result<String, Error> {
976 #[derive(Serialize)]
977 struct DownloadRequest {
978 session_id: TrainingSessionID,
979 file_path: String,
980 }
981
982 let params = DownloadRequest {
983 session_id: self.id(),
984 file_path: filename.to_string(),
985 };
986
987 client
988 .rpc("trainer.download.file".to_owned(), Some(params))
989 .await
990 }
991
992 pub async fn upload(
993 &self,
994 client: &client::Client,
995 files: &[(String, PathBuf)],
996 ) -> Result<(), Error> {
997 let mut parts = Form::new().part(
998 "params",
999 Part::text(format!("{{ \"session_id\": {} }}", self.id().value())),
1000 );
1001
1002 for (name, path) in files {
1003 let file_part = Part::file(path).await?.file_name(name.to_owned());
1004 parts = parts.part("file", file_part);
1005 }
1006
1007 let result = client.post_multipart("trainer.upload.files", parts).await?;
1008 trace!("TrainingSession::upload: {:?}", result);
1009 Ok(())
1010 }
1011}
1012
1013#[derive(Deserialize, Clone, Debug)]
1014pub struct ValidationSession {
1015 id: ValidationSessionID,
1016 description: String,
1017 dataset_id: DatasetID,
1018 experiment_id: ExperimentID,
1019 training_session_id: TrainingSessionID,
1020 #[serde(rename = "gt_annotation_set_id")]
1021 annotation_set_id: AnnotationSetID,
1022 #[serde(deserialize_with = "validation_session_params")]
1023 params: HashMap<String, Parameter>,
1024 #[serde(rename = "docker_task")]
1025 task: Task,
1026}
1027
1028fn validation_session_params<'de, D>(
1029 deserializer: D,
1030) -> Result<HashMap<String, Parameter>, D::Error>
1031where
1032 D: Deserializer<'de>,
1033{
1034 #[derive(Deserialize)]
1035 struct ModelParams {
1036 validation: Option<HashMap<String, Parameter>>,
1037 }
1038
1039 #[derive(Deserialize)]
1040 struct ValidateParams {
1041 model: String,
1042 }
1043
1044 #[derive(Deserialize)]
1045 struct Params {
1046 model_params: ModelParams,
1047 validate_params: ValidateParams,
1048 }
1049
1050 let params = Params::deserialize(deserializer)?;
1051 let params = match params.model_params.validation {
1052 Some(mut map) => {
1053 map.insert(
1054 "model".to_string(),
1055 Parameter::String(params.validate_params.model),
1056 );
1057 map
1058 }
1059 None => HashMap::from([(
1060 "model".to_string(),
1061 Parameter::String(params.validate_params.model),
1062 )]),
1063 };
1064
1065 Ok(params)
1066}
1067
1068impl ValidationSession {
1069 pub fn id(&self) -> ValidationSessionID {
1070 self.id
1071 }
1072
1073 pub fn name(&self) -> &str {
1074 self.task.name()
1075 }
1076
1077 pub fn description(&self) -> &str {
1078 &self.description
1079 }
1080
1081 pub fn dataset_id(&self) -> DatasetID {
1082 self.dataset_id
1083 }
1084
1085 pub fn experiment_id(&self) -> ExperimentID {
1086 self.experiment_id
1087 }
1088
1089 pub fn training_session_id(&self) -> TrainingSessionID {
1090 self.training_session_id
1091 }
1092
1093 pub fn annotation_set_id(&self) -> AnnotationSetID {
1094 self.annotation_set_id
1095 }
1096
1097 pub fn params(&self) -> &HashMap<String, Parameter> {
1098 &self.params
1099 }
1100
1101 pub fn task(&self) -> &Task {
1102 &self.task
1103 }
1104
1105 pub async fn metrics(
1106 &self,
1107 client: &client::Client,
1108 ) -> Result<HashMap<String, Parameter>, Error> {
1109 #[derive(Deserialize)]
1110 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
1111 enum Response {
1112 Empty {},
1113 Map(HashMap<String, Parameter>),
1114 String(String),
1115 }
1116
1117 let params = HashMap::from([("validate_session_id", self.id().value())]);
1118 let resp: Response = client
1119 .rpc("validate.session.metrics".to_owned(), Some(params))
1120 .await?;
1121
1122 Ok(match resp {
1123 Response::String(metrics) => serde_json::from_str(&metrics)?,
1124 Response::Map(metrics) => metrics,
1125 Response::Empty {} => HashMap::new(),
1126 })
1127 }
1128
1129 pub async fn set_metrics(
1130 &self,
1131 client: &client::Client,
1132 metrics: HashMap<String, Parameter>,
1133 ) -> Result<(), Error> {
1134 let metrics = PublishMetrics {
1135 trainer_session_id: None,
1136 validate_session_id: Some(self.id()),
1137 metrics,
1138 };
1139
1140 let _: String = client
1141 .rpc("validate.session.metrics".to_owned(), Some(metrics))
1142 .await?;
1143
1144 Ok(())
1145 }
1146
1147 pub async fn upload_data(
1172 &self,
1173 client: &client::Client,
1174 files: &[(String, std::path::PathBuf)],
1175 folder: Option<&str>,
1176 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1177 ) -> Result<(), Error> {
1178 use futures::StreamExt;
1179 use std::sync::{
1180 Arc,
1181 atomic::{AtomicUsize, Ordering},
1182 };
1183 use tokio_util::io::ReaderStream;
1184
1185 let mut total: usize = 0;
1187 let mut file_meta = Vec::with_capacity(files.len());
1188 for (name, path) in files {
1189 let f = tokio::fs::File::open(path).await?;
1190 let len = f.metadata().await?.len() as usize;
1191 total += len;
1192 file_meta.push((name.clone(), f, len));
1193 }
1194
1195 let sent = Arc::new(AtomicUsize::new(0));
1197
1198 let mut form = Form::new().text("session_id", self.id().value().to_string());
1199 if let Some(folder) = folder.filter(|s| !s.is_empty()) {
1200 form = form.text("folder", folder.to_owned());
1201 }
1202
1203 for (name, file, len) in file_meta {
1204 let reader_stream = ReaderStream::new(file);
1205 let sent_clone = sent.clone();
1206 let progress_clone = progress.clone();
1207 let progress_stream = reader_stream.inspect(move |chunk_result| {
1208 if let Ok(chunk) = chunk_result {
1209 let current =
1210 sent_clone.fetch_add(chunk.len(), Ordering::Relaxed) + chunk.len();
1211 if let Some(tx) = &progress_clone {
1216 let _ = tx.try_send(Progress {
1217 current,
1218 total,
1219 status: None,
1220 });
1221 }
1222 }
1223 });
1224 let body = reqwest::Body::wrap_stream(progress_stream);
1225 let part = Part::stream_with_length(body, len as u64).file_name(name);
1226 form = form.part("file", part);
1227 }
1228
1229 let result = match client.post_multipart("val.data.upload", form).await {
1230 Ok(_) => Ok(()),
1231 Err(Error::RpcError(code, msg)) => {
1232 Err(client::map_rpc_error("val.data.upload", code, msg, None))
1233 }
1234 Err(e) => Err(e),
1235 };
1236
1237 if result.is_ok()
1242 && let Some(tx) = progress
1243 {
1244 let _ = tx
1245 .send(Progress {
1246 current: total,
1247 total,
1248 status: None,
1249 })
1250 .await;
1251 }
1252 result
1253 }
1254
1255 pub async fn download_data(
1275 &self,
1276 client: &client::Client,
1277 filename: &str,
1278 output_path: &std::path::Path,
1279 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1280 ) -> Result<(), Error> {
1281 let req = client::ValDataDownloadRequest {
1282 session_id: self.id().value(),
1283 filename: filename.to_owned(),
1284 };
1285 match client
1286 .rpc_download("val.data.download", &req, output_path, progress)
1287 .await
1288 {
1289 Ok(()) => Ok(()),
1290 Err(Error::RpcError(code, msg)) => {
1291 Err(client::map_rpc_error("val.data.download", code, msg, None))
1292 }
1293 Err(e) => Err(e),
1294 }
1295 }
1296
1297 pub async fn data_list(&self, client: &client::Client) -> Result<Vec<String>, Error> {
1312 let req = client::ValDataListRequest {
1313 session_id: self.id().value(),
1314 };
1315 match client.rpc("val.data.list".to_owned(), Some(&req)).await {
1316 Ok(r) => Ok(r),
1317 Err(Error::RpcError(code, msg)) => {
1318 Err(client::map_rpc_error("val.data.list", code, msg, None))
1319 }
1320 Err(e) => Err(e),
1321 }
1322 }
1323}
1324
1325#[derive(Debug, Clone)]
1344pub struct StartValidationRequest {
1345 pub project_id: ProjectID,
1346 pub name: String,
1347 pub training_session_id: TrainingSessionID,
1348 pub model_file: String,
1349 pub val_type: String,
1350 pub params: HashMap<String, Parameter>,
1351 pub is_local: bool,
1352 pub is_kubernetes: bool,
1353 pub description: Option<String>,
1354 pub dataset_id: Option<DatasetID>,
1355 pub annotation_set_id: Option<AnnotationSetID>,
1356 pub snapshot_id: Option<SnapshotID>,
1357}
1358
1359#[derive(Deserialize, Debug, Clone)]
1374pub struct NewValidationSession {
1375 #[serde(rename = "id")]
1376 pub task_id: TaskID,
1377 #[serde(rename = "val_session_id", default)]
1378 pub session_id: Option<ValidationSessionID>,
1379}
1380
1381impl Display for NewValidationSession {
1382 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1383 match self.session_id {
1384 Some(id) => write!(f, "task {} session {}", self.task_id, id),
1385 None => write!(f, "task {} (no session)", self.task_id),
1386 }
1387 }
1388}
1389
1390#[derive(Deserialize, Clone, Debug)]
1391pub struct DatasetParams {
1392 dataset_id: DatasetID,
1393 annotation_set_id: AnnotationSetID,
1394 #[serde(rename = "train_group_name")]
1395 train_group: String,
1396 #[serde(rename = "val_group_name")]
1397 val_group: String,
1398}
1399
1400impl DatasetParams {
1401 pub fn dataset_id(&self) -> DatasetID {
1402 self.dataset_id
1403 }
1404
1405 pub fn annotation_set_id(&self) -> AnnotationSetID {
1406 self.annotation_set_id
1407 }
1408
1409 pub fn train_group(&self) -> &str {
1410 &self.train_group
1411 }
1412
1413 pub fn val_group(&self) -> &str {
1414 &self.val_group
1415 }
1416}
1417
1418#[derive(Serialize, Debug, Clone)]
1419pub struct TasksListParams {
1420 #[serde(skip_serializing_if = "Option::is_none")]
1421 pub continue_token: Option<String>,
1422 #[serde(skip_serializing_if = "Option::is_none")]
1423 pub types: Option<Vec<String>>,
1424 #[serde(rename = "manage_types", skip_serializing_if = "Option::is_none")]
1425 pub manager: Option<Vec<String>>,
1426 #[serde(skip_serializing_if = "Option::is_none")]
1427 pub status: Option<Vec<String>>,
1428}
1429
1430#[derive(Debug, Clone, Serialize, Deserialize)]
1436pub struct TaskDataList {
1437 pub server: String,
1438 #[serde(rename = "organization_uid")]
1439 pub organization_uid: String,
1440 #[serde(default)]
1441 pub traces: Vec<String>,
1442 #[serde(default)]
1443 pub data: std::collections::HashMap<String, Vec<String>>,
1444}
1445
1446#[derive(Debug, Clone, Serialize, Deserialize)]
1451pub struct Job {
1452 #[serde(default)]
1454 pub code: String,
1455 #[serde(default)]
1457 pub title: String,
1458 #[serde(default)]
1460 pub job_name: String,
1461 #[serde(default)]
1463 pub job_id: String,
1464 #[serde(default)]
1466 pub state: String,
1467 #[serde(default)]
1469 pub launch: Option<DateTime<Utc>>,
1470 pub task_id: i64,
1475}
1476
1477impl Job {
1478 pub fn task_id(&self) -> TaskID {
1484 TaskID::from(self.task_id.max(0) as u64)
1485 }
1486}
1487
1488#[derive(Deserialize, Debug, Clone)]
1489pub struct TasksListResult {
1490 pub tasks: Vec<Task>,
1491 pub continue_token: Option<String>,
1492}
1493
1494#[derive(Deserialize, Debug, Clone)]
1495pub struct Task {
1496 id: TaskID,
1497 name: String,
1498 #[serde(rename = "type")]
1499 workflow: String,
1500 status: String,
1501 #[serde(rename = "manage_type")]
1502 manager: Option<String>,
1503 #[serde(rename = "instance_type")]
1504 instance: String,
1505 #[serde(rename = "date")]
1506 created: DateTime<Utc>,
1507}
1508
1509impl Task {
1510 pub fn id(&self) -> TaskID {
1511 self.id
1512 }
1513
1514 pub fn name(&self) -> &str {
1515 &self.name
1516 }
1517
1518 pub fn workflow(&self) -> &str {
1519 &self.workflow
1520 }
1521
1522 pub fn status(&self) -> &str {
1523 &self.status
1524 }
1525
1526 pub fn manager(&self) -> Option<&str> {
1527 self.manager.as_deref()
1528 }
1529
1530 pub fn instance(&self) -> &str {
1531 &self.instance
1532 }
1533
1534 pub fn created(&self) -> &DateTime<Utc> {
1535 &self.created
1536 }
1537}
1538
1539impl Display for Task {
1540 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1541 write!(
1542 f,
1543 "{} [{:?} {}] {}",
1544 self.id,
1545 self.manager(),
1546 self.workflow(),
1547 self.name()
1548 )
1549 }
1550}
1551
1552#[derive(Deserialize, Debug, Clone)]
1553pub struct TaskInfo {
1554 id: TaskID,
1555 project_id: Option<ProjectID>,
1556 #[serde(rename = "task_description", alias = "description", default)]
1557 description: String,
1558 #[serde(rename = "type")]
1559 workflow: String,
1560 status: Option<String>,
1561 #[serde(default)]
1562 progress: TaskProgress,
1563 #[serde(
1564 rename = "created_date",
1565 alias = "created",
1566 default = "default_datetime_utc"
1567 )]
1568 created: DateTime<Utc>,
1569 #[serde(
1570 rename = "end_date",
1571 alias = "completed",
1572 default = "default_datetime_utc"
1573 )]
1574 completed: DateTime<Utc>,
1575}
1576
1577fn default_datetime_utc() -> DateTime<Utc> {
1578 DateTime::UNIX_EPOCH
1579}
1580
1581impl Display for TaskInfo {
1582 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1583 write!(f, "{} {}: {}", self.id, self.workflow(), self.description())
1584 }
1585}
1586
1587impl TaskInfo {
1588 pub fn id(&self) -> TaskID {
1589 self.id
1590 }
1591
1592 pub fn project_id(&self) -> Option<ProjectID> {
1593 self.project_id
1594 }
1595
1596 pub fn description(&self) -> &str {
1597 &self.description
1598 }
1599
1600 pub fn workflow(&self) -> &str {
1601 &self.workflow
1602 }
1603
1604 pub fn status(&self) -> &Option<String> {
1605 &self.status
1606 }
1607
1608 pub async fn set_status(&mut self, client: &Client, status: &str) -> Result<(), Error> {
1609 let t = client.task_status(self.id(), status).await?;
1610 self.status = Some(t.status);
1611 Ok(())
1612 }
1613
1614 pub fn stages(&self) -> HashMap<String, Stage> {
1615 match &self.progress.stages {
1616 Some(stages) => stages.clone(),
1617 None => HashMap::new(),
1618 }
1619 }
1620
1621 pub async fn update_stage(
1622 &mut self,
1623 client: &Client,
1624 stage: &str,
1625 status: &str,
1626 message: &str,
1627 percentage: u8,
1628 ) -> Result<(), Error> {
1629 client
1630 .update_stage(self.id(), stage, status, message, percentage)
1631 .await?;
1632 let t = client.task_info(self.id()).await?;
1633 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1634 Ok(())
1635 }
1636
1637 pub async fn set_stages(
1638 &mut self,
1639 client: &Client,
1640 stages: &[(&str, &str)],
1641 ) -> Result<(), Error> {
1642 client.set_stages(self.id(), stages).await?;
1643 let t = client.task_info(self.id()).await?;
1644 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1645 Ok(())
1646 }
1647
1648 pub async fn data_list(&self, client: &client::Client) -> Result<TaskDataList, Error> {
1664 let req = client::TaskDataListRequest {
1665 task_id: self.id().value(),
1666 };
1667 match client.rpc("task.data.list".to_owned(), Some(&req)).await {
1668 Ok(r) => Ok(r),
1669 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1670 "task.data.list",
1671 code,
1672 msg,
1673 Some(self.id()),
1674 )),
1675 Err(e) => Err(e),
1676 }
1677 }
1678
1679 pub async fn upload_data(
1700 &self,
1701 client: &client::Client,
1702 path: &std::path::Path,
1703 folder: Option<&str>,
1704 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1705 ) -> Result<(), Error> {
1706 use futures::StreamExt;
1707 use std::sync::{
1708 Arc,
1709 atomic::{AtomicUsize, Ordering},
1710 };
1711 use tokio_util::io::ReaderStream;
1712
1713 let file_name = path
1714 .file_name()
1715 .and_then(|s| s.to_str())
1716 .ok_or_else(|| Error::InvalidParameters("path must have a UTF-8 filename".into()))?
1717 .to_owned();
1718
1719 let file = tokio::fs::File::open(path).await?;
1720 let total = file.metadata().await?.len() as usize;
1721 let sent = Arc::new(AtomicUsize::new(0));
1722
1723 let reader_stream = ReaderStream::new(file);
1724 let sent_clone = sent.clone();
1725 let progress_clone = progress.clone();
1726 let progress_stream = reader_stream.inspect(move |chunk_result| {
1727 if let Ok(chunk) = chunk_result {
1728 let current = sent_clone.fetch_add(chunk.len(), Ordering::Relaxed) + chunk.len();
1729 if let Some(tx) = &progress_clone {
1735 let _ = tx.try_send(Progress {
1736 current,
1737 total,
1738 status: None,
1739 });
1740 }
1741 }
1742 });
1743
1744 let body = reqwest::Body::wrap_stream(progress_stream);
1745 let file_part = Part::stream_with_length(body, total as u64).file_name(file_name);
1746
1747 let mut form = Form::new().text("task_id", self.id().value().to_string());
1748 if let Some(folder) = folder.filter(|s| !s.is_empty()) {
1749 form = form.text("folder", folder.to_owned());
1750 }
1751 form = form.part("file", file_part);
1752
1753 let result = match client.post_multipart("task.data.upload", form).await {
1754 Ok(_) => Ok(()),
1755 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1756 "task.data.upload",
1757 code,
1758 msg,
1759 Some(self.id()),
1760 )),
1761 Err(e) => Err(e),
1762 };
1763
1764 if result.is_ok()
1768 && let Some(tx) = progress
1769 {
1770 let _ = tx
1771 .send(Progress {
1772 current: total,
1773 total,
1774 status: None,
1775 })
1776 .await;
1777 }
1778 result
1779 }
1780
1781 pub async fn download_data(
1810 &self,
1811 client: &client::Client,
1812 file: &str,
1813 folder: Option<&str>,
1814 output_path: &std::path::Path,
1815 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1816 ) -> Result<(), Error> {
1817 let folder = folder.unwrap_or("").to_owned();
1818 let req = client::TaskDataDownloadRequest {
1819 task_id: self.id().value(),
1820 folder,
1821 file: file.to_owned(),
1822 };
1823 match client
1824 .rpc_download("task.data.download", &req, output_path, progress)
1825 .await
1826 {
1827 Ok(()) => Ok(()),
1828 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1829 "task.data.download",
1830 code,
1831 msg,
1832 Some(self.id()),
1833 )),
1834 Err(e) => Err(e),
1835 }
1836 }
1837
1838 pub async fn add_chart(
1866 &self,
1867 client: &client::Client,
1868 group: &str,
1869 name: &str,
1870 data: Parameter,
1871 params: Option<Parameter>,
1872 ) -> Result<(), Error> {
1873 client::validate_chart_args(group, name)?;
1874 let req = client::TaskChartAddRequest {
1875 task_id: self.id().value(),
1876 group_name: group.to_owned(),
1877 chart_name: name.to_owned(),
1878 params,
1879 data,
1880 };
1881 let _resp: serde_json::Value =
1882 match client.rpc("task.chart.add".to_owned(), Some(&req)).await {
1883 Ok(r) => r,
1884 Err(Error::RpcError(code, msg)) => {
1885 return Err(client::map_rpc_error(
1886 "task.chart.add",
1887 code,
1888 msg,
1889 Some(self.id()),
1890 ));
1891 }
1892 Err(e) => return Err(e),
1893 };
1894 Ok(())
1895 }
1896
1897 pub async fn list_charts(
1914 &self,
1915 client: &client::Client,
1916 group: Option<&str>,
1917 ) -> Result<TaskDataList, Error> {
1918 let req = client::TaskChartListRequest {
1919 task_id: self.id().value(),
1920 group_name: group.unwrap_or("").to_owned(),
1921 };
1922 match client.rpc("task.chart.list".to_owned(), Some(&req)).await {
1923 Ok(r) => Ok(r),
1924 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1925 "task.chart.list",
1926 code,
1927 msg,
1928 Some(self.id()),
1929 )),
1930 Err(e) => Err(e),
1931 }
1932 }
1933
1934 pub async fn get_chart(
1953 &self,
1954 client: &client::Client,
1955 group: &str,
1956 name: &str,
1957 ) -> Result<Parameter, Error> {
1958 client::validate_chart_args(group, name)?;
1959 let req = client::TaskChartGetRequest {
1960 task_id: self.id().value(),
1961 group_name: group.to_owned(),
1962 chart_name: name.to_owned(),
1963 };
1964 match client.rpc("task.chart.get".to_owned(), Some(&req)).await {
1965 Ok(r) => Ok(r),
1966 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1967 "task.chart.get",
1968 code,
1969 msg,
1970 Some(self.id()),
1971 )),
1972 Err(e) => Err(e),
1973 }
1974 }
1975
1976 pub fn created(&self) -> &DateTime<Utc> {
1977 &self.created
1978 }
1979
1980 pub fn completed(&self) -> &DateTime<Utc> {
1981 &self.completed
1982 }
1983}
1984
1985#[derive(Deserialize, Debug, Default, Clone)]
1986pub struct TaskProgress {
1987 stages: Option<HashMap<String, Stage>>,
1988}
1989
1990#[derive(Serialize, Debug, Clone)]
1991pub struct TaskStatus {
1992 #[serde(rename = "docker_task_id")]
1993 pub task_id: TaskID,
1994 pub status: String,
1995}
1996
1997#[derive(Serialize, Deserialize, Debug, Clone)]
1998pub struct Stage {
1999 #[serde(rename = "docker_task_id", skip_serializing_if = "Option::is_none")]
2000 task_id: Option<TaskID>,
2001 stage: String,
2002 #[serde(skip_serializing_if = "Option::is_none")]
2003 status: Option<String>,
2004 #[serde(skip_serializing_if = "Option::is_none")]
2005 description: Option<String>,
2006 #[serde(skip_serializing_if = "Option::is_none")]
2007 message: Option<String>,
2008 percentage: u8,
2009}
2010
2011impl Stage {
2012 pub fn new(
2013 task_id: Option<TaskID>,
2014 stage: String,
2015 status: Option<String>,
2016 message: Option<String>,
2017 percentage: u8,
2018 ) -> Self {
2019 Stage {
2020 task_id,
2021 stage,
2022 status,
2023 description: None,
2024 message,
2025 percentage,
2026 }
2027 }
2028
2029 pub fn task_id(&self) -> &Option<TaskID> {
2030 &self.task_id
2031 }
2032
2033 pub fn stage(&self) -> &str {
2034 &self.stage
2035 }
2036
2037 pub fn status(&self) -> &Option<String> {
2038 &self.status
2039 }
2040
2041 pub fn description(&self) -> &Option<String> {
2042 &self.description
2043 }
2044
2045 pub fn message(&self) -> &Option<String> {
2046 &self.message
2047 }
2048
2049 pub fn percentage(&self) -> u8 {
2050 self.percentage
2051 }
2052}
2053
2054#[derive(Serialize, Debug)]
2055pub struct TaskStages {
2056 #[serde(rename = "docker_task_id")]
2057 pub task_id: TaskID,
2058 #[serde(skip_serializing_if = "Vec::is_empty")]
2059 pub stages: Vec<HashMap<String, String>>,
2060}
2061
2062#[derive(Deserialize, Debug)]
2063pub struct Artifact {
2064 name: String,
2065 #[serde(rename = "modelType")]
2066 model_type: String,
2067}
2068
2069impl Artifact {
2070 pub fn name(&self) -> &str {
2071 &self.name
2072 }
2073
2074 pub fn model_type(&self) -> &str {
2075 &self.model_type
2076 }
2077}
2078
2079#[cfg(test)]
2080mod tests {
2081 use super::*;
2082
2083 #[test]
2085 fn test_organization_id_from_u64() {
2086 let id = OrganizationID::from(12345);
2087 assert_eq!(id.value(), 12345);
2088 }
2089
2090 #[test]
2091 fn test_organization_id_display() {
2092 let id = OrganizationID::from(0xabc123);
2093 assert_eq!(format!("{}", id), "org-abc123");
2094 }
2095
2096 #[test]
2097 fn test_organization_id_try_from_str_valid() {
2098 let id = OrganizationID::try_from("org-abc123").unwrap();
2099 assert_eq!(id.value(), 0xabc123);
2100 }
2101
2102 #[test]
2103 fn test_organization_id_try_from_str_invalid_prefix() {
2104 let result = OrganizationID::try_from("invalid-abc123");
2105 assert!(result.is_err());
2106 match result {
2107 Err(Error::InvalidParameters(msg)) => {
2108 assert!(msg.contains("must start with 'org-'"));
2109 }
2110 _ => panic!("Expected InvalidParameters error"),
2111 }
2112 }
2113
2114 #[test]
2115 fn test_organization_id_try_from_str_invalid_hex() {
2116 let result = OrganizationID::try_from("org-xyz");
2117 assert!(result.is_err());
2118 }
2119
2120 #[test]
2121 fn test_organization_id_try_from_str_empty() {
2122 let result = OrganizationID::try_from("org-");
2123 assert!(result.is_err());
2124 }
2125
2126 #[test]
2127 fn test_organization_id_into_u64() {
2128 let id = OrganizationID::from(54321);
2129 let value: u64 = id.into();
2130 assert_eq!(value, 54321);
2131 }
2132
2133 #[test]
2135 fn test_project_id_from_u64() {
2136 let id = ProjectID::from(78910);
2137 assert_eq!(id.value(), 78910);
2138 }
2139
2140 #[test]
2141 fn test_project_id_display() {
2142 let id = ProjectID::from(0xdef456);
2143 assert_eq!(format!("{}", id), "p-def456");
2144 }
2145
2146 #[test]
2147 fn test_project_id_from_str_valid() {
2148 let id = ProjectID::from_str("p-def456").unwrap();
2149 assert_eq!(id.value(), 0xdef456);
2150 }
2151
2152 #[test]
2153 fn test_project_id_try_from_str_valid() {
2154 let id = ProjectID::try_from("p-123abc").unwrap();
2155 assert_eq!(id.value(), 0x123abc);
2156 }
2157
2158 #[test]
2159 fn test_project_id_try_from_string_valid() {
2160 let id = ProjectID::try_from("p-456def".to_string()).unwrap();
2161 assert_eq!(id.value(), 0x456def);
2162 }
2163
2164 #[test]
2165 fn test_project_id_from_str_invalid_prefix() {
2166 let result = ProjectID::from_str("proj-123");
2167 assert!(result.is_err());
2168 match result {
2169 Err(Error::InvalidParameters(msg)) => {
2170 assert!(msg.contains("must start with 'p-'"));
2171 }
2172 _ => panic!("Expected InvalidParameters error"),
2173 }
2174 }
2175
2176 #[test]
2177 fn test_project_id_from_str_invalid_hex() {
2178 let result = ProjectID::from_str("p-notahex");
2179 assert!(result.is_err());
2180 }
2181
2182 #[test]
2183 fn test_project_id_into_u64() {
2184 let id = ProjectID::from(99999);
2185 let value: u64 = id.into();
2186 assert_eq!(value, 99999);
2187 }
2188
2189 #[test]
2191 fn test_experiment_id_from_u64() {
2192 let id = ExperimentID::from(1193046);
2193 assert_eq!(id.value(), 1193046);
2194 }
2195
2196 #[test]
2197 fn test_experiment_id_display() {
2198 let id = ExperimentID::from(0x123abc);
2199 assert_eq!(format!("{}", id), "exp-123abc");
2200 }
2201
2202 #[test]
2203 fn test_experiment_id_from_str_valid() {
2204 let id = ExperimentID::from_str("exp-456def").unwrap();
2205 assert_eq!(id.value(), 0x456def);
2206 }
2207
2208 #[test]
2209 fn test_experiment_id_try_from_str_valid() {
2210 let id = ExperimentID::try_from("exp-789abc").unwrap();
2211 assert_eq!(id.value(), 0x789abc);
2212 }
2213
2214 #[test]
2215 fn test_experiment_id_try_from_string_valid() {
2216 let id = ExperimentID::try_from("exp-fedcba".to_string()).unwrap();
2217 assert_eq!(id.value(), 0xfedcba);
2218 }
2219
2220 #[test]
2221 fn test_experiment_id_from_str_invalid_prefix() {
2222 let result = ExperimentID::from_str("experiment-123");
2223 assert!(result.is_err());
2224 match result {
2225 Err(Error::InvalidParameters(msg)) => {
2226 assert!(msg.contains("must start with 'exp-'"));
2227 }
2228 _ => panic!("Expected InvalidParameters error"),
2229 }
2230 }
2231
2232 #[test]
2233 fn test_experiment_id_from_str_invalid_hex() {
2234 let result = ExperimentID::from_str("exp-zzz");
2235 assert!(result.is_err());
2236 }
2237
2238 #[test]
2239 fn test_experiment_id_into_u64() {
2240 let id = ExperimentID::from(777777);
2241 let value: u64 = id.into();
2242 assert_eq!(value, 777777);
2243 }
2244
2245 #[test]
2247 fn test_training_session_id_from_u64() {
2248 let id = TrainingSessionID::from(7901234);
2249 assert_eq!(id.value(), 7901234);
2250 }
2251
2252 #[test]
2253 fn test_training_session_id_display() {
2254 let id = TrainingSessionID::from(0xabc123);
2255 assert_eq!(format!("{}", id), "t-abc123");
2256 }
2257
2258 #[test]
2259 fn test_training_session_id_from_str_valid() {
2260 let id = TrainingSessionID::from_str("t-abc123").unwrap();
2261 assert_eq!(id.value(), 0xabc123);
2262 }
2263
2264 #[test]
2265 fn test_training_session_id_try_from_str_valid() {
2266 let id = TrainingSessionID::try_from("t-deadbeef").unwrap();
2267 assert_eq!(id.value(), 0xdeadbeef);
2268 }
2269
2270 #[test]
2271 fn test_training_session_id_try_from_string_valid() {
2272 let id = TrainingSessionID::try_from("t-cafebabe".to_string()).unwrap();
2273 assert_eq!(id.value(), 0xcafebabe);
2274 }
2275
2276 #[test]
2277 fn test_training_session_id_from_str_invalid_prefix() {
2278 let result = TrainingSessionID::from_str("training-123");
2279 assert!(result.is_err());
2280 match result {
2281 Err(Error::InvalidParameters(msg)) => {
2282 assert!(msg.contains("must start with 't-'"));
2283 }
2284 _ => panic!("Expected InvalidParameters error"),
2285 }
2286 }
2287
2288 #[test]
2289 fn test_training_session_id_from_str_invalid_hex() {
2290 let result = TrainingSessionID::from_str("t-qqq");
2291 assert!(result.is_err());
2292 }
2293
2294 #[test]
2295 fn test_training_session_id_into_u64() {
2296 let id = TrainingSessionID::from(123456);
2297 let value: u64 = id.into();
2298 assert_eq!(value, 123456);
2299 }
2300
2301 #[test]
2303 fn test_validation_session_id_from_u64() {
2304 let id = ValidationSessionID::from(3456789);
2305 assert_eq!(id.value(), 3456789);
2306 }
2307
2308 #[test]
2309 fn test_validation_session_id_display() {
2310 let id = ValidationSessionID::from(0x34c985);
2311 assert_eq!(format!("{}", id), "v-34c985");
2312 }
2313
2314 #[test]
2315 fn test_validation_session_id_try_from_str_valid() {
2316 let id = ValidationSessionID::try_from("v-deadbeef").unwrap();
2317 assert_eq!(id.value(), 0xdeadbeef);
2318 }
2319
2320 #[test]
2321 fn test_validation_session_id_try_from_string_valid() {
2322 let id = ValidationSessionID::try_from("v-12345678".to_string()).unwrap();
2323 assert_eq!(id.value(), 0x12345678);
2324 }
2325
2326 #[test]
2327 fn test_validation_session_id_try_from_str_invalid_prefix() {
2328 let result = ValidationSessionID::try_from("validation-123");
2329 assert!(result.is_err());
2330 match result {
2331 Err(Error::InvalidParameters(msg)) => {
2332 assert!(msg.contains("must start with 'v-'"));
2333 }
2334 _ => panic!("Expected InvalidParameters error"),
2335 }
2336 }
2337
2338 #[test]
2339 fn test_validation_session_id_try_from_str_invalid_hex() {
2340 let result = ValidationSessionID::try_from("v-xyz");
2341 assert!(result.is_err());
2342 }
2343
2344 #[test]
2345 fn test_validation_session_id_into_u64() {
2346 let id = ValidationSessionID::from(987654);
2347 let value: u64 = id.into();
2348 assert_eq!(value, 987654);
2349 }
2350
2351 #[test]
2353 fn test_snapshot_id_from_u64() {
2354 let id = SnapshotID::from(111222);
2355 assert_eq!(id.value(), 111222);
2356 }
2357
2358 #[test]
2359 fn test_snapshot_id_display() {
2360 let id = SnapshotID::from(0xaabbcc);
2361 assert_eq!(format!("{}", id), "ss-aabbcc");
2362 }
2363
2364 #[test]
2365 fn test_snapshot_id_try_from_str_valid() {
2366 let id = SnapshotID::try_from("ss-aabbcc").unwrap();
2367 assert_eq!(id.value(), 0xaabbcc);
2368 }
2369
2370 #[test]
2371 fn test_snapshot_id_try_from_str_invalid_prefix() {
2372 let result = SnapshotID::try_from("snapshot-123");
2373 assert!(result.is_err());
2374 match result {
2375 Err(Error::InvalidParameters(msg)) => {
2376 assert!(msg.contains("must start with 'ss-'"));
2377 }
2378 _ => panic!("Expected InvalidParameters error"),
2379 }
2380 }
2381
2382 #[test]
2383 fn test_snapshot_id_try_from_str_invalid_hex() {
2384 let result = SnapshotID::try_from("ss-ggg");
2385 assert!(result.is_err());
2386 }
2387
2388 #[test]
2389 fn test_snapshot_id_into_u64() {
2390 let id = SnapshotID::from(333444);
2391 let value: u64 = id.into();
2392 assert_eq!(value, 333444);
2393 }
2394
2395 #[test]
2397 fn test_task_id_from_u64() {
2398 let id = TaskID::from(555666);
2399 assert_eq!(id.value(), 555666);
2400 }
2401
2402 #[test]
2403 fn test_task_id_display() {
2404 let id = TaskID::from(0x123456);
2405 assert_eq!(format!("{}", id), "task-123456");
2406 }
2407
2408 #[test]
2409 fn test_task_id_from_str_valid() {
2410 let id = TaskID::from_str("task-123456").unwrap();
2411 assert_eq!(id.value(), 0x123456);
2412 }
2413
2414 #[test]
2415 fn test_task_id_try_from_str_valid() {
2416 let id = TaskID::try_from("task-abcdef").unwrap();
2417 assert_eq!(id.value(), 0xabcdef);
2418 }
2419
2420 #[test]
2421 fn test_task_id_try_from_string_valid() {
2422 let id = TaskID::try_from("task-fedcba".to_string()).unwrap();
2423 assert_eq!(id.value(), 0xfedcba);
2424 }
2425
2426 #[test]
2427 fn test_task_id_from_str_invalid_prefix() {
2428 let result = TaskID::from_str("t-123");
2429 assert!(result.is_err());
2430 match result {
2431 Err(Error::InvalidParameters(msg)) => {
2432 assert!(msg.contains("must start with 'task-'"));
2433 }
2434 _ => panic!("Expected InvalidParameters error"),
2435 }
2436 }
2437
2438 #[test]
2439 fn test_task_id_from_str_invalid_hex() {
2440 let result = TaskID::from_str("task-zzz");
2441 assert!(result.is_err());
2442 }
2443
2444 #[test]
2445 fn test_task_id_into_u64() {
2446 let id = TaskID::from(777888);
2447 let value: u64 = id.into();
2448 assert_eq!(value, 777888);
2449 }
2450
2451 #[test]
2453 fn test_dataset_id_from_u64() {
2454 let id = DatasetID::from(1193046);
2455 assert_eq!(id.value(), 1193046);
2456 }
2457
2458 #[test]
2459 fn test_dataset_id_display() {
2460 let id = DatasetID::from(0x123abc);
2461 assert_eq!(format!("{}", id), "ds-123abc");
2462 }
2463
2464 #[test]
2465 fn test_dataset_id_from_str_valid() {
2466 let id = DatasetID::from_str("ds-456def").unwrap();
2467 assert_eq!(id.value(), 0x456def);
2468 }
2469
2470 #[test]
2471 fn test_dataset_id_try_from_str_valid() {
2472 let id = DatasetID::try_from("ds-789abc").unwrap();
2473 assert_eq!(id.value(), 0x789abc);
2474 }
2475
2476 #[test]
2477 fn test_dataset_id_try_from_string_valid() {
2478 let id = DatasetID::try_from("ds-fedcba".to_string()).unwrap();
2479 assert_eq!(id.value(), 0xfedcba);
2480 }
2481
2482 #[test]
2483 fn test_dataset_id_from_str_invalid_prefix() {
2484 let result = DatasetID::from_str("dataset-123");
2485 assert!(result.is_err());
2486 match result {
2487 Err(Error::InvalidParameters(msg)) => {
2488 assert!(msg.contains("must start with 'ds-'"));
2489 }
2490 _ => panic!("Expected InvalidParameters error"),
2491 }
2492 }
2493
2494 #[test]
2495 fn test_dataset_id_from_str_invalid_hex() {
2496 let result = DatasetID::from_str("ds-zzz");
2497 assert!(result.is_err());
2498 }
2499
2500 #[test]
2501 fn test_dataset_id_into_u64() {
2502 let id = DatasetID::from(111111);
2503 let value: u64 = id.into();
2504 assert_eq!(value, 111111);
2505 }
2506
2507 #[test]
2509 fn test_annotation_set_id_from_u64() {
2510 let id = AnnotationSetID::from(222333);
2511 assert_eq!(id.value(), 222333);
2512 }
2513
2514 #[test]
2515 fn test_annotation_set_id_display() {
2516 let id = AnnotationSetID::from(0xabcdef);
2517 assert_eq!(format!("{}", id), "as-abcdef");
2518 }
2519
2520 #[test]
2521 fn test_annotation_set_id_from_str_valid() {
2522 let id = AnnotationSetID::from_str("as-abcdef").unwrap();
2523 assert_eq!(id.value(), 0xabcdef);
2524 }
2525
2526 #[test]
2527 fn test_annotation_set_id_try_from_str_valid() {
2528 let id = AnnotationSetID::try_from("as-123456").unwrap();
2529 assert_eq!(id.value(), 0x123456);
2530 }
2531
2532 #[test]
2533 fn test_annotation_set_id_try_from_string_valid() {
2534 let id = AnnotationSetID::try_from("as-fedcba".to_string()).unwrap();
2535 assert_eq!(id.value(), 0xfedcba);
2536 }
2537
2538 #[test]
2539 fn test_annotation_set_id_from_str_invalid_prefix() {
2540 let result = AnnotationSetID::from_str("annotation-123");
2541 assert!(result.is_err());
2542 match result {
2543 Err(Error::InvalidParameters(msg)) => {
2544 assert!(msg.contains("must start with 'as-'"));
2545 }
2546 _ => panic!("Expected InvalidParameters error"),
2547 }
2548 }
2549
2550 #[test]
2551 fn test_annotation_set_id_from_str_invalid_hex() {
2552 let result = AnnotationSetID::from_str("as-zzz");
2553 assert!(result.is_err());
2554 }
2555
2556 #[test]
2557 fn test_annotation_set_id_into_u64() {
2558 let id = AnnotationSetID::from(444555);
2559 let value: u64 = id.into();
2560 assert_eq!(value, 444555);
2561 }
2562
2563 #[test]
2565 fn test_sample_id_from_u64() {
2566 let id = SampleID::from(666777);
2567 assert_eq!(id.value(), 666777);
2568 }
2569
2570 #[test]
2571 fn test_sample_id_display() {
2572 let id = SampleID::from(0x987654);
2573 assert_eq!(format!("{}", id), "s-987654");
2574 }
2575
2576 #[test]
2577 fn test_sample_id_try_from_str_valid() {
2578 let id = SampleID::try_from("s-987654").unwrap();
2579 assert_eq!(id.value(), 0x987654);
2580 }
2581
2582 #[test]
2583 fn test_sample_id_try_from_str_invalid_prefix() {
2584 let result = SampleID::try_from("sample-123");
2585 assert!(result.is_err());
2586 match result {
2587 Err(Error::InvalidParameters(msg)) => {
2588 assert!(msg.contains("must start with 's-'"));
2589 }
2590 _ => panic!("Expected InvalidParameters error"),
2591 }
2592 }
2593
2594 #[test]
2595 fn test_sample_id_try_from_str_invalid_hex() {
2596 let result = SampleID::try_from("s-zzz");
2597 assert!(result.is_err());
2598 }
2599
2600 #[test]
2601 fn test_sample_id_into_u64() {
2602 let id = SampleID::from(888999);
2603 let value: u64 = id.into();
2604 assert_eq!(value, 888999);
2605 }
2606
2607 #[test]
2609 fn test_app_id_from_u64() {
2610 let id = AppId::from(123123);
2611 assert_eq!(id.value(), 123123);
2612 }
2613
2614 #[test]
2615 fn test_app_id_display() {
2616 let id = AppId::from(0x456789);
2617 assert_eq!(format!("{}", id), "app-456789");
2618 }
2619
2620 #[test]
2621 fn test_app_id_try_from_str_valid() {
2622 let id = AppId::try_from("app-456789").unwrap();
2623 assert_eq!(id.value(), 0x456789);
2624 }
2625
2626 #[test]
2627 fn test_app_id_try_from_str_invalid_prefix() {
2628 let result = AppId::try_from("application-123");
2629 assert!(result.is_err());
2630 match result {
2631 Err(Error::InvalidParameters(msg)) => {
2632 assert!(msg.contains("must start with 'app-'"));
2633 }
2634 _ => panic!("Expected InvalidParameters error"),
2635 }
2636 }
2637
2638 #[test]
2639 fn test_app_id_try_from_str_invalid_hex() {
2640 let result = AppId::try_from("app-zzz");
2641 assert!(result.is_err());
2642 }
2643
2644 #[test]
2645 fn test_app_id_into_u64() {
2646 let id = AppId::from(321321);
2647 let value: u64 = id.into();
2648 assert_eq!(value, 321321);
2649 }
2650
2651 #[test]
2653 fn test_image_id_from_u64() {
2654 let id = ImageId::from(789789);
2655 assert_eq!(id.value(), 789789);
2656 }
2657
2658 #[test]
2659 fn test_image_id_display() {
2660 let id = ImageId::from(0xabcd1234);
2661 assert_eq!(format!("{}", id), "im-abcd1234");
2662 }
2663
2664 #[test]
2665 fn test_image_id_try_from_str_valid() {
2666 let id = ImageId::try_from("im-abcd1234").unwrap();
2667 assert_eq!(id.value(), 0xabcd1234);
2668 }
2669
2670 #[test]
2671 fn test_image_id_try_from_str_invalid_prefix() {
2672 let result = ImageId::try_from("image-123");
2673 assert!(result.is_err());
2674 match result {
2675 Err(Error::InvalidParameters(msg)) => {
2676 assert!(msg.contains("must start with 'im-'"));
2677 }
2678 _ => panic!("Expected InvalidParameters error"),
2679 }
2680 }
2681
2682 #[test]
2683 fn test_image_id_try_from_str_invalid_hex() {
2684 let result = ImageId::try_from("im-zzz");
2685 assert!(result.is_err());
2686 }
2687
2688 #[test]
2689 fn test_image_id_into_u64() {
2690 let id = ImageId::from(987987);
2691 let value: u64 = id.into();
2692 assert_eq!(value, 987987);
2693 }
2694
2695 #[test]
2697 fn test_id_types_equality() {
2698 let id1 = ProjectID::from(12345);
2699 let id2 = ProjectID::from(12345);
2700 let id3 = ProjectID::from(54321);
2701
2702 assert_eq!(id1, id2);
2703 assert_ne!(id1, id3);
2704 }
2705
2706 #[test]
2707 fn test_id_types_hash() {
2708 use std::collections::HashSet;
2709
2710 let mut set = HashSet::new();
2711 set.insert(DatasetID::from(100));
2712 set.insert(DatasetID::from(200));
2713 set.insert(DatasetID::from(100)); assert_eq!(set.len(), 2);
2716 assert!(set.contains(&DatasetID::from(100)));
2717 assert!(set.contains(&DatasetID::from(200)));
2718 }
2719
2720 #[test]
2721 fn test_id_types_copy_clone() {
2722 let id1 = ExperimentID::from(999);
2723 let id2 = id1; let id3 = id1; assert_eq!(id1, id2);
2727 assert_eq!(id1, id3);
2728 }
2729
2730 #[test]
2732 fn test_id_zero_value() {
2733 let id = ProjectID::from(0);
2734 assert_eq!(format!("{}", id), "p-0");
2735 assert_eq!(id.value(), 0);
2736 }
2737
2738 #[test]
2739 fn test_id_max_value() {
2740 let id = ProjectID::from(u64::MAX);
2741 assert_eq!(format!("{}", id), "p-ffffffffffffffff");
2742 assert_eq!(id.value(), u64::MAX);
2743 }
2744
2745 #[test]
2746 fn test_id_round_trip_conversion() {
2747 let original = 0xdeadbeef_u64;
2748 let id = TrainingSessionID::from(original);
2749 let back: u64 = id.into();
2750 assert_eq!(original, back);
2751 }
2752
2753 #[test]
2754 fn test_id_case_insensitive_hex() {
2755 let id1 = DatasetID::from_str("ds-ABCDEF").unwrap();
2757 let id2 = DatasetID::from_str("ds-abcdef").unwrap();
2758 assert_eq!(id1.value(), id2.value());
2759 }
2760
2761 #[test]
2762 fn test_id_with_leading_zeros() {
2763 let id = ProjectID::from_str("p-00001234").unwrap();
2764 assert_eq!(id.value(), 0x1234);
2765 }
2766
2767 #[test]
2769 fn test_parameter_integer() {
2770 let param = Parameter::Integer(42);
2771 match param {
2772 Parameter::Integer(val) => assert_eq!(val, 42),
2773 _ => panic!("Expected Integer variant"),
2774 }
2775 }
2776
2777 #[test]
2778 fn test_parameter_real() {
2779 let param = Parameter::Real(2.5);
2780 match param {
2781 Parameter::Real(val) => assert_eq!(val, 2.5),
2782 _ => panic!("Expected Real variant"),
2783 }
2784 }
2785
2786 #[test]
2787 fn test_parameter_boolean() {
2788 let param = Parameter::Boolean(true);
2789 match param {
2790 Parameter::Boolean(val) => assert!(val),
2791 _ => panic!("Expected Boolean variant"),
2792 }
2793 }
2794
2795 #[test]
2796 fn test_parameter_string() {
2797 let param = Parameter::String("test".to_string());
2798 match param {
2799 Parameter::String(val) => assert_eq!(val, "test"),
2800 _ => panic!("Expected String variant"),
2801 }
2802 }
2803
2804 #[test]
2805 fn test_parameter_array() {
2806 let param = Parameter::Array(vec![
2807 Parameter::Integer(1),
2808 Parameter::Integer(2),
2809 Parameter::Integer(3),
2810 ]);
2811 match param {
2812 Parameter::Array(arr) => assert_eq!(arr.len(), 3),
2813 _ => panic!("Expected Array variant"),
2814 }
2815 }
2816
2817 #[test]
2818 fn test_parameter_object() {
2819 let mut map = HashMap::new();
2820 map.insert("key".to_string(), Parameter::Integer(100));
2821 let param = Parameter::Object(map);
2822 match param {
2823 Parameter::Object(obj) => {
2824 assert_eq!(obj.len(), 1);
2825 assert!(obj.contains_key("key"));
2826 }
2827 _ => panic!("Expected Object variant"),
2828 }
2829 }
2830
2831 #[test]
2832 fn test_parameter_clone() {
2833 let param1 = Parameter::Integer(42);
2834 let param2 = param1.clone();
2835 assert_eq!(param1, param2);
2836 }
2837
2838 #[test]
2839 fn test_parameter_nested() {
2840 let inner_array = Parameter::Array(vec![Parameter::Integer(1), Parameter::Integer(2)]);
2841 let outer_array = Parameter::Array(vec![inner_array.clone(), inner_array]);
2842
2843 match outer_array {
2844 Parameter::Array(arr) => {
2845 assert_eq!(arr.len(), 2);
2846 }
2847 _ => panic!("Expected Array variant"),
2848 }
2849 }
2850
2851 macro_rules! test_typeid_conversions {
2854 ($test_name:ident, $type:ty, $prefix:literal, $wrong_prefix:literal) => {
2855 #[test]
2856 fn $test_name() {
2857 let id = <$type>::from(0xabc123);
2859 assert_eq!(id.value(), 0xabc123);
2860
2861 assert_eq!(format!("{}", id), concat!($prefix, "-abc123"));
2863
2864 let id: $type = concat!($prefix, "-abc123").parse().unwrap();
2866 assert_eq!(id.value(), 0xabc123);
2867
2868 assert!(concat!($wrong_prefix, "-abc").parse::<$type>().is_err());
2870
2871 assert!("abc123".parse::<$type>().is_err());
2873
2874 assert!(concat!($prefix, "-xyz").parse::<$type>().is_err());
2876
2877 let id = <$type>::try_from(concat!($prefix, "-abc123")).unwrap();
2879 assert_eq!(id.value(), 0xabc123);
2880
2881 let id = <$type>::try_from(concat!($prefix, "-abc123").to_string()).unwrap();
2883 assert_eq!(id.value(), 0xabc123);
2884
2885 let id = <$type>::from(0xabc123);
2887 let json = serde_json::to_string(&id).unwrap();
2888 let parsed: $type = serde_json::from_str(&json).unwrap();
2889 assert_eq!(id, parsed);
2890
2891 let id = <$type>::from(0xabc123);
2893 let val: u64 = id.into();
2894 assert_eq!(val, 0xabc123);
2895 }
2896 };
2897 }
2898
2899 test_typeid_conversions!(test_organization_id_conversions, OrganizationID, "org", "p");
2900 test_typeid_conversions!(test_project_id_conversions, ProjectID, "p", "org");
2901 test_typeid_conversions!(test_experiment_id_conversions, ExperimentID, "exp", "p");
2902 test_typeid_conversions!(
2903 test_training_session_id_conversions,
2904 TrainingSessionID,
2905 "t",
2906 "v"
2907 );
2908 test_typeid_conversions!(
2909 test_validation_session_id_conversions,
2910 ValidationSessionID,
2911 "v",
2912 "t"
2913 );
2914 test_typeid_conversions!(test_snapshot_id_conversions, SnapshotID, "ss", "ds");
2915 test_typeid_conversions!(test_task_id_conversions, TaskID, "task", "t");
2916 test_typeid_conversions!(test_dataset_id_conversions, DatasetID, "ds", "ss");
2917 test_typeid_conversions!(
2918 test_annotation_set_id_conversions,
2919 AnnotationSetID,
2920 "as",
2921 "ds"
2922 );
2923 test_typeid_conversions!(test_sample_id_conversions, SampleID, "s", "p");
2924 test_typeid_conversions!(test_app_id_conversions, AppId, "app", "p");
2925 test_typeid_conversions!(test_image_id_conversions, ImageId, "im", "se");
2926 test_typeid_conversions!(test_sequence_id_conversions, SequenceId, "se", "im");
2927}
2928
2929#[cfg(test)]
2930mod tests_task_data_list {
2931 use super::*;
2932
2933 #[test]
2934 fn task_data_list_deserializes_from_server_shape() {
2935 let json = r#"{
2936 "server": "test.edgefirst.studio",
2937 "organization_uid": "org-abc123",
2938 "traces": ["trace/imx95.json"],
2939 "data": {
2940 "predictions": ["predictions.parquet"],
2941 "trace": ["imx95.json"]
2942 }
2943 }"#;
2944 let parsed: TaskDataList = serde_json::from_str(json).unwrap();
2945 assert_eq!(parsed.server, "test.edgefirst.studio");
2946 assert_eq!(parsed.organization_uid, "org-abc123");
2947 assert_eq!(parsed.traces, vec!["trace/imx95.json"]);
2948 assert_eq!(
2949 parsed.data.get("predictions").unwrap(),
2950 &vec!["predictions.parquet".to_string()]
2951 );
2952 }
2953}
2954
2955#[cfg(test)]
2956mod tests_upload_data {
2957 #[test]
2961 fn folder_empty_string_is_normalised() {
2962 let folder: Option<&str> = Some("");
2963 assert!(folder.filter(|s| !s.is_empty()).is_none());
2964
2965 let folder_real: Option<&str> = Some("predictions");
2966 assert!(folder_real.filter(|s| !s.is_empty()).is_some());
2967 }
2968}
2969
2970#[cfg(test)]
2971mod tests_job_struct {
2972 use super::*;
2973
2974 #[test]
2975 fn job_deserializes_with_all_fields() {
2976 let json = r#"{
2977 "code": "edgefirst-validator:2.9.5",
2978 "title": "EdgeFirst Validator",
2979 "job_name": "smoke-test",
2980 "job_id": "aws-batch-abc",
2981 "state": "RUNNING",
2982 "launch": "2026-05-14T15:00:00Z",
2983 "task_id": 6789
2984 }"#;
2985 let job: Job = serde_json::from_str(json).unwrap();
2986 assert_eq!(job.code, "edgefirst-validator:2.9.5");
2987 assert_eq!(job.title, "EdgeFirst Validator");
2988 assert_eq!(job.job_name, "smoke-test");
2989 assert_eq!(job.job_id, "aws-batch-abc");
2990 assert_eq!(job.state, "RUNNING");
2991 assert!(job.launch.is_some());
2992 assert_eq!(job.task_id, 6789);
2993 }
2994
2995 #[test]
2996 fn job_tolerates_missing_optional_fields() {
2997 let json = r#"{ "task_id": 42 }"#;
3001 let job: Job = serde_json::from_str(json).unwrap();
3002 assert_eq!(job.task_id, 42);
3003 assert!(job.code.is_empty());
3004 assert!(job.title.is_empty());
3005 assert!(job.job_name.is_empty());
3006 assert!(job.job_id.is_empty());
3007 assert!(job.state.is_empty());
3008 assert!(job.launch.is_none());
3009 }
3010
3011 #[test]
3012 fn job_task_id_accessor_saturates_negative_to_zero() {
3013 let job = Job {
3018 code: String::new(),
3019 title: String::new(),
3020 job_name: String::new(),
3021 job_id: String::new(),
3022 state: String::new(),
3023 launch: None,
3024 task_id: -1,
3025 };
3026 assert_eq!(job.task_id().value(), 0);
3027 }
3028
3029 #[test]
3030 fn job_task_id_accessor_passes_through_positive_values() {
3031 let job = Job {
3032 code: String::new(),
3033 title: String::new(),
3034 job_name: String::new(),
3035 job_id: String::new(),
3036 state: String::new(),
3037 launch: None,
3038 task_id: 12345,
3039 };
3040 assert_eq!(job.task_id().value(), 12345);
3041 }
3042
3043 #[test]
3044 fn job_ignores_unknown_fields() {
3045 let json = r#"{
3049 "code": "x",
3050 "task_id": 1,
3051 "docker_task": { "image": "x" },
3052 "aws_region": "us-east-1",
3053 "tags": ["a", "b"]
3054 }"#;
3055 let job: Job = serde_json::from_str(json).unwrap();
3056 assert_eq!(job.task_id, 1);
3057 }
3058}
3059
3060#[cfg(test)]
3061mod tests_task_info_schema_tolerance {
3062 use super::*;
3063
3064 #[test]
3069 fn task_info_accepts_task_description_field() {
3070 let json = r#"{
3072 "id": 6699,
3073 "type": "edgefirst-validator:2.9.5",
3074 "task_description": "Profiler run for IMX95",
3075 "status": "running"
3076 }"#;
3077 let info: TaskInfo = serde_json::from_str(json).unwrap();
3078 assert_eq!(info.description(), "Profiler run for IMX95");
3079 }
3080
3081 #[test]
3082 fn task_info_accepts_legacy_description_field() {
3083 let json = r#"{
3085 "id": 6699,
3086 "type": "edgefirst-validator:2.9.5",
3087 "description": "Legacy description"
3088 }"#;
3089 let info: TaskInfo = serde_json::from_str(json).unwrap();
3090 assert_eq!(info.description(), "Legacy description");
3091 }
3092
3093 #[test]
3094 fn task_info_tolerates_missing_description() {
3095 let json = r#"{
3097 "id": 6699,
3098 "type": "x"
3099 }"#;
3100 let info: TaskInfo = serde_json::from_str(json).unwrap();
3101 assert!(info.description().is_empty());
3102 }
3103
3104 #[test]
3105 fn task_info_tolerates_missing_dates_via_default() {
3106 let json = r#"{
3108 "id": 6699,
3109 "type": "x"
3110 }"#;
3111 let info: TaskInfo = serde_json::from_str(json).unwrap();
3112 assert_eq!(info.id().value(), 6699);
3114 }
3115
3116 #[test]
3117 fn task_info_status_accessor_returns_option() {
3118 let json = r#"{
3119 "id": 1,
3120 "type": "x"
3121 }"#;
3122 let info: TaskInfo = serde_json::from_str(json).unwrap();
3123 assert!(info.status().is_none());
3124 }
3125
3126 #[test]
3127 fn task_info_stages_returns_empty_map_when_unset() {
3128 let json = r#"{
3129 "id": 1,
3130 "type": "x"
3131 }"#;
3132 let info: TaskInfo = serde_json::from_str(json).unwrap();
3133 let stages = info.stages();
3134 assert!(stages.is_empty());
3135 }
3136}
3137
3138#[cfg(test)]
3139mod tests_stage_struct {
3140 use super::*;
3141
3142 #[test]
3143 fn stage_new_sets_only_supplied_fields() {
3144 let stage = Stage::new(
3145 None,
3146 "download".into(),
3147 Some("running".into()),
3148 Some("fetching".into()),
3149 42,
3150 );
3151 assert!(stage.task_id().is_none());
3152 assert_eq!(stage.stage(), "download");
3153 assert_eq!(stage.status().as_deref(), Some("running"));
3154 assert_eq!(stage.message().as_deref(), Some("fetching"));
3155 assert_eq!(stage.percentage(), 42);
3156 assert!(stage.description().is_none());
3158 }
3159
3160 #[test]
3161 fn stage_serializes_without_optional_none_fields() {
3162 let stage = Stage::new(None, "init".into(), None, None, 0);
3164 let json = serde_json::to_value(&stage).unwrap();
3165 assert!(json.get("status").is_none(), "got: {json}");
3166 assert!(json.get("message").is_none(), "got: {json}");
3167 assert!(json.get("docker_task_id").is_none(), "got: {json}");
3168 assert_eq!(json["stage"], "init");
3170 assert_eq!(json["percentage"], 0);
3171 }
3172
3173 #[test]
3174 fn stage_serializes_task_id_when_present() {
3175 let task_id = TaskID::from(0xdeadu64);
3176 let stage = Stage::new(Some(task_id), "x".into(), None, None, 0);
3177 let json = serde_json::to_value(&stage).unwrap();
3178 assert!(json.get("docker_task_id").is_some());
3181 }
3182
3183 #[test]
3184 fn stage_round_trips_through_json() {
3185 let stage = Stage::new(
3186 None,
3187 "train".into(),
3188 Some("done".into()),
3189 Some("epoch 100".into()),
3190 100,
3191 );
3192 let s = serde_json::to_string(&stage).unwrap();
3193 let back: Stage = serde_json::from_str(&s).unwrap();
3194 assert_eq!(back.stage(), "train");
3195 assert_eq!(back.status().as_deref(), Some("done"));
3196 assert_eq!(back.message().as_deref(), Some("epoch 100"));
3197 assert_eq!(back.percentage(), 100);
3198 }
3199}
3200
3201#[cfg(test)]
3202mod tests_task_data_list_extra {
3203 use super::*;
3204
3205 #[test]
3206 fn task_data_list_with_empty_data_map() {
3207 let json = r#"{
3208 "server": "studio",
3209 "organization_uid": "org-1",
3210 "traces": [],
3211 "data": {}
3212 }"#;
3213 let parsed: TaskDataList = serde_json::from_str(json).unwrap();
3214 assert!(parsed.traces.is_empty());
3215 assert!(parsed.data.is_empty());
3216 }
3217
3218 #[test]
3219 fn task_data_list_multiple_folders() {
3220 let json = r#"{
3221 "server": "studio",
3222 "organization_uid": "org-1",
3223 "traces": ["t1", "t2"],
3224 "data": {
3225 "predictions": ["a.parquet", "b.parquet"],
3226 "metrics": ["loss.json"]
3227 }
3228 }"#;
3229 let parsed: TaskDataList = serde_json::from_str(json).unwrap();
3230 assert_eq!(parsed.traces.len(), 2);
3231 assert_eq!(parsed.data.len(), 2);
3232 assert_eq!(parsed.data["predictions"].len(), 2);
3233 }
3234}
3235
3236#[cfg(test)]
3237mod tests_artifact_struct {
3238 use super::*;
3239
3240 #[test]
3241 fn artifact_accessors_return_strs() {
3242 let json = r#"{ "name": "best.onnx", "modelType": "yolo" }"#;
3245 let a: Artifact = serde_json::from_str(json).unwrap();
3246 assert_eq!(a.name(), "best.onnx");
3247 assert_eq!(a.model_type(), "yolo");
3248 }
3249}
3250
3251#[cfg(test)]
3252mod tests_task_status_serialize {
3253 use super::*;
3254
3255 #[test]
3256 fn task_status_uses_docker_task_id_wire_field() {
3257 let s = TaskStatus {
3258 task_id: TaskID::from(0x1a2bu64),
3259 status: "training".into(),
3260 };
3261 let json = serde_json::to_value(&s).unwrap();
3262 assert!(json.get("docker_task_id").is_some(), "got: {json}");
3264 assert_eq!(json["status"], "training");
3265 }
3266}
3267
3268#[cfg(test)]
3269mod tests_task_stages_serialize {
3270 use super::*;
3271
3272 #[test]
3273 fn task_stages_omits_empty_vec() {
3274 let stages = TaskStages {
3275 task_id: TaskID::from(1u64),
3276 stages: Vec::new(),
3277 };
3278 let json = serde_json::to_value(&stages).unwrap();
3279 assert!(json.get("stages").is_none(), "got: {json}");
3281 }
3282
3283 #[test]
3284 fn task_stages_serializes_non_empty_vec() {
3285 let stages = TaskStages {
3286 task_id: TaskID::from(1u64),
3287 stages: vec![std::collections::HashMap::from([(
3288 "stage".to_string(),
3289 "download".to_string(),
3290 )])],
3291 };
3292 let json = serde_json::to_value(&stages).unwrap();
3293 assert_eq!(json["stages"][0]["stage"], "download");
3294 }
3295}