1use crate::{AnnotationSet, Client, Dataset, Error, Progress, Sample, client};
5use chrono::{DateTime, Utc};
6use log::trace;
7use reqwest::multipart::{Form, Part};
8use serde::{Deserialize, Deserializer, Serialize};
9use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr};
10
11#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
41#[serde(untagged)]
42pub enum Parameter {
43 Integer(i64),
45 Real(f64),
47 Boolean(bool),
49 String(String),
51 Array(Vec<Parameter>),
53 Object(HashMap<String, Parameter>),
55}
56
57#[derive(Deserialize)]
58pub struct LoginResult {
59 pub(crate) token: String,
60}
61
62macro_rules! typeid {
71 ($(#[$meta:meta])* $name:ident, $prefix:literal) => {
72 $(#[$meta])*
73 #[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq, Eq, Hash)]
74 pub struct $name(u64);
75
76 impl Display for $name {
77 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
78 write!(f, concat!($prefix, "-{:x}"), self.0)
79 }
80 }
81
82 impl From<u64> for $name {
83 fn from(id: u64) -> Self {
84 $name(id)
85 }
86 }
87
88 impl From<$name> for u64 {
89 fn from(val: $name) -> Self {
90 val.0
91 }
92 }
93
94 impl $name {
95 pub fn value(&self) -> u64 {
97 self.0
98 }
99 }
100
101 impl TryFrom<&str> for $name {
102 type Error = Error;
103
104 fn try_from(s: &str) -> Result<Self, Self::Error> {
105 $name::from_str(s)
106 }
107 }
108
109 impl TryFrom<String> for $name {
110 type Error = Error;
111
112 fn try_from(s: String) -> Result<Self, Self::Error> {
113 $name::from_str(&s)
114 }
115 }
116
117 impl FromStr for $name {
118 type Err = Error;
119
120 fn from_str(s: &str) -> Result<Self, Self::Err> {
121 let hex_part =
122 s.strip_prefix(concat!($prefix, "-")).ok_or_else(|| {
123 Error::InvalidParameters(format!(
124 "{} must start with '{}-' prefix",
125 stringify!($name),
126 $prefix
127 ))
128 })?;
129 let id = u64::from_str_radix(hex_part, 16)?;
130 Ok($name(id))
131 }
132 }
133 };
134}
135
136typeid!(
137 OrganizationID,
157 "org"
158);
159
160#[derive(Deserialize, Clone, Debug)]
181pub struct Organization {
182 id: OrganizationID,
183 name: String,
184 #[serde(rename = "latest_credit")]
185 credits: i64,
186}
187
188impl Display for Organization {
189 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
190 write!(f, "{}", self.name())
191 }
192}
193
194impl Organization {
195 pub fn id(&self) -> OrganizationID {
196 self.id
197 }
198
199 pub fn name(&self) -> &str {
200 &self.name
201 }
202
203 pub fn credits(&self) -> i64 {
204 self.credits
205 }
206}
207
208typeid!(
209 ProjectID,
230 "p"
231);
232
233typeid!(
234 ExperimentID,
255 "exp"
256);
257
258typeid!(
259 TrainingSessionID,
280 "t"
281);
282
283typeid!(
284 ValidationSessionID,
304 "v"
305);
306
307typeid!(
308 SnapshotID,
324 "ss"
325);
326
327typeid!(
328 TaskID,
344 "task"
345);
346
347typeid!(
348 DatasetID,
369 "ds"
370);
371
372typeid!(
373 AnnotationSetID,
389 "as"
390);
391
392typeid!(
393 SampleID,
409 "s"
410);
411
412typeid!(
413 AppId,
419 "app"
420);
421
422typeid!(
423 ImageId,
429 "im"
430);
431
432typeid!(
433 SequenceId,
439 "se"
440);
441
442#[derive(Deserialize, Clone, Debug)]
446pub struct Project {
447 id: ProjectID,
448 name: String,
449 description: String,
450}
451
452impl Display for Project {
453 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
454 write!(f, "{} {}", self.id(), self.name())
455 }
456}
457
458impl Project {
459 pub fn id(&self) -> ProjectID {
460 self.id
461 }
462
463 pub fn name(&self) -> &str {
464 &self.name
465 }
466
467 pub fn description(&self) -> &str {
468 &self.description
469 }
470
471 pub async fn datasets(
472 &self,
473 client: &client::Client,
474 name: Option<&str>,
475 ) -> Result<Vec<Dataset>, Error> {
476 client.datasets(self.id, name).await
477 }
478
479 pub async fn experiments(
480 &self,
481 client: &client::Client,
482 name: Option<&str>,
483 ) -> Result<Vec<Experiment>, Error> {
484 client.experiments(self.id, name).await
485 }
486}
487
488#[derive(Deserialize, Debug)]
489pub struct SamplesCountResult {
490 pub total: u64,
491}
492
493#[derive(Serialize, Clone, Debug)]
494pub struct SamplesListParams {
495 pub dataset_id: DatasetID,
496 #[serde(skip_serializing_if = "Option::is_none")]
497 pub annotation_set_id: Option<AnnotationSetID>,
498 #[serde(skip_serializing_if = "Option::is_none")]
499 pub continue_token: Option<String>,
500 #[serde(skip_serializing_if = "Vec::is_empty")]
501 pub types: Vec<String>,
502 #[serde(skip_serializing_if = "Vec::is_empty")]
503 pub group_names: Vec<String>,
504}
505
506#[derive(Deserialize, Debug)]
507pub struct SamplesListResult {
508 pub samples: Vec<Sample>,
509 pub continue_token: Option<String>,
510}
511
512#[derive(Serialize, Clone, Debug)]
517pub struct SamplesPopulateParams {
518 pub dataset_id: DatasetID,
519 #[serde(skip_serializing_if = "Option::is_none")]
520 pub annotation_set_id: Option<AnnotationSetID>,
521 #[serde(skip_serializing_if = "Option::is_none")]
522 pub presigned_urls: Option<bool>,
523 pub samples: Vec<Sample>,
524}
525
526#[derive(Deserialize, Debug, Clone)]
532pub struct SamplesPopulateResult {
533 pub uuid: String,
535 pub urls: Vec<PresignedUrl>,
537}
538
539#[derive(Deserialize, Debug, Clone)]
541pub struct PresignedUrl {
542 pub filename: String,
544 pub key: String,
546 pub url: String,
548}
549
550#[derive(Serialize, Clone, Debug)]
563pub struct ServerAnnotation {
564 #[serde(skip_serializing_if = "Option::is_none")]
566 pub label_id: Option<u64>,
567 #[serde(skip_serializing_if = "Option::is_none")]
569 pub label_index: Option<u64>,
570 #[serde(skip_serializing_if = "Option::is_none")]
572 pub label_name: Option<String>,
573 #[serde(rename = "type")]
575 pub annotation_type: String,
576 pub x: f64,
578 pub y: f64,
580 pub w: f64,
582 pub h: f64,
584 pub score: f64,
586 #[serde(skip_serializing_if = "String::is_empty")]
588 pub polygon: String,
589 pub image_id: u64,
591 pub annotation_set_id: u64,
593 #[serde(skip_serializing_if = "Option::is_none")]
595 pub object_reference: Option<String>,
596}
597
598#[derive(Serialize, Debug)]
600pub struct AnnotationAddBulkParams {
601 pub annotation_set_id: u64,
602 pub annotations: Vec<ServerAnnotation>,
603}
604
605#[derive(Serialize, Debug)]
607pub struct AnnotationBulkDeleteParams {
608 pub annotation_set_id: u64,
609 pub annotation_types: Vec<String>,
610 #[serde(skip_serializing_if = "Vec::is_empty")]
612 pub image_ids: Vec<u64>,
613 #[serde(skip_serializing_if = "Option::is_none")]
615 pub delete_all: Option<bool>,
616}
617
618#[derive(Deserialize)]
619pub struct Snapshot {
620 id: SnapshotID,
621 description: String,
622 status: String,
623 path: String,
624 #[serde(rename = "date")]
625 created: DateTime<Utc>,
626}
627
628impl Display for Snapshot {
629 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
630 write!(f, "{} {}", self.id, self.description)
631 }
632}
633
634impl Snapshot {
635 pub fn id(&self) -> SnapshotID {
636 self.id
637 }
638
639 pub fn description(&self) -> &str {
640 &self.description
641 }
642
643 pub fn status(&self) -> &str {
644 &self.status
645 }
646
647 pub fn path(&self) -> &str {
648 &self.path
649 }
650
651 pub fn created(&self) -> &DateTime<Utc> {
652 &self.created
653 }
654}
655
656#[derive(Serialize, Debug)]
657pub struct SnapshotRestore {
658 pub project_id: ProjectID,
659 pub snapshot_id: SnapshotID,
660 pub fps: u64,
661 #[serde(rename = "enabled_topics", skip_serializing_if = "Vec::is_empty")]
662 pub topics: Vec<String>,
663 #[serde(rename = "label_names", skip_serializing_if = "Vec::is_empty")]
664 pub autolabel: Vec<String>,
665 #[serde(rename = "depth_gen")]
666 pub autodepth: bool,
667 pub agtg_pipeline: bool,
668 #[serde(skip_serializing_if = "Option::is_none")]
669 pub dataset_name: Option<String>,
670 #[serde(skip_serializing_if = "Option::is_none")]
671 pub dataset_description: Option<String>,
672}
673
674#[derive(Deserialize, Debug)]
675pub struct SnapshotRestoreResult {
676 pub id: SnapshotID,
677 pub description: String,
678 pub dataset_name: String,
679 pub dataset_id: DatasetID,
680 pub annotation_set_id: AnnotationSetID,
681 #[serde(default)]
682 pub task_id: Option<TaskID>,
683 pub date: DateTime<Utc>,
684}
685
686#[derive(Serialize, Debug)]
691pub struct SnapshotCreateFromDataset {
692 pub description: String,
694 pub dataset_id: DatasetID,
696 pub annotation_set_id: AnnotationSetID,
698}
699
700#[derive(Deserialize, Debug)]
704pub struct SnapshotFromDatasetResult {
705 #[serde(alias = "snapshot_id")]
707 pub id: SnapshotID,
708 #[serde(default)]
710 pub task_id: Option<TaskID>,
711}
712
713#[derive(Deserialize)]
714pub struct Experiment {
715 id: ExperimentID,
716 project_id: ProjectID,
717 name: String,
718 description: String,
719}
720
721impl Display for Experiment {
722 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
723 write!(f, "{} {}", self.id, self.name)
724 }
725}
726
727impl Experiment {
728 pub fn id(&self) -> ExperimentID {
729 self.id
730 }
731
732 pub fn project_id(&self) -> ProjectID {
733 self.project_id
734 }
735
736 pub fn name(&self) -> &str {
737 &self.name
738 }
739
740 pub fn description(&self) -> &str {
741 &self.description
742 }
743
744 pub async fn project(&self, client: &client::Client) -> Result<Project, Error> {
745 client.project(self.project_id).await
746 }
747
748 pub async fn training_sessions(
749 &self,
750 client: &client::Client,
751 name: Option<&str>,
752 ) -> Result<Vec<TrainingSession>, Error> {
753 client.training_sessions(self.id, name).await
754 }
755}
756
757#[derive(Serialize, Debug)]
758pub struct PublishMetrics {
759 #[serde(rename = "trainer_session_id", skip_serializing_if = "Option::is_none")]
760 pub trainer_session_id: Option<TrainingSessionID>,
761 #[serde(
762 rename = "validate_session_id",
763 skip_serializing_if = "Option::is_none"
764 )]
765 pub validate_session_id: Option<ValidationSessionID>,
766 pub metrics: HashMap<String, Parameter>,
767}
768
769#[derive(Deserialize)]
770struct TrainingSessionParams {
771 model_params: HashMap<String, Parameter>,
772 dataset_params: DatasetParams,
773}
774
775#[derive(Deserialize)]
776pub struct TrainingSession {
777 id: TrainingSessionID,
778 #[serde(rename = "trainer_id")]
779 experiment_id: ExperimentID,
780 model: String,
781 name: String,
782 description: String,
783 params: TrainingSessionParams,
784 #[serde(rename = "docker_task")]
785 task: Task,
786}
787
788impl Display for TrainingSession {
789 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
790 write!(f, "{} {}", self.id, self.name())
791 }
792}
793
794impl TrainingSession {
795 pub fn id(&self) -> TrainingSessionID {
796 self.id
797 }
798
799 pub fn name(&self) -> &str {
800 &self.name
801 }
802
803 pub fn description(&self) -> &str {
804 &self.description
805 }
806
807 pub fn model(&self) -> &str {
808 &self.model
809 }
810
811 pub fn experiment_id(&self) -> ExperimentID {
812 self.experiment_id
813 }
814
815 pub fn task(&self) -> Task {
816 self.task.clone()
817 }
818
819 pub fn model_params(&self) -> &HashMap<String, Parameter> {
820 &self.params.model_params
821 }
822
823 pub fn dataset_params(&self) -> &DatasetParams {
824 &self.params.dataset_params
825 }
826
827 pub fn train_group(&self) -> &str {
828 &self.params.dataset_params.train_group
829 }
830
831 pub fn val_group(&self) -> &str {
832 &self.params.dataset_params.val_group
833 }
834
835 pub async fn experiment(&self, client: &client::Client) -> Result<Experiment, Error> {
836 client.experiment(self.experiment_id).await
837 }
838
839 pub async fn dataset(&self, client: &client::Client) -> Result<Dataset, Error> {
840 client.dataset(self.params.dataset_params.dataset_id).await
841 }
842
843 pub async fn annotation_set(&self, client: &client::Client) -> Result<AnnotationSet, Error> {
844 client
845 .annotation_set(self.params.dataset_params.annotation_set_id)
846 .await
847 }
848
849 pub async fn artifacts(&self, client: &client::Client) -> Result<Vec<Artifact>, Error> {
850 client.artifacts(self.id).await
851 }
852
853 pub async fn metrics(
854 &self,
855 client: &client::Client,
856 ) -> Result<HashMap<String, Parameter>, Error> {
857 #[derive(Deserialize)]
858 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
859 enum Response {
860 Empty {},
861 Map(HashMap<String, Parameter>),
862 String(String),
863 }
864
865 let params = HashMap::from([("trainer_session_id", self.id().value())]);
866 let resp: Response = client
867 .rpc("trainer.session.metrics".to_owned(), Some(params))
868 .await?;
869
870 Ok(match resp {
871 Response::String(metrics) => serde_json::from_str(&metrics)?,
872 Response::Map(metrics) => metrics,
873 Response::Empty {} => HashMap::new(),
874 })
875 }
876
877 pub async fn set_metrics(
878 &self,
879 client: &client::Client,
880 metrics: HashMap<String, Parameter>,
881 ) -> Result<(), Error> {
882 let metrics = PublishMetrics {
883 trainer_session_id: Some(self.id()),
884 validate_session_id: None,
885 metrics,
886 };
887
888 let _: String = client
889 .rpc("trainer.session.metrics".to_owned(), Some(metrics))
890 .await?;
891
892 Ok(())
893 }
894
895 pub async fn download_artifact(
897 &self,
898 client: &client::Client,
899 filename: &str,
900 ) -> Result<Vec<u8>, Error> {
901 client
902 .fetch(&format!(
903 "download_model?training_session_id={}&file={}",
904 self.id().value(),
905 filename
906 ))
907 .await
908 }
909
910 pub async fn upload_artifact(
914 &self,
915 client: &client::Client,
916 filename: &str,
917 path: PathBuf,
918 ) -> Result<(), Error> {
919 self.upload(client, &[(format!("artifacts/{}", filename), path)])
920 .await
921 }
922
923 pub async fn download_checkpoint(
925 &self,
926 client: &client::Client,
927 filename: &str,
928 ) -> Result<Vec<u8>, Error> {
929 client
930 .fetch(&format!(
931 "download_checkpoint?folder=checkpoints&training_session_id={}&file={}",
932 self.id().value(),
933 filename
934 ))
935 .await
936 }
937
938 pub async fn upload_checkpoint(
942 &self,
943 client: &client::Client,
944 filename: &str,
945 path: PathBuf,
946 ) -> Result<(), Error> {
947 self.upload(client, &[(format!("checkpoints/{}", filename), path)])
948 .await
949 }
950
951 pub async fn download(&self, client: &client::Client, filename: &str) -> Result<String, Error> {
955 #[derive(Serialize)]
956 struct DownloadRequest {
957 session_id: TrainingSessionID,
958 file_path: String,
959 }
960
961 let params = DownloadRequest {
962 session_id: self.id(),
963 file_path: filename.to_string(),
964 };
965
966 client
967 .rpc("trainer.download.file".to_owned(), Some(params))
968 .await
969 }
970
971 pub async fn upload(
972 &self,
973 client: &client::Client,
974 files: &[(String, PathBuf)],
975 ) -> Result<(), Error> {
976 let mut parts = Form::new().part(
977 "params",
978 Part::text(format!("{{ \"session_id\": {} }}", self.id().value())),
979 );
980
981 for (name, path) in files {
982 let file_part = Part::file(path).await?.file_name(name.to_owned());
983 parts = parts.part("file", file_part);
984 }
985
986 let result = client.post_multipart("trainer.upload.files", parts).await?;
987 trace!("TrainingSession::upload: {:?}", result);
988 Ok(())
989 }
990}
991
992#[derive(Deserialize, Clone, Debug)]
993pub struct ValidationSession {
994 id: ValidationSessionID,
995 description: String,
996 dataset_id: DatasetID,
997 experiment_id: ExperimentID,
998 training_session_id: TrainingSessionID,
999 #[serde(rename = "gt_annotation_set_id")]
1000 annotation_set_id: AnnotationSetID,
1001 #[serde(deserialize_with = "validation_session_params")]
1002 params: HashMap<String, Parameter>,
1003 #[serde(rename = "docker_task")]
1004 task: Task,
1005}
1006
1007fn validation_session_params<'de, D>(
1008 deserializer: D,
1009) -> Result<HashMap<String, Parameter>, D::Error>
1010where
1011 D: Deserializer<'de>,
1012{
1013 #[derive(Deserialize)]
1014 struct ModelParams {
1015 validation: Option<HashMap<String, Parameter>>,
1016 }
1017
1018 #[derive(Deserialize)]
1019 struct ValidateParams {
1020 model: String,
1021 }
1022
1023 #[derive(Deserialize)]
1024 struct Params {
1025 model_params: ModelParams,
1026 validate_params: ValidateParams,
1027 }
1028
1029 let params = Params::deserialize(deserializer)?;
1030 let params = match params.model_params.validation {
1031 Some(mut map) => {
1032 map.insert(
1033 "model".to_string(),
1034 Parameter::String(params.validate_params.model),
1035 );
1036 map
1037 }
1038 None => HashMap::from([(
1039 "model".to_string(),
1040 Parameter::String(params.validate_params.model),
1041 )]),
1042 };
1043
1044 Ok(params)
1045}
1046
1047impl ValidationSession {
1048 pub fn id(&self) -> ValidationSessionID {
1049 self.id
1050 }
1051
1052 pub fn name(&self) -> &str {
1053 self.task.name()
1054 }
1055
1056 pub fn description(&self) -> &str {
1057 &self.description
1058 }
1059
1060 pub fn dataset_id(&self) -> DatasetID {
1061 self.dataset_id
1062 }
1063
1064 pub fn experiment_id(&self) -> ExperimentID {
1065 self.experiment_id
1066 }
1067
1068 pub fn training_session_id(&self) -> TrainingSessionID {
1069 self.training_session_id
1070 }
1071
1072 pub fn annotation_set_id(&self) -> AnnotationSetID {
1073 self.annotation_set_id
1074 }
1075
1076 pub fn params(&self) -> &HashMap<String, Parameter> {
1077 &self.params
1078 }
1079
1080 pub fn task(&self) -> &Task {
1081 &self.task
1082 }
1083
1084 pub async fn metrics(
1085 &self,
1086 client: &client::Client,
1087 ) -> Result<HashMap<String, Parameter>, Error> {
1088 #[derive(Deserialize)]
1089 #[serde(untagged, deny_unknown_fields, expecting = "map, empty map or string")]
1090 enum Response {
1091 Empty {},
1092 Map(HashMap<String, Parameter>),
1093 String(String),
1094 }
1095
1096 let params = HashMap::from([("validate_session_id", self.id().value())]);
1097 let resp: Response = client
1098 .rpc("validate.session.metrics".to_owned(), Some(params))
1099 .await?;
1100
1101 Ok(match resp {
1102 Response::String(metrics) => serde_json::from_str(&metrics)?,
1103 Response::Map(metrics) => metrics,
1104 Response::Empty {} => HashMap::new(),
1105 })
1106 }
1107
1108 pub async fn set_metrics(
1109 &self,
1110 client: &client::Client,
1111 metrics: HashMap<String, Parameter>,
1112 ) -> Result<(), Error> {
1113 let metrics = PublishMetrics {
1114 trainer_session_id: None,
1115 validate_session_id: Some(self.id()),
1116 metrics,
1117 };
1118
1119 let _: String = client
1120 .rpc("validate.session.metrics".to_owned(), Some(metrics))
1121 .await?;
1122
1123 Ok(())
1124 }
1125
1126 pub async fn upload_data(
1151 &self,
1152 client: &client::Client,
1153 files: &[(String, std::path::PathBuf)],
1154 folder: Option<&str>,
1155 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1156 ) -> Result<(), Error> {
1157 use futures::StreamExt;
1158 use std::sync::{
1159 Arc,
1160 atomic::{AtomicUsize, Ordering},
1161 };
1162 use tokio_util::io::ReaderStream;
1163
1164 let mut total: usize = 0;
1166 let mut file_meta = Vec::with_capacity(files.len());
1167 for (name, path) in files {
1168 let f = tokio::fs::File::open(path).await?;
1169 let len = f.metadata().await?.len() as usize;
1170 total += len;
1171 file_meta.push((name.clone(), f, len));
1172 }
1173
1174 let sent = Arc::new(AtomicUsize::new(0));
1176
1177 let mut form = Form::new().text("session_id", self.id().value().to_string());
1178 if let Some(folder) = folder.filter(|s| !s.is_empty()) {
1179 form = form.text("folder", folder.to_owned());
1180 }
1181
1182 for (name, file, len) in file_meta {
1183 let reader_stream = ReaderStream::new(file);
1184 let sent_clone = sent.clone();
1185 let progress_clone = progress.clone();
1186 let progress_stream = reader_stream.inspect(move |chunk_result| {
1187 if let Ok(chunk) = chunk_result {
1188 let current =
1189 sent_clone.fetch_add(chunk.len(), Ordering::Relaxed) + chunk.len();
1190 if let Some(tx) = &progress_clone {
1195 let _ = tx.try_send(Progress {
1196 current,
1197 total,
1198 status: None,
1199 });
1200 }
1201 }
1202 });
1203 let body = reqwest::Body::wrap_stream(progress_stream);
1204 let part = Part::stream_with_length(body, len as u64).file_name(name);
1205 form = form.part("file", part);
1206 }
1207
1208 let result = match client.post_multipart("val.data.upload", form).await {
1209 Ok(_) => Ok(()),
1210 Err(Error::RpcError(code, msg)) => {
1211 Err(client::map_rpc_error("val.data.upload", code, msg, None))
1212 }
1213 Err(e) => Err(e),
1214 };
1215
1216 if result.is_ok()
1221 && let Some(tx) = progress
1222 {
1223 let _ = tx
1224 .send(Progress {
1225 current: total,
1226 total,
1227 status: None,
1228 })
1229 .await;
1230 }
1231 result
1232 }
1233
1234 pub async fn download_data(
1254 &self,
1255 client: &client::Client,
1256 filename: &str,
1257 output_path: &std::path::Path,
1258 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1259 ) -> Result<(), Error> {
1260 let req = client::ValDataDownloadRequest {
1261 session_id: self.id().value(),
1262 filename: filename.to_owned(),
1263 };
1264 match client
1265 .rpc_download("val.data.download", &req, output_path, progress)
1266 .await
1267 {
1268 Ok(()) => Ok(()),
1269 Err(Error::RpcError(code, msg)) => {
1270 Err(client::map_rpc_error("val.data.download", code, msg, None))
1271 }
1272 Err(e) => Err(e),
1273 }
1274 }
1275
1276 pub async fn data_list(&self, client: &client::Client) -> Result<Vec<String>, Error> {
1291 let req = client::ValDataListRequest {
1292 session_id: self.id().value(),
1293 };
1294 match client.rpc("val.data.list".to_owned(), Some(&req)).await {
1295 Ok(r) => Ok(r),
1296 Err(Error::RpcError(code, msg)) => {
1297 Err(client::map_rpc_error("val.data.list", code, msg, None))
1298 }
1299 Err(e) => Err(e),
1300 }
1301 }
1302}
1303
1304#[derive(Debug, Clone)]
1323pub struct StartValidationRequest {
1324 pub project_id: ProjectID,
1325 pub name: String,
1326 pub training_session_id: TrainingSessionID,
1327 pub model_file: String,
1328 pub val_type: String,
1329 pub params: HashMap<String, Parameter>,
1330 pub is_local: bool,
1331 pub is_kubernetes: bool,
1332 pub description: Option<String>,
1333 pub dataset_id: Option<DatasetID>,
1334 pub annotation_set_id: Option<AnnotationSetID>,
1335 pub snapshot_id: Option<SnapshotID>,
1336}
1337
1338#[derive(Deserialize, Debug, Clone)]
1353pub struct NewValidationSession {
1354 #[serde(rename = "id")]
1355 pub task_id: TaskID,
1356 #[serde(rename = "val_session_id", default)]
1357 pub session_id: Option<ValidationSessionID>,
1358}
1359
1360impl Display for NewValidationSession {
1361 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1362 match self.session_id {
1363 Some(id) => write!(f, "task {} session {}", self.task_id, id),
1364 None => write!(f, "task {} (no session)", self.task_id),
1365 }
1366 }
1367}
1368
1369#[derive(Deserialize, Clone, Debug)]
1370pub struct DatasetParams {
1371 dataset_id: DatasetID,
1372 annotation_set_id: AnnotationSetID,
1373 #[serde(rename = "train_group_name")]
1374 train_group: String,
1375 #[serde(rename = "val_group_name")]
1376 val_group: String,
1377}
1378
1379impl DatasetParams {
1380 pub fn dataset_id(&self) -> DatasetID {
1381 self.dataset_id
1382 }
1383
1384 pub fn annotation_set_id(&self) -> AnnotationSetID {
1385 self.annotation_set_id
1386 }
1387
1388 pub fn train_group(&self) -> &str {
1389 &self.train_group
1390 }
1391
1392 pub fn val_group(&self) -> &str {
1393 &self.val_group
1394 }
1395}
1396
1397#[derive(Serialize, Debug, Clone)]
1398pub struct TasksListParams {
1399 #[serde(skip_serializing_if = "Option::is_none")]
1400 pub continue_token: Option<String>,
1401 #[serde(skip_serializing_if = "Option::is_none")]
1402 pub types: Option<Vec<String>>,
1403 #[serde(rename = "manage_types", skip_serializing_if = "Option::is_none")]
1404 pub manager: Option<Vec<String>>,
1405 #[serde(skip_serializing_if = "Option::is_none")]
1406 pub status: Option<Vec<String>>,
1407}
1408
1409#[derive(Debug, Clone, Serialize, Deserialize)]
1415pub struct TaskDataList {
1416 pub server: String,
1417 #[serde(rename = "organization_uid")]
1418 pub organization_uid: String,
1419 #[serde(default)]
1420 pub traces: Vec<String>,
1421 #[serde(default)]
1422 pub data: std::collections::HashMap<String, Vec<String>>,
1423}
1424
1425#[derive(Debug, Clone, Serialize, Deserialize)]
1430pub struct Job {
1431 #[serde(default)]
1433 pub code: String,
1434 #[serde(default)]
1436 pub title: String,
1437 #[serde(default)]
1439 pub job_name: String,
1440 #[serde(default)]
1442 pub job_id: String,
1443 #[serde(default)]
1445 pub state: String,
1446 #[serde(default)]
1448 pub launch: Option<DateTime<Utc>>,
1449 pub task_id: i64,
1454}
1455
1456impl Job {
1457 pub fn task_id(&self) -> TaskID {
1463 TaskID::from(self.task_id.max(0) as u64)
1464 }
1465}
1466
1467#[derive(Deserialize, Debug, Clone)]
1468pub struct TasksListResult {
1469 pub tasks: Vec<Task>,
1470 pub continue_token: Option<String>,
1471}
1472
1473#[derive(Deserialize, Debug, Clone)]
1474pub struct Task {
1475 id: TaskID,
1476 name: String,
1477 #[serde(rename = "type")]
1478 workflow: String,
1479 status: String,
1480 #[serde(rename = "manage_type")]
1481 manager: Option<String>,
1482 #[serde(rename = "instance_type")]
1483 instance: String,
1484 #[serde(rename = "date")]
1485 created: DateTime<Utc>,
1486}
1487
1488impl Task {
1489 pub fn id(&self) -> TaskID {
1490 self.id
1491 }
1492
1493 pub fn name(&self) -> &str {
1494 &self.name
1495 }
1496
1497 pub fn workflow(&self) -> &str {
1498 &self.workflow
1499 }
1500
1501 pub fn status(&self) -> &str {
1502 &self.status
1503 }
1504
1505 pub fn manager(&self) -> Option<&str> {
1506 self.manager.as_deref()
1507 }
1508
1509 pub fn instance(&self) -> &str {
1510 &self.instance
1511 }
1512
1513 pub fn created(&self) -> &DateTime<Utc> {
1514 &self.created
1515 }
1516}
1517
1518impl Display for Task {
1519 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1520 write!(
1521 f,
1522 "{} [{:?} {}] {}",
1523 self.id,
1524 self.manager(),
1525 self.workflow(),
1526 self.name()
1527 )
1528 }
1529}
1530
1531#[derive(Deserialize, Debug, Clone)]
1532pub struct TaskInfo {
1533 id: TaskID,
1534 project_id: Option<ProjectID>,
1535 #[serde(rename = "task_description", alias = "description", default)]
1536 description: String,
1537 #[serde(rename = "type")]
1538 workflow: String,
1539 status: Option<String>,
1540 #[serde(default)]
1541 progress: TaskProgress,
1542 #[serde(
1543 rename = "created_date",
1544 alias = "created",
1545 default = "default_datetime_utc"
1546 )]
1547 created: DateTime<Utc>,
1548 #[serde(
1549 rename = "end_date",
1550 alias = "completed",
1551 default = "default_datetime_utc"
1552 )]
1553 completed: DateTime<Utc>,
1554}
1555
1556fn default_datetime_utc() -> DateTime<Utc> {
1557 DateTime::UNIX_EPOCH
1558}
1559
1560impl Display for TaskInfo {
1561 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1562 write!(f, "{} {}: {}", self.id, self.workflow(), self.description())
1563 }
1564}
1565
1566impl TaskInfo {
1567 pub fn id(&self) -> TaskID {
1568 self.id
1569 }
1570
1571 pub fn project_id(&self) -> Option<ProjectID> {
1572 self.project_id
1573 }
1574
1575 pub fn description(&self) -> &str {
1576 &self.description
1577 }
1578
1579 pub fn workflow(&self) -> &str {
1580 &self.workflow
1581 }
1582
1583 pub fn status(&self) -> &Option<String> {
1584 &self.status
1585 }
1586
1587 pub async fn set_status(&mut self, client: &Client, status: &str) -> Result<(), Error> {
1588 let t = client.task_status(self.id(), status).await?;
1589 self.status = Some(t.status);
1590 Ok(())
1591 }
1592
1593 pub fn stages(&self) -> HashMap<String, Stage> {
1594 match &self.progress.stages {
1595 Some(stages) => stages.clone(),
1596 None => HashMap::new(),
1597 }
1598 }
1599
1600 pub async fn update_stage(
1601 &mut self,
1602 client: &Client,
1603 stage: &str,
1604 status: &str,
1605 message: &str,
1606 percentage: u8,
1607 ) -> Result<(), Error> {
1608 client
1609 .update_stage(self.id(), stage, status, message, percentage)
1610 .await?;
1611 let t = client.task_info(self.id()).await?;
1612 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1613 Ok(())
1614 }
1615
1616 pub async fn set_stages(
1617 &mut self,
1618 client: &Client,
1619 stages: &[(&str, &str)],
1620 ) -> Result<(), Error> {
1621 client.set_stages(self.id(), stages).await?;
1622 let t = client.task_info(self.id()).await?;
1623 self.progress.stages = Some(t.progress.stages.unwrap_or_default());
1624 Ok(())
1625 }
1626
1627 pub async fn data_list(&self, client: &client::Client) -> Result<TaskDataList, Error> {
1643 let req = client::TaskDataListRequest {
1644 task_id: self.id().value(),
1645 };
1646 match client.rpc("task.data.list".to_owned(), Some(&req)).await {
1647 Ok(r) => Ok(r),
1648 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1649 "task.data.list",
1650 code,
1651 msg,
1652 Some(self.id()),
1653 )),
1654 Err(e) => Err(e),
1655 }
1656 }
1657
1658 pub async fn upload_data(
1679 &self,
1680 client: &client::Client,
1681 path: &std::path::Path,
1682 folder: Option<&str>,
1683 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1684 ) -> Result<(), Error> {
1685 use futures::StreamExt;
1686 use std::sync::{
1687 Arc,
1688 atomic::{AtomicUsize, Ordering},
1689 };
1690 use tokio_util::io::ReaderStream;
1691
1692 let file_name = path
1693 .file_name()
1694 .and_then(|s| s.to_str())
1695 .ok_or_else(|| Error::InvalidParameters("path must have a UTF-8 filename".into()))?
1696 .to_owned();
1697
1698 let file = tokio::fs::File::open(path).await?;
1699 let total = file.metadata().await?.len() as usize;
1700 let sent = Arc::new(AtomicUsize::new(0));
1701
1702 let reader_stream = ReaderStream::new(file);
1703 let sent_clone = sent.clone();
1704 let progress_clone = progress.clone();
1705 let progress_stream = reader_stream.inspect(move |chunk_result| {
1706 if let Ok(chunk) = chunk_result {
1707 let current = sent_clone.fetch_add(chunk.len(), Ordering::Relaxed) + chunk.len();
1708 if let Some(tx) = &progress_clone {
1714 let _ = tx.try_send(Progress {
1715 current,
1716 total,
1717 status: None,
1718 });
1719 }
1720 }
1721 });
1722
1723 let body = reqwest::Body::wrap_stream(progress_stream);
1724 let file_part = Part::stream_with_length(body, total as u64).file_name(file_name);
1725
1726 let mut form = Form::new().text("task_id", self.id().value().to_string());
1727 if let Some(folder) = folder.filter(|s| !s.is_empty()) {
1728 form = form.text("folder", folder.to_owned());
1729 }
1730 form = form.part("file", file_part);
1731
1732 let result = match client.post_multipart("task.data.upload", form).await {
1733 Ok(_) => Ok(()),
1734 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1735 "task.data.upload",
1736 code,
1737 msg,
1738 Some(self.id()),
1739 )),
1740 Err(e) => Err(e),
1741 };
1742
1743 if result.is_ok()
1747 && let Some(tx) = progress
1748 {
1749 let _ = tx
1750 .send(Progress {
1751 current: total,
1752 total,
1753 status: None,
1754 })
1755 .await;
1756 }
1757 result
1758 }
1759
1760 pub async fn download_data(
1789 &self,
1790 client: &client::Client,
1791 file: &str,
1792 folder: Option<&str>,
1793 output_path: &std::path::Path,
1794 progress: Option<tokio::sync::mpsc::Sender<Progress>>,
1795 ) -> Result<(), Error> {
1796 let folder = folder.unwrap_or("").to_owned();
1797 let req = client::TaskDataDownloadRequest {
1798 task_id: self.id().value(),
1799 folder,
1800 file: file.to_owned(),
1801 };
1802 match client
1803 .rpc_download("task.data.download", &req, output_path, progress)
1804 .await
1805 {
1806 Ok(()) => Ok(()),
1807 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1808 "task.data.download",
1809 code,
1810 msg,
1811 Some(self.id()),
1812 )),
1813 Err(e) => Err(e),
1814 }
1815 }
1816
1817 pub async fn add_chart(
1845 &self,
1846 client: &client::Client,
1847 group: &str,
1848 name: &str,
1849 data: Parameter,
1850 params: Option<Parameter>,
1851 ) -> Result<(), Error> {
1852 client::validate_chart_args(group, name)?;
1853 let req = client::TaskChartAddRequest {
1854 task_id: self.id().value(),
1855 group_name: group.to_owned(),
1856 chart_name: name.to_owned(),
1857 params,
1858 data,
1859 };
1860 let _resp: serde_json::Value =
1861 match client.rpc("task.chart.add".to_owned(), Some(&req)).await {
1862 Ok(r) => r,
1863 Err(Error::RpcError(code, msg)) => {
1864 return Err(client::map_rpc_error(
1865 "task.chart.add",
1866 code,
1867 msg,
1868 Some(self.id()),
1869 ));
1870 }
1871 Err(e) => return Err(e),
1872 };
1873 Ok(())
1874 }
1875
1876 pub async fn list_charts(
1893 &self,
1894 client: &client::Client,
1895 group: Option<&str>,
1896 ) -> Result<TaskDataList, Error> {
1897 let req = client::TaskChartListRequest {
1898 task_id: self.id().value(),
1899 group_name: group.unwrap_or("").to_owned(),
1900 };
1901 match client.rpc("task.chart.list".to_owned(), Some(&req)).await {
1902 Ok(r) => Ok(r),
1903 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1904 "task.chart.list",
1905 code,
1906 msg,
1907 Some(self.id()),
1908 )),
1909 Err(e) => Err(e),
1910 }
1911 }
1912
1913 pub async fn get_chart(
1932 &self,
1933 client: &client::Client,
1934 group: &str,
1935 name: &str,
1936 ) -> Result<Parameter, Error> {
1937 client::validate_chart_args(group, name)?;
1938 let req = client::TaskChartGetRequest {
1939 task_id: self.id().value(),
1940 group_name: group.to_owned(),
1941 chart_name: name.to_owned(),
1942 };
1943 match client.rpc("task.chart.get".to_owned(), Some(&req)).await {
1944 Ok(r) => Ok(r),
1945 Err(Error::RpcError(code, msg)) => Err(client::map_rpc_error(
1946 "task.chart.get",
1947 code,
1948 msg,
1949 Some(self.id()),
1950 )),
1951 Err(e) => Err(e),
1952 }
1953 }
1954
1955 pub fn created(&self) -> &DateTime<Utc> {
1956 &self.created
1957 }
1958
1959 pub fn completed(&self) -> &DateTime<Utc> {
1960 &self.completed
1961 }
1962}
1963
1964#[derive(Deserialize, Debug, Default, Clone)]
1965pub struct TaskProgress {
1966 stages: Option<HashMap<String, Stage>>,
1967}
1968
1969#[derive(Serialize, Debug, Clone)]
1970pub struct TaskStatus {
1971 #[serde(rename = "docker_task_id")]
1972 pub task_id: TaskID,
1973 pub status: String,
1974}
1975
1976#[derive(Serialize, Deserialize, Debug, Clone)]
1977pub struct Stage {
1978 #[serde(rename = "docker_task_id", skip_serializing_if = "Option::is_none")]
1979 task_id: Option<TaskID>,
1980 stage: String,
1981 #[serde(skip_serializing_if = "Option::is_none")]
1982 status: Option<String>,
1983 #[serde(skip_serializing_if = "Option::is_none")]
1984 description: Option<String>,
1985 #[serde(skip_serializing_if = "Option::is_none")]
1986 message: Option<String>,
1987 percentage: u8,
1988}
1989
1990impl Stage {
1991 pub fn new(
1992 task_id: Option<TaskID>,
1993 stage: String,
1994 status: Option<String>,
1995 message: Option<String>,
1996 percentage: u8,
1997 ) -> Self {
1998 Stage {
1999 task_id,
2000 stage,
2001 status,
2002 description: None,
2003 message,
2004 percentage,
2005 }
2006 }
2007
2008 pub fn task_id(&self) -> &Option<TaskID> {
2009 &self.task_id
2010 }
2011
2012 pub fn stage(&self) -> &str {
2013 &self.stage
2014 }
2015
2016 pub fn status(&self) -> &Option<String> {
2017 &self.status
2018 }
2019
2020 pub fn description(&self) -> &Option<String> {
2021 &self.description
2022 }
2023
2024 pub fn message(&self) -> &Option<String> {
2025 &self.message
2026 }
2027
2028 pub fn percentage(&self) -> u8 {
2029 self.percentage
2030 }
2031}
2032
2033#[derive(Serialize, Debug)]
2034pub struct TaskStages {
2035 #[serde(rename = "docker_task_id")]
2036 pub task_id: TaskID,
2037 #[serde(skip_serializing_if = "Vec::is_empty")]
2038 pub stages: Vec<HashMap<String, String>>,
2039}
2040
2041#[derive(Deserialize, Debug)]
2042pub struct Artifact {
2043 name: String,
2044 #[serde(rename = "modelType")]
2045 model_type: String,
2046}
2047
2048impl Artifact {
2049 pub fn name(&self) -> &str {
2050 &self.name
2051 }
2052
2053 pub fn model_type(&self) -> &str {
2054 &self.model_type
2055 }
2056}
2057
2058#[cfg(test)]
2059mod tests {
2060 use super::*;
2061
2062 #[test]
2064 fn test_organization_id_from_u64() {
2065 let id = OrganizationID::from(12345);
2066 assert_eq!(id.value(), 12345);
2067 }
2068
2069 #[test]
2070 fn test_organization_id_display() {
2071 let id = OrganizationID::from(0xabc123);
2072 assert_eq!(format!("{}", id), "org-abc123");
2073 }
2074
2075 #[test]
2076 fn test_organization_id_try_from_str_valid() {
2077 let id = OrganizationID::try_from("org-abc123").unwrap();
2078 assert_eq!(id.value(), 0xabc123);
2079 }
2080
2081 #[test]
2082 fn test_organization_id_try_from_str_invalid_prefix() {
2083 let result = OrganizationID::try_from("invalid-abc123");
2084 assert!(result.is_err());
2085 match result {
2086 Err(Error::InvalidParameters(msg)) => {
2087 assert!(msg.contains("must start with 'org-'"));
2088 }
2089 _ => panic!("Expected InvalidParameters error"),
2090 }
2091 }
2092
2093 #[test]
2094 fn test_organization_id_try_from_str_invalid_hex() {
2095 let result = OrganizationID::try_from("org-xyz");
2096 assert!(result.is_err());
2097 }
2098
2099 #[test]
2100 fn test_organization_id_try_from_str_empty() {
2101 let result = OrganizationID::try_from("org-");
2102 assert!(result.is_err());
2103 }
2104
2105 #[test]
2106 fn test_organization_id_into_u64() {
2107 let id = OrganizationID::from(54321);
2108 let value: u64 = id.into();
2109 assert_eq!(value, 54321);
2110 }
2111
2112 #[test]
2114 fn test_project_id_from_u64() {
2115 let id = ProjectID::from(78910);
2116 assert_eq!(id.value(), 78910);
2117 }
2118
2119 #[test]
2120 fn test_project_id_display() {
2121 let id = ProjectID::from(0xdef456);
2122 assert_eq!(format!("{}", id), "p-def456");
2123 }
2124
2125 #[test]
2126 fn test_project_id_from_str_valid() {
2127 let id = ProjectID::from_str("p-def456").unwrap();
2128 assert_eq!(id.value(), 0xdef456);
2129 }
2130
2131 #[test]
2132 fn test_project_id_try_from_str_valid() {
2133 let id = ProjectID::try_from("p-123abc").unwrap();
2134 assert_eq!(id.value(), 0x123abc);
2135 }
2136
2137 #[test]
2138 fn test_project_id_try_from_string_valid() {
2139 let id = ProjectID::try_from("p-456def".to_string()).unwrap();
2140 assert_eq!(id.value(), 0x456def);
2141 }
2142
2143 #[test]
2144 fn test_project_id_from_str_invalid_prefix() {
2145 let result = ProjectID::from_str("proj-123");
2146 assert!(result.is_err());
2147 match result {
2148 Err(Error::InvalidParameters(msg)) => {
2149 assert!(msg.contains("must start with 'p-'"));
2150 }
2151 _ => panic!("Expected InvalidParameters error"),
2152 }
2153 }
2154
2155 #[test]
2156 fn test_project_id_from_str_invalid_hex() {
2157 let result = ProjectID::from_str("p-notahex");
2158 assert!(result.is_err());
2159 }
2160
2161 #[test]
2162 fn test_project_id_into_u64() {
2163 let id = ProjectID::from(99999);
2164 let value: u64 = id.into();
2165 assert_eq!(value, 99999);
2166 }
2167
2168 #[test]
2170 fn test_experiment_id_from_u64() {
2171 let id = ExperimentID::from(1193046);
2172 assert_eq!(id.value(), 1193046);
2173 }
2174
2175 #[test]
2176 fn test_experiment_id_display() {
2177 let id = ExperimentID::from(0x123abc);
2178 assert_eq!(format!("{}", id), "exp-123abc");
2179 }
2180
2181 #[test]
2182 fn test_experiment_id_from_str_valid() {
2183 let id = ExperimentID::from_str("exp-456def").unwrap();
2184 assert_eq!(id.value(), 0x456def);
2185 }
2186
2187 #[test]
2188 fn test_experiment_id_try_from_str_valid() {
2189 let id = ExperimentID::try_from("exp-789abc").unwrap();
2190 assert_eq!(id.value(), 0x789abc);
2191 }
2192
2193 #[test]
2194 fn test_experiment_id_try_from_string_valid() {
2195 let id = ExperimentID::try_from("exp-fedcba".to_string()).unwrap();
2196 assert_eq!(id.value(), 0xfedcba);
2197 }
2198
2199 #[test]
2200 fn test_experiment_id_from_str_invalid_prefix() {
2201 let result = ExperimentID::from_str("experiment-123");
2202 assert!(result.is_err());
2203 match result {
2204 Err(Error::InvalidParameters(msg)) => {
2205 assert!(msg.contains("must start with 'exp-'"));
2206 }
2207 _ => panic!("Expected InvalidParameters error"),
2208 }
2209 }
2210
2211 #[test]
2212 fn test_experiment_id_from_str_invalid_hex() {
2213 let result = ExperimentID::from_str("exp-zzz");
2214 assert!(result.is_err());
2215 }
2216
2217 #[test]
2218 fn test_experiment_id_into_u64() {
2219 let id = ExperimentID::from(777777);
2220 let value: u64 = id.into();
2221 assert_eq!(value, 777777);
2222 }
2223
2224 #[test]
2226 fn test_training_session_id_from_u64() {
2227 let id = TrainingSessionID::from(7901234);
2228 assert_eq!(id.value(), 7901234);
2229 }
2230
2231 #[test]
2232 fn test_training_session_id_display() {
2233 let id = TrainingSessionID::from(0xabc123);
2234 assert_eq!(format!("{}", id), "t-abc123");
2235 }
2236
2237 #[test]
2238 fn test_training_session_id_from_str_valid() {
2239 let id = TrainingSessionID::from_str("t-abc123").unwrap();
2240 assert_eq!(id.value(), 0xabc123);
2241 }
2242
2243 #[test]
2244 fn test_training_session_id_try_from_str_valid() {
2245 let id = TrainingSessionID::try_from("t-deadbeef").unwrap();
2246 assert_eq!(id.value(), 0xdeadbeef);
2247 }
2248
2249 #[test]
2250 fn test_training_session_id_try_from_string_valid() {
2251 let id = TrainingSessionID::try_from("t-cafebabe".to_string()).unwrap();
2252 assert_eq!(id.value(), 0xcafebabe);
2253 }
2254
2255 #[test]
2256 fn test_training_session_id_from_str_invalid_prefix() {
2257 let result = TrainingSessionID::from_str("training-123");
2258 assert!(result.is_err());
2259 match result {
2260 Err(Error::InvalidParameters(msg)) => {
2261 assert!(msg.contains("must start with 't-'"));
2262 }
2263 _ => panic!("Expected InvalidParameters error"),
2264 }
2265 }
2266
2267 #[test]
2268 fn test_training_session_id_from_str_invalid_hex() {
2269 let result = TrainingSessionID::from_str("t-qqq");
2270 assert!(result.is_err());
2271 }
2272
2273 #[test]
2274 fn test_training_session_id_into_u64() {
2275 let id = TrainingSessionID::from(123456);
2276 let value: u64 = id.into();
2277 assert_eq!(value, 123456);
2278 }
2279
2280 #[test]
2282 fn test_validation_session_id_from_u64() {
2283 let id = ValidationSessionID::from(3456789);
2284 assert_eq!(id.value(), 3456789);
2285 }
2286
2287 #[test]
2288 fn test_validation_session_id_display() {
2289 let id = ValidationSessionID::from(0x34c985);
2290 assert_eq!(format!("{}", id), "v-34c985");
2291 }
2292
2293 #[test]
2294 fn test_validation_session_id_try_from_str_valid() {
2295 let id = ValidationSessionID::try_from("v-deadbeef").unwrap();
2296 assert_eq!(id.value(), 0xdeadbeef);
2297 }
2298
2299 #[test]
2300 fn test_validation_session_id_try_from_string_valid() {
2301 let id = ValidationSessionID::try_from("v-12345678".to_string()).unwrap();
2302 assert_eq!(id.value(), 0x12345678);
2303 }
2304
2305 #[test]
2306 fn test_validation_session_id_try_from_str_invalid_prefix() {
2307 let result = ValidationSessionID::try_from("validation-123");
2308 assert!(result.is_err());
2309 match result {
2310 Err(Error::InvalidParameters(msg)) => {
2311 assert!(msg.contains("must start with 'v-'"));
2312 }
2313 _ => panic!("Expected InvalidParameters error"),
2314 }
2315 }
2316
2317 #[test]
2318 fn test_validation_session_id_try_from_str_invalid_hex() {
2319 let result = ValidationSessionID::try_from("v-xyz");
2320 assert!(result.is_err());
2321 }
2322
2323 #[test]
2324 fn test_validation_session_id_into_u64() {
2325 let id = ValidationSessionID::from(987654);
2326 let value: u64 = id.into();
2327 assert_eq!(value, 987654);
2328 }
2329
2330 #[test]
2332 fn test_snapshot_id_from_u64() {
2333 let id = SnapshotID::from(111222);
2334 assert_eq!(id.value(), 111222);
2335 }
2336
2337 #[test]
2338 fn test_snapshot_id_display() {
2339 let id = SnapshotID::from(0xaabbcc);
2340 assert_eq!(format!("{}", id), "ss-aabbcc");
2341 }
2342
2343 #[test]
2344 fn test_snapshot_id_try_from_str_valid() {
2345 let id = SnapshotID::try_from("ss-aabbcc").unwrap();
2346 assert_eq!(id.value(), 0xaabbcc);
2347 }
2348
2349 #[test]
2350 fn test_snapshot_id_try_from_str_invalid_prefix() {
2351 let result = SnapshotID::try_from("snapshot-123");
2352 assert!(result.is_err());
2353 match result {
2354 Err(Error::InvalidParameters(msg)) => {
2355 assert!(msg.contains("must start with 'ss-'"));
2356 }
2357 _ => panic!("Expected InvalidParameters error"),
2358 }
2359 }
2360
2361 #[test]
2362 fn test_snapshot_id_try_from_str_invalid_hex() {
2363 let result = SnapshotID::try_from("ss-ggg");
2364 assert!(result.is_err());
2365 }
2366
2367 #[test]
2368 fn test_snapshot_id_into_u64() {
2369 let id = SnapshotID::from(333444);
2370 let value: u64 = id.into();
2371 assert_eq!(value, 333444);
2372 }
2373
2374 #[test]
2376 fn test_task_id_from_u64() {
2377 let id = TaskID::from(555666);
2378 assert_eq!(id.value(), 555666);
2379 }
2380
2381 #[test]
2382 fn test_task_id_display() {
2383 let id = TaskID::from(0x123456);
2384 assert_eq!(format!("{}", id), "task-123456");
2385 }
2386
2387 #[test]
2388 fn test_task_id_from_str_valid() {
2389 let id = TaskID::from_str("task-123456").unwrap();
2390 assert_eq!(id.value(), 0x123456);
2391 }
2392
2393 #[test]
2394 fn test_task_id_try_from_str_valid() {
2395 let id = TaskID::try_from("task-abcdef").unwrap();
2396 assert_eq!(id.value(), 0xabcdef);
2397 }
2398
2399 #[test]
2400 fn test_task_id_try_from_string_valid() {
2401 let id = TaskID::try_from("task-fedcba".to_string()).unwrap();
2402 assert_eq!(id.value(), 0xfedcba);
2403 }
2404
2405 #[test]
2406 fn test_task_id_from_str_invalid_prefix() {
2407 let result = TaskID::from_str("t-123");
2408 assert!(result.is_err());
2409 match result {
2410 Err(Error::InvalidParameters(msg)) => {
2411 assert!(msg.contains("must start with 'task-'"));
2412 }
2413 _ => panic!("Expected InvalidParameters error"),
2414 }
2415 }
2416
2417 #[test]
2418 fn test_task_id_from_str_invalid_hex() {
2419 let result = TaskID::from_str("task-zzz");
2420 assert!(result.is_err());
2421 }
2422
2423 #[test]
2424 fn test_task_id_into_u64() {
2425 let id = TaskID::from(777888);
2426 let value: u64 = id.into();
2427 assert_eq!(value, 777888);
2428 }
2429
2430 #[test]
2432 fn test_dataset_id_from_u64() {
2433 let id = DatasetID::from(1193046);
2434 assert_eq!(id.value(), 1193046);
2435 }
2436
2437 #[test]
2438 fn test_dataset_id_display() {
2439 let id = DatasetID::from(0x123abc);
2440 assert_eq!(format!("{}", id), "ds-123abc");
2441 }
2442
2443 #[test]
2444 fn test_dataset_id_from_str_valid() {
2445 let id = DatasetID::from_str("ds-456def").unwrap();
2446 assert_eq!(id.value(), 0x456def);
2447 }
2448
2449 #[test]
2450 fn test_dataset_id_try_from_str_valid() {
2451 let id = DatasetID::try_from("ds-789abc").unwrap();
2452 assert_eq!(id.value(), 0x789abc);
2453 }
2454
2455 #[test]
2456 fn test_dataset_id_try_from_string_valid() {
2457 let id = DatasetID::try_from("ds-fedcba".to_string()).unwrap();
2458 assert_eq!(id.value(), 0xfedcba);
2459 }
2460
2461 #[test]
2462 fn test_dataset_id_from_str_invalid_prefix() {
2463 let result = DatasetID::from_str("dataset-123");
2464 assert!(result.is_err());
2465 match result {
2466 Err(Error::InvalidParameters(msg)) => {
2467 assert!(msg.contains("must start with 'ds-'"));
2468 }
2469 _ => panic!("Expected InvalidParameters error"),
2470 }
2471 }
2472
2473 #[test]
2474 fn test_dataset_id_from_str_invalid_hex() {
2475 let result = DatasetID::from_str("ds-zzz");
2476 assert!(result.is_err());
2477 }
2478
2479 #[test]
2480 fn test_dataset_id_into_u64() {
2481 let id = DatasetID::from(111111);
2482 let value: u64 = id.into();
2483 assert_eq!(value, 111111);
2484 }
2485
2486 #[test]
2488 fn test_annotation_set_id_from_u64() {
2489 let id = AnnotationSetID::from(222333);
2490 assert_eq!(id.value(), 222333);
2491 }
2492
2493 #[test]
2494 fn test_annotation_set_id_display() {
2495 let id = AnnotationSetID::from(0xabcdef);
2496 assert_eq!(format!("{}", id), "as-abcdef");
2497 }
2498
2499 #[test]
2500 fn test_annotation_set_id_from_str_valid() {
2501 let id = AnnotationSetID::from_str("as-abcdef").unwrap();
2502 assert_eq!(id.value(), 0xabcdef);
2503 }
2504
2505 #[test]
2506 fn test_annotation_set_id_try_from_str_valid() {
2507 let id = AnnotationSetID::try_from("as-123456").unwrap();
2508 assert_eq!(id.value(), 0x123456);
2509 }
2510
2511 #[test]
2512 fn test_annotation_set_id_try_from_string_valid() {
2513 let id = AnnotationSetID::try_from("as-fedcba".to_string()).unwrap();
2514 assert_eq!(id.value(), 0xfedcba);
2515 }
2516
2517 #[test]
2518 fn test_annotation_set_id_from_str_invalid_prefix() {
2519 let result = AnnotationSetID::from_str("annotation-123");
2520 assert!(result.is_err());
2521 match result {
2522 Err(Error::InvalidParameters(msg)) => {
2523 assert!(msg.contains("must start with 'as-'"));
2524 }
2525 _ => panic!("Expected InvalidParameters error"),
2526 }
2527 }
2528
2529 #[test]
2530 fn test_annotation_set_id_from_str_invalid_hex() {
2531 let result = AnnotationSetID::from_str("as-zzz");
2532 assert!(result.is_err());
2533 }
2534
2535 #[test]
2536 fn test_annotation_set_id_into_u64() {
2537 let id = AnnotationSetID::from(444555);
2538 let value: u64 = id.into();
2539 assert_eq!(value, 444555);
2540 }
2541
2542 #[test]
2544 fn test_sample_id_from_u64() {
2545 let id = SampleID::from(666777);
2546 assert_eq!(id.value(), 666777);
2547 }
2548
2549 #[test]
2550 fn test_sample_id_display() {
2551 let id = SampleID::from(0x987654);
2552 assert_eq!(format!("{}", id), "s-987654");
2553 }
2554
2555 #[test]
2556 fn test_sample_id_try_from_str_valid() {
2557 let id = SampleID::try_from("s-987654").unwrap();
2558 assert_eq!(id.value(), 0x987654);
2559 }
2560
2561 #[test]
2562 fn test_sample_id_try_from_str_invalid_prefix() {
2563 let result = SampleID::try_from("sample-123");
2564 assert!(result.is_err());
2565 match result {
2566 Err(Error::InvalidParameters(msg)) => {
2567 assert!(msg.contains("must start with 's-'"));
2568 }
2569 _ => panic!("Expected InvalidParameters error"),
2570 }
2571 }
2572
2573 #[test]
2574 fn test_sample_id_try_from_str_invalid_hex() {
2575 let result = SampleID::try_from("s-zzz");
2576 assert!(result.is_err());
2577 }
2578
2579 #[test]
2580 fn test_sample_id_into_u64() {
2581 let id = SampleID::from(888999);
2582 let value: u64 = id.into();
2583 assert_eq!(value, 888999);
2584 }
2585
2586 #[test]
2588 fn test_app_id_from_u64() {
2589 let id = AppId::from(123123);
2590 assert_eq!(id.value(), 123123);
2591 }
2592
2593 #[test]
2594 fn test_app_id_display() {
2595 let id = AppId::from(0x456789);
2596 assert_eq!(format!("{}", id), "app-456789");
2597 }
2598
2599 #[test]
2600 fn test_app_id_try_from_str_valid() {
2601 let id = AppId::try_from("app-456789").unwrap();
2602 assert_eq!(id.value(), 0x456789);
2603 }
2604
2605 #[test]
2606 fn test_app_id_try_from_str_invalid_prefix() {
2607 let result = AppId::try_from("application-123");
2608 assert!(result.is_err());
2609 match result {
2610 Err(Error::InvalidParameters(msg)) => {
2611 assert!(msg.contains("must start with 'app-'"));
2612 }
2613 _ => panic!("Expected InvalidParameters error"),
2614 }
2615 }
2616
2617 #[test]
2618 fn test_app_id_try_from_str_invalid_hex() {
2619 let result = AppId::try_from("app-zzz");
2620 assert!(result.is_err());
2621 }
2622
2623 #[test]
2624 fn test_app_id_into_u64() {
2625 let id = AppId::from(321321);
2626 let value: u64 = id.into();
2627 assert_eq!(value, 321321);
2628 }
2629
2630 #[test]
2632 fn test_image_id_from_u64() {
2633 let id = ImageId::from(789789);
2634 assert_eq!(id.value(), 789789);
2635 }
2636
2637 #[test]
2638 fn test_image_id_display() {
2639 let id = ImageId::from(0xabcd1234);
2640 assert_eq!(format!("{}", id), "im-abcd1234");
2641 }
2642
2643 #[test]
2644 fn test_image_id_try_from_str_valid() {
2645 let id = ImageId::try_from("im-abcd1234").unwrap();
2646 assert_eq!(id.value(), 0xabcd1234);
2647 }
2648
2649 #[test]
2650 fn test_image_id_try_from_str_invalid_prefix() {
2651 let result = ImageId::try_from("image-123");
2652 assert!(result.is_err());
2653 match result {
2654 Err(Error::InvalidParameters(msg)) => {
2655 assert!(msg.contains("must start with 'im-'"));
2656 }
2657 _ => panic!("Expected InvalidParameters error"),
2658 }
2659 }
2660
2661 #[test]
2662 fn test_image_id_try_from_str_invalid_hex() {
2663 let result = ImageId::try_from("im-zzz");
2664 assert!(result.is_err());
2665 }
2666
2667 #[test]
2668 fn test_image_id_into_u64() {
2669 let id = ImageId::from(987987);
2670 let value: u64 = id.into();
2671 assert_eq!(value, 987987);
2672 }
2673
2674 #[test]
2676 fn test_id_types_equality() {
2677 let id1 = ProjectID::from(12345);
2678 let id2 = ProjectID::from(12345);
2679 let id3 = ProjectID::from(54321);
2680
2681 assert_eq!(id1, id2);
2682 assert_ne!(id1, id3);
2683 }
2684
2685 #[test]
2686 fn test_id_types_hash() {
2687 use std::collections::HashSet;
2688
2689 let mut set = HashSet::new();
2690 set.insert(DatasetID::from(100));
2691 set.insert(DatasetID::from(200));
2692 set.insert(DatasetID::from(100)); assert_eq!(set.len(), 2);
2695 assert!(set.contains(&DatasetID::from(100)));
2696 assert!(set.contains(&DatasetID::from(200)));
2697 }
2698
2699 #[test]
2700 fn test_id_types_copy_clone() {
2701 let id1 = ExperimentID::from(999);
2702 let id2 = id1; let id3 = id1; assert_eq!(id1, id2);
2706 assert_eq!(id1, id3);
2707 }
2708
2709 #[test]
2711 fn test_id_zero_value() {
2712 let id = ProjectID::from(0);
2713 assert_eq!(format!("{}", id), "p-0");
2714 assert_eq!(id.value(), 0);
2715 }
2716
2717 #[test]
2718 fn test_id_max_value() {
2719 let id = ProjectID::from(u64::MAX);
2720 assert_eq!(format!("{}", id), "p-ffffffffffffffff");
2721 assert_eq!(id.value(), u64::MAX);
2722 }
2723
2724 #[test]
2725 fn test_id_round_trip_conversion() {
2726 let original = 0xdeadbeef_u64;
2727 let id = TrainingSessionID::from(original);
2728 let back: u64 = id.into();
2729 assert_eq!(original, back);
2730 }
2731
2732 #[test]
2733 fn test_id_case_insensitive_hex() {
2734 let id1 = DatasetID::from_str("ds-ABCDEF").unwrap();
2736 let id2 = DatasetID::from_str("ds-abcdef").unwrap();
2737 assert_eq!(id1.value(), id2.value());
2738 }
2739
2740 #[test]
2741 fn test_id_with_leading_zeros() {
2742 let id = ProjectID::from_str("p-00001234").unwrap();
2743 assert_eq!(id.value(), 0x1234);
2744 }
2745
2746 #[test]
2748 fn test_parameter_integer() {
2749 let param = Parameter::Integer(42);
2750 match param {
2751 Parameter::Integer(val) => assert_eq!(val, 42),
2752 _ => panic!("Expected Integer variant"),
2753 }
2754 }
2755
2756 #[test]
2757 fn test_parameter_real() {
2758 let param = Parameter::Real(2.5);
2759 match param {
2760 Parameter::Real(val) => assert_eq!(val, 2.5),
2761 _ => panic!("Expected Real variant"),
2762 }
2763 }
2764
2765 #[test]
2766 fn test_parameter_boolean() {
2767 let param = Parameter::Boolean(true);
2768 match param {
2769 Parameter::Boolean(val) => assert!(val),
2770 _ => panic!("Expected Boolean variant"),
2771 }
2772 }
2773
2774 #[test]
2775 fn test_parameter_string() {
2776 let param = Parameter::String("test".to_string());
2777 match param {
2778 Parameter::String(val) => assert_eq!(val, "test"),
2779 _ => panic!("Expected String variant"),
2780 }
2781 }
2782
2783 #[test]
2784 fn test_parameter_array() {
2785 let param = Parameter::Array(vec![
2786 Parameter::Integer(1),
2787 Parameter::Integer(2),
2788 Parameter::Integer(3),
2789 ]);
2790 match param {
2791 Parameter::Array(arr) => assert_eq!(arr.len(), 3),
2792 _ => panic!("Expected Array variant"),
2793 }
2794 }
2795
2796 #[test]
2797 fn test_parameter_object() {
2798 let mut map = HashMap::new();
2799 map.insert("key".to_string(), Parameter::Integer(100));
2800 let param = Parameter::Object(map);
2801 match param {
2802 Parameter::Object(obj) => {
2803 assert_eq!(obj.len(), 1);
2804 assert!(obj.contains_key("key"));
2805 }
2806 _ => panic!("Expected Object variant"),
2807 }
2808 }
2809
2810 #[test]
2811 fn test_parameter_clone() {
2812 let param1 = Parameter::Integer(42);
2813 let param2 = param1.clone();
2814 assert_eq!(param1, param2);
2815 }
2816
2817 #[test]
2818 fn test_parameter_nested() {
2819 let inner_array = Parameter::Array(vec![Parameter::Integer(1), Parameter::Integer(2)]);
2820 let outer_array = Parameter::Array(vec![inner_array.clone(), inner_array]);
2821
2822 match outer_array {
2823 Parameter::Array(arr) => {
2824 assert_eq!(arr.len(), 2);
2825 }
2826 _ => panic!("Expected Array variant"),
2827 }
2828 }
2829
2830 macro_rules! test_typeid_conversions {
2833 ($test_name:ident, $type:ty, $prefix:literal, $wrong_prefix:literal) => {
2834 #[test]
2835 fn $test_name() {
2836 let id = <$type>::from(0xabc123);
2838 assert_eq!(id.value(), 0xabc123);
2839
2840 assert_eq!(format!("{}", id), concat!($prefix, "-abc123"));
2842
2843 let id: $type = concat!($prefix, "-abc123").parse().unwrap();
2845 assert_eq!(id.value(), 0xabc123);
2846
2847 assert!(concat!($wrong_prefix, "-abc").parse::<$type>().is_err());
2849
2850 assert!("abc123".parse::<$type>().is_err());
2852
2853 assert!(concat!($prefix, "-xyz").parse::<$type>().is_err());
2855
2856 let id = <$type>::try_from(concat!($prefix, "-abc123")).unwrap();
2858 assert_eq!(id.value(), 0xabc123);
2859
2860 let id = <$type>::try_from(concat!($prefix, "-abc123").to_string()).unwrap();
2862 assert_eq!(id.value(), 0xabc123);
2863
2864 let id = <$type>::from(0xabc123);
2866 let json = serde_json::to_string(&id).unwrap();
2867 let parsed: $type = serde_json::from_str(&json).unwrap();
2868 assert_eq!(id, parsed);
2869
2870 let id = <$type>::from(0xabc123);
2872 let val: u64 = id.into();
2873 assert_eq!(val, 0xabc123);
2874 }
2875 };
2876 }
2877
2878 test_typeid_conversions!(test_organization_id_conversions, OrganizationID, "org", "p");
2879 test_typeid_conversions!(test_project_id_conversions, ProjectID, "p", "org");
2880 test_typeid_conversions!(test_experiment_id_conversions, ExperimentID, "exp", "p");
2881 test_typeid_conversions!(
2882 test_training_session_id_conversions,
2883 TrainingSessionID,
2884 "t",
2885 "v"
2886 );
2887 test_typeid_conversions!(
2888 test_validation_session_id_conversions,
2889 ValidationSessionID,
2890 "v",
2891 "t"
2892 );
2893 test_typeid_conversions!(test_snapshot_id_conversions, SnapshotID, "ss", "ds");
2894 test_typeid_conversions!(test_task_id_conversions, TaskID, "task", "t");
2895 test_typeid_conversions!(test_dataset_id_conversions, DatasetID, "ds", "ss");
2896 test_typeid_conversions!(
2897 test_annotation_set_id_conversions,
2898 AnnotationSetID,
2899 "as",
2900 "ds"
2901 );
2902 test_typeid_conversions!(test_sample_id_conversions, SampleID, "s", "p");
2903 test_typeid_conversions!(test_app_id_conversions, AppId, "app", "p");
2904 test_typeid_conversions!(test_image_id_conversions, ImageId, "im", "se");
2905 test_typeid_conversions!(test_sequence_id_conversions, SequenceId, "se", "im");
2906}
2907
2908#[cfg(test)]
2909mod tests_task_data_list {
2910 use super::*;
2911
2912 #[test]
2913 fn task_data_list_deserializes_from_server_shape() {
2914 let json = r#"{
2915 "server": "test.edgefirst.studio",
2916 "organization_uid": "org-abc123",
2917 "traces": ["trace/imx95.json"],
2918 "data": {
2919 "predictions": ["predictions.parquet"],
2920 "trace": ["imx95.json"]
2921 }
2922 }"#;
2923 let parsed: TaskDataList = serde_json::from_str(json).unwrap();
2924 assert_eq!(parsed.server, "test.edgefirst.studio");
2925 assert_eq!(parsed.organization_uid, "org-abc123");
2926 assert_eq!(parsed.traces, vec!["trace/imx95.json"]);
2927 assert_eq!(
2928 parsed.data.get("predictions").unwrap(),
2929 &vec!["predictions.parquet".to_string()]
2930 );
2931 }
2932}
2933
2934#[cfg(test)]
2935mod tests_upload_data {
2936 #[test]
2940 fn folder_empty_string_is_normalised() {
2941 let folder: Option<&str> = Some("");
2942 assert!(folder.filter(|s| !s.is_empty()).is_none());
2943
2944 let folder_real: Option<&str> = Some("predictions");
2945 assert!(folder_real.filter(|s| !s.is_empty()).is_some());
2946 }
2947}
2948
2949#[cfg(test)]
2950mod tests_job_struct {
2951 use super::*;
2952
2953 #[test]
2954 fn job_deserializes_with_all_fields() {
2955 let json = r#"{
2956 "code": "edgefirst-validator:2.9.5",
2957 "title": "EdgeFirst Validator",
2958 "job_name": "smoke-test",
2959 "job_id": "aws-batch-abc",
2960 "state": "RUNNING",
2961 "launch": "2026-05-14T15:00:00Z",
2962 "task_id": 6789
2963 }"#;
2964 let job: Job = serde_json::from_str(json).unwrap();
2965 assert_eq!(job.code, "edgefirst-validator:2.9.5");
2966 assert_eq!(job.title, "EdgeFirst Validator");
2967 assert_eq!(job.job_name, "smoke-test");
2968 assert_eq!(job.job_id, "aws-batch-abc");
2969 assert_eq!(job.state, "RUNNING");
2970 assert!(job.launch.is_some());
2971 assert_eq!(job.task_id, 6789);
2972 }
2973
2974 #[test]
2975 fn job_tolerates_missing_optional_fields() {
2976 let json = r#"{ "task_id": 42 }"#;
2980 let job: Job = serde_json::from_str(json).unwrap();
2981 assert_eq!(job.task_id, 42);
2982 assert!(job.code.is_empty());
2983 assert!(job.title.is_empty());
2984 assert!(job.job_name.is_empty());
2985 assert!(job.job_id.is_empty());
2986 assert!(job.state.is_empty());
2987 assert!(job.launch.is_none());
2988 }
2989
2990 #[test]
2991 fn job_task_id_accessor_saturates_negative_to_zero() {
2992 let job = Job {
2997 code: String::new(),
2998 title: String::new(),
2999 job_name: String::new(),
3000 job_id: String::new(),
3001 state: String::new(),
3002 launch: None,
3003 task_id: -1,
3004 };
3005 assert_eq!(job.task_id().value(), 0);
3006 }
3007
3008 #[test]
3009 fn job_task_id_accessor_passes_through_positive_values() {
3010 let job = Job {
3011 code: String::new(),
3012 title: String::new(),
3013 job_name: String::new(),
3014 job_id: String::new(),
3015 state: String::new(),
3016 launch: None,
3017 task_id: 12345,
3018 };
3019 assert_eq!(job.task_id().value(), 12345);
3020 }
3021
3022 #[test]
3023 fn job_ignores_unknown_fields() {
3024 let json = r#"{
3028 "code": "x",
3029 "task_id": 1,
3030 "docker_task": { "image": "x" },
3031 "aws_region": "us-east-1",
3032 "tags": ["a", "b"]
3033 }"#;
3034 let job: Job = serde_json::from_str(json).unwrap();
3035 assert_eq!(job.task_id, 1);
3036 }
3037}
3038
3039#[cfg(test)]
3040mod tests_task_info_schema_tolerance {
3041 use super::*;
3042
3043 #[test]
3048 fn task_info_accepts_task_description_field() {
3049 let json = r#"{
3051 "id": 6699,
3052 "type": "edgefirst-validator:2.9.5",
3053 "task_description": "Profiler run for IMX95",
3054 "status": "running"
3055 }"#;
3056 let info: TaskInfo = serde_json::from_str(json).unwrap();
3057 assert_eq!(info.description(), "Profiler run for IMX95");
3058 }
3059
3060 #[test]
3061 fn task_info_accepts_legacy_description_field() {
3062 let json = r#"{
3064 "id": 6699,
3065 "type": "edgefirst-validator:2.9.5",
3066 "description": "Legacy description"
3067 }"#;
3068 let info: TaskInfo = serde_json::from_str(json).unwrap();
3069 assert_eq!(info.description(), "Legacy description");
3070 }
3071
3072 #[test]
3073 fn task_info_tolerates_missing_description() {
3074 let json = r#"{
3076 "id": 6699,
3077 "type": "x"
3078 }"#;
3079 let info: TaskInfo = serde_json::from_str(json).unwrap();
3080 assert!(info.description().is_empty());
3081 }
3082
3083 #[test]
3084 fn task_info_tolerates_missing_dates_via_default() {
3085 let json = r#"{
3087 "id": 6699,
3088 "type": "x"
3089 }"#;
3090 let info: TaskInfo = serde_json::from_str(json).unwrap();
3091 assert_eq!(info.id().value(), 6699);
3093 }
3094
3095 #[test]
3096 fn task_info_status_accessor_returns_option() {
3097 let json = r#"{
3098 "id": 1,
3099 "type": "x"
3100 }"#;
3101 let info: TaskInfo = serde_json::from_str(json).unwrap();
3102 assert!(info.status().is_none());
3103 }
3104
3105 #[test]
3106 fn task_info_stages_returns_empty_map_when_unset() {
3107 let json = r#"{
3108 "id": 1,
3109 "type": "x"
3110 }"#;
3111 let info: TaskInfo = serde_json::from_str(json).unwrap();
3112 let stages = info.stages();
3113 assert!(stages.is_empty());
3114 }
3115}
3116
3117#[cfg(test)]
3118mod tests_stage_struct {
3119 use super::*;
3120
3121 #[test]
3122 fn stage_new_sets_only_supplied_fields() {
3123 let stage = Stage::new(
3124 None,
3125 "download".into(),
3126 Some("running".into()),
3127 Some("fetching".into()),
3128 42,
3129 );
3130 assert!(stage.task_id().is_none());
3131 assert_eq!(stage.stage(), "download");
3132 assert_eq!(stage.status().as_deref(), Some("running"));
3133 assert_eq!(stage.message().as_deref(), Some("fetching"));
3134 assert_eq!(stage.percentage(), 42);
3135 assert!(stage.description().is_none());
3137 }
3138
3139 #[test]
3140 fn stage_serializes_without_optional_none_fields() {
3141 let stage = Stage::new(None, "init".into(), None, None, 0);
3143 let json = serde_json::to_value(&stage).unwrap();
3144 assert!(json.get("status").is_none(), "got: {json}");
3145 assert!(json.get("message").is_none(), "got: {json}");
3146 assert!(json.get("docker_task_id").is_none(), "got: {json}");
3147 assert_eq!(json["stage"], "init");
3149 assert_eq!(json["percentage"], 0);
3150 }
3151
3152 #[test]
3153 fn stage_serializes_task_id_when_present() {
3154 let task_id = TaskID::from(0xdeadu64);
3155 let stage = Stage::new(Some(task_id), "x".into(), None, None, 0);
3156 let json = serde_json::to_value(&stage).unwrap();
3157 assert!(json.get("docker_task_id").is_some());
3160 }
3161
3162 #[test]
3163 fn stage_round_trips_through_json() {
3164 let stage = Stage::new(
3165 None,
3166 "train".into(),
3167 Some("done".into()),
3168 Some("epoch 100".into()),
3169 100,
3170 );
3171 let s = serde_json::to_string(&stage).unwrap();
3172 let back: Stage = serde_json::from_str(&s).unwrap();
3173 assert_eq!(back.stage(), "train");
3174 assert_eq!(back.status().as_deref(), Some("done"));
3175 assert_eq!(back.message().as_deref(), Some("epoch 100"));
3176 assert_eq!(back.percentage(), 100);
3177 }
3178}
3179
3180#[cfg(test)]
3181mod tests_task_data_list_extra {
3182 use super::*;
3183
3184 #[test]
3185 fn task_data_list_with_empty_data_map() {
3186 let json = r#"{
3187 "server": "studio",
3188 "organization_uid": "org-1",
3189 "traces": [],
3190 "data": {}
3191 }"#;
3192 let parsed: TaskDataList = serde_json::from_str(json).unwrap();
3193 assert!(parsed.traces.is_empty());
3194 assert!(parsed.data.is_empty());
3195 }
3196
3197 #[test]
3198 fn task_data_list_multiple_folders() {
3199 let json = r#"{
3200 "server": "studio",
3201 "organization_uid": "org-1",
3202 "traces": ["t1", "t2"],
3203 "data": {
3204 "predictions": ["a.parquet", "b.parquet"],
3205 "metrics": ["loss.json"]
3206 }
3207 }"#;
3208 let parsed: TaskDataList = serde_json::from_str(json).unwrap();
3209 assert_eq!(parsed.traces.len(), 2);
3210 assert_eq!(parsed.data.len(), 2);
3211 assert_eq!(parsed.data["predictions"].len(), 2);
3212 }
3213}
3214
3215#[cfg(test)]
3216mod tests_artifact_struct {
3217 use super::*;
3218
3219 #[test]
3220 fn artifact_accessors_return_strs() {
3221 let json = r#"{ "name": "best.onnx", "modelType": "yolo" }"#;
3224 let a: Artifact = serde_json::from_str(json).unwrap();
3225 assert_eq!(a.name(), "best.onnx");
3226 assert_eq!(a.model_type(), "yolo");
3227 }
3228}
3229
3230#[cfg(test)]
3231mod tests_task_status_serialize {
3232 use super::*;
3233
3234 #[test]
3235 fn task_status_uses_docker_task_id_wire_field() {
3236 let s = TaskStatus {
3237 task_id: TaskID::from(0x1a2bu64),
3238 status: "training".into(),
3239 };
3240 let json = serde_json::to_value(&s).unwrap();
3241 assert!(json.get("docker_task_id").is_some(), "got: {json}");
3243 assert_eq!(json["status"], "training");
3244 }
3245}
3246
3247#[cfg(test)]
3248mod tests_task_stages_serialize {
3249 use super::*;
3250
3251 #[test]
3252 fn task_stages_omits_empty_vec() {
3253 let stages = TaskStages {
3254 task_id: TaskID::from(1u64),
3255 stages: Vec::new(),
3256 };
3257 let json = serde_json::to_value(&stages).unwrap();
3258 assert!(json.get("stages").is_none(), "got: {json}");
3260 }
3261
3262 #[test]
3263 fn task_stages_serializes_non_empty_vec() {
3264 let stages = TaskStages {
3265 task_id: TaskID::from(1u64),
3266 stages: vec![std::collections::HashMap::from([(
3267 "stage".to_string(),
3268 "download".to_string(),
3269 )])],
3270 };
3271 let json = serde_json::to_value(&stages).unwrap();
3272 assert_eq!(json["stages"][0]["stage"], "download");
3273 }
3274}