1use std::{collections::HashMap, fmt::Display};
5
6use crate::{
7 Client, Error,
8 api::{AnnotationSetID, DatasetID, ProjectID, SampleID},
9 mask::MaskData,
10};
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13
14#[cfg(feature = "polars")]
15use polars::prelude::*;
16
17#[derive(Clone, Eq, PartialEq, Debug)]
52pub enum FileType {
53 Image,
55 LidarPcd,
57 LidarDepth,
59 LidarReflect,
61 RadarPcd,
63 RadarCube,
65 All,
67}
68
69impl std::fmt::Display for FileType {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 let value = match self {
74 FileType::Image => "image",
75 FileType::LidarPcd => "lidar.pcd",
76 FileType::LidarDepth => "lidar.depth",
77 FileType::LidarReflect => "lidar.reflect",
78 FileType::RadarPcd => "radar.pcd",
79 FileType::RadarCube => "radar.png",
80 FileType::All => "all",
81 };
82 write!(f, "{}", value)
83 }
84}
85
86impl FileType {
87 pub fn file_extension(&self) -> &'static str {
90 match self {
91 FileType::Image => "jpg", FileType::LidarPcd => "lidar.pcd",
93 FileType::LidarDepth => "lidar.png",
94 FileType::LidarReflect => "lidar.jpg",
95 FileType::RadarPcd => "radar.pcd",
96 FileType::RadarCube => "radar.png",
97 FileType::All => "",
98 }
99 }
100}
101
102impl TryFrom<&str> for FileType {
103 type Error = crate::Error;
104
105 fn try_from(s: &str) -> Result<Self, Self::Error> {
106 match s {
107 "image" => Ok(FileType::Image),
108 "lidar.pcd" => Ok(FileType::LidarPcd),
109 "lidar.png" | "lidar.depth" | "depth.png" | "depthmap" => Ok(FileType::LidarDepth),
111 "lidar.jpg" | "lidar.jpeg" | "lidar.reflect" => Ok(FileType::LidarReflect),
112 "radar.pcd" | "pcd" => Ok(FileType::RadarPcd),
113 "radar.png" | "cube" => Ok(FileType::RadarCube),
114 "all" => Ok(FileType::All),
115 _ => Err(crate::Error::InvalidFileType(s.to_string())),
116 }
117 }
118}
119
120impl std::str::FromStr for FileType {
121 type Err = crate::Error;
122
123 fn from_str(s: &str) -> Result<Self, Self::Err> {
124 s.try_into()
125 }
126}
127
128impl FileType {
129 pub fn all_sensor_types() -> Vec<FileType> {
144 vec![
145 FileType::Image,
146 FileType::LidarPcd,
147 FileType::LidarDepth,
148 FileType::LidarReflect,
149 FileType::RadarPcd,
150 FileType::RadarCube,
151 ]
152 }
153
154 pub fn type_names() -> Vec<&'static str> {
166 vec![
167 "image",
168 "lidar.pcd",
169 "lidar.png",
170 "lidar.jpg",
171 "radar.pcd",
172 "radar.png",
173 "all",
174 ]
175 }
176
177 pub fn expand_types(types: &[FileType]) -> Vec<FileType> {
197 if types.contains(&FileType::All) {
198 FileType::all_sensor_types()
199 } else {
200 types.to_vec()
201 }
202 }
203}
204
205#[derive(Clone, Eq, PartialEq, Debug)]
237pub enum AnnotationType {
238 Box2d,
240 Box3d,
242 Polygon,
244 Mask,
246}
247
248impl TryFrom<&str> for AnnotationType {
249 type Error = crate::Error;
250
251 fn try_from(s: &str) -> Result<Self, Self::Error> {
252 match s {
253 "box2d" => Ok(AnnotationType::Box2d),
254 "box3d" => Ok(AnnotationType::Box3d),
255 "polygon" => Ok(AnnotationType::Polygon),
256 "seg" => Ok(AnnotationType::Polygon),
257 "mask" => Ok(AnnotationType::Polygon), "raster" => Ok(AnnotationType::Mask),
259 _ => Err(crate::Error::InvalidAnnotationType(s.to_string())),
260 }
261 }
262}
263
264impl From<String> for AnnotationType {
265 fn from(s: String) -> Self {
266 s.as_str().try_into().unwrap_or(AnnotationType::Box2d)
268 }
269}
270
271impl From<&String> for AnnotationType {
272 fn from(s: &String) -> Self {
273 s.as_str().try_into().unwrap_or(AnnotationType::Box2d)
275 }
276}
277
278impl AnnotationType {
279 pub fn as_server_type(&self) -> &'static str {
287 match self {
288 AnnotationType::Box2d => "box",
289 AnnotationType::Box3d => "box3d",
290 AnnotationType::Polygon => "seg",
291 AnnotationType::Mask => "seg",
292 }
293 }
294}
295
296impl std::fmt::Display for AnnotationType {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 let value = match self {
299 AnnotationType::Box2d => "box2d",
300 AnnotationType::Box3d => "box3d",
301 AnnotationType::Polygon => "polygon",
302 AnnotationType::Mask => "mask",
303 };
304 write!(f, "{}", value)
305 }
306}
307
308#[derive(Deserialize, Clone, Debug)]
347pub struct Dataset {
348 id: DatasetID,
349 project_id: ProjectID,
350 name: String,
351 description: String,
352 cloud_key: String,
353 #[serde(rename = "createdAt")]
354 created: DateTime<Utc>,
355}
356
357impl Display for Dataset {
358 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
359 write!(f, "{} {}", self.id, self.name)
360 }
361}
362
363impl Dataset {
364 pub fn id(&self) -> DatasetID {
365 self.id
366 }
367
368 pub fn project_id(&self) -> ProjectID {
369 self.project_id
370 }
371
372 pub fn name(&self) -> &str {
373 &self.name
374 }
375
376 pub fn description(&self) -> &str {
377 &self.description
378 }
379
380 pub fn cloud_key(&self) -> &str {
381 &self.cloud_key
382 }
383
384 pub fn created(&self) -> &DateTime<Utc> {
385 &self.created
386 }
387
388 pub async fn project(&self, client: &Client) -> Result<crate::api::Project, Error> {
389 client.project(self.project_id).await
390 }
391
392 pub async fn annotation_sets(&self, client: &Client) -> Result<Vec<AnnotationSet>, Error> {
393 client.annotation_sets(self.id).await
394 }
395
396 pub async fn labels(&self, client: &Client) -> Result<Vec<Label>, Error> {
397 client.labels(self.id).await
398 }
399
400 pub async fn add_label(&self, client: &Client, name: &str) -> Result<(), Error> {
401 client.add_label(self.id, name).await
402 }
403
404 pub async fn remove_label(&self, client: &Client, name: &str) -> Result<(), Error> {
405 let labels = self.labels(client).await?;
406 let label = labels
407 .iter()
408 .find(|l| l.name() == name)
409 .ok_or_else(|| Error::MissingLabel(name.to_string()))?;
410 client.remove_label(label.id()).await
411 }
412}
413
414#[derive(Deserialize)]
418pub struct AnnotationSet {
419 id: AnnotationSetID,
420 dataset_id: DatasetID,
421 name: String,
422 description: String,
423 #[serde(rename = "date")]
424 created: DateTime<Utc>,
425}
426
427impl Display for AnnotationSet {
428 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
429 write!(f, "{} {}", self.id, self.name)
430 }
431}
432
433impl AnnotationSet {
434 pub fn id(&self) -> AnnotationSetID {
435 self.id
436 }
437
438 pub fn dataset_id(&self) -> DatasetID {
439 self.dataset_id
440 }
441
442 pub fn name(&self) -> &str {
443 &self.name
444 }
445
446 pub fn description(&self) -> &str {
447 &self.description
448 }
449
450 pub fn created(&self) -> DateTime<Utc> {
451 self.created
452 }
453
454 pub async fn dataset(&self, client: &Client) -> Result<Dataset, Error> {
455 client.dataset(self.dataset_id).await
456 }
457}
458
459#[derive(Clone, Debug, Default, PartialEq)]
464pub struct Timing {
465 pub load: Option<i64>,
467 pub preprocess: Option<i64>,
469 pub inference: Option<i64>,
471 pub decode: Option<i64>,
473}
474
475#[derive(Serialize, Clone, Debug)]
482pub struct Sample {
483 #[serde(skip_serializing_if = "Option::is_none")]
484 pub id: Option<SampleID>,
485 #[serde(
490 alias = "group_name",
491 rename(serialize = "group", deserialize = "group_name"),
492 skip_serializing_if = "Option::is_none"
493 )]
494 pub group: Option<String>,
495 #[serde(skip_serializing_if = "Option::is_none")]
496 pub sequence_name: Option<String>,
497 #[serde(skip_serializing_if = "Option::is_none")]
498 pub sequence_uuid: Option<String>,
499 #[serde(skip_serializing_if = "Option::is_none")]
500 pub sequence_description: Option<String>,
501 #[serde(
502 default,
503 skip_serializing_if = "Option::is_none",
504 deserialize_with = "deserialize_frame_number"
505 )]
506 pub frame_number: Option<u32>,
507 #[serde(skip_serializing_if = "Option::is_none")]
508 pub uuid: Option<String>,
509 #[serde(skip_serializing_if = "Option::is_none")]
510 pub image_name: Option<String>,
511 #[serde(skip_serializing_if = "Option::is_none")]
512 pub image_url: Option<String>,
513 #[serde(skip_serializing_if = "Option::is_none")]
514 pub width: Option<u32>,
515 #[serde(skip_serializing_if = "Option::is_none")]
516 pub height: Option<u32>,
517 #[serde(skip_serializing_if = "Option::is_none")]
518 pub date: Option<DateTime<Utc>>,
519 #[serde(skip_serializing_if = "Option::is_none")]
520 pub source: Option<String>,
521 #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "sensors"))]
526 pub location: Option<Location>,
527 #[serde(skip_serializing_if = "Option::is_none")]
529 pub degradation: Option<String>,
530 #[serde(default, skip_serializing_if = "Option::is_none")]
532 pub neg_label_indices: Option<Vec<u32>>,
533 #[serde(default, skip_serializing_if = "Option::is_none")]
535 pub not_exhaustive_label_indices: Option<Vec<u32>>,
536 #[serde(
541 default,
542 skip_serializing_if = "Vec::is_empty",
543 serialize_with = "serialize_files"
544 )]
545 pub files: Vec<SampleFile>,
546 #[serde(
549 default,
550 skip_serializing_if = "Vec::is_empty",
551 serialize_with = "serialize_annotations"
552 )]
553 pub annotations: Vec<Annotation>,
554 #[serde(skip)]
557 pub timing: Option<Timing>,
558}
559
560fn deserialize_frame_number<'de, D>(deserializer: D) -> Result<Option<u32>, D::Error>
563where
564 D: serde::Deserializer<'de>,
565{
566 use serde::Deserialize;
567
568 let value = Option::<i32>::deserialize(deserializer)?;
569 Ok(value.and_then(|v| if v < 0 { None } else { Some(v as u32) }))
570}
571
572fn is_valid_url(s: &str) -> bool {
575 s.starts_with("http://") || s.starts_with("https://")
576}
577
578fn serialize_files<S>(files: &[SampleFile], serializer: S) -> Result<S::Ok, S::Error>
581where
582 S: serde::Serializer,
583{
584 use serde::Serialize;
585 let map: HashMap<String, String> = files
586 .iter()
587 .filter_map(|f| {
588 f.filename()
589 .map(|filename| (f.file_type().to_string(), filename.to_string()))
590 })
591 .collect();
592 map.serialize(serializer)
593}
594
595fn serialize_annotations<S>(annotations: &Vec<Annotation>, serializer: S) -> Result<S::Ok, S::Error>
599where
600 S: serde::Serializer,
601{
602 serde::Serialize::serialize(annotations, serializer)
603}
604
605fn deserialize_annotations<'de, D>(deserializer: D) -> Result<Vec<Annotation>, D::Error>
608where
609 D: serde::Deserializer<'de>,
610{
611 use serde::Deserialize;
612
613 #[derive(Deserialize)]
614 #[serde(untagged)]
615 enum AnnotationsFormat {
616 Vec(Vec<Annotation>),
617 Map(HashMap<String, Vec<Annotation>>),
618 }
619
620 let value = Option::<AnnotationsFormat>::deserialize(deserializer)?;
621 Ok(value
622 .map(|v| match v {
623 AnnotationsFormat::Vec(annotations) => annotations,
624 AnnotationsFormat::Map(map) => convert_annotations_map_to_vec(map),
625 })
626 .unwrap_or_default())
627}
628
629#[derive(Debug, Default)]
632struct SensorsData {
633 files: Vec<SampleFile>,
634 location: Option<Location>,
635}
636
637fn deserialize_sensors_data(value: Option<serde_json::Value>) -> SensorsData {
639 use serde_json::Value;
640
641 fn create_sample_file(file_type: String, value: String) -> SampleFile {
644 if is_valid_url(&value) {
645 SampleFile::with_url(file_type, value)
646 } else {
647 SampleFile::with_data(file_type, value)
648 }
649 }
650
651 fn create_sample_file_from_value(file_type: String, value: Value) -> Option<SampleFile> {
653 match value {
654 Value::String(s) => Some(create_sample_file(file_type, s)),
655 Value::Object(_) | Value::Array(_) => {
656 serde_json::to_string(&value)
658 .ok()
659 .map(|data| SampleFile::with_data(file_type, data))
660 }
661 _ => None,
662 }
663 }
664
665 fn extract_location(map: &serde_json::Map<String, Value>) -> Option<Location> {
667 let gps = map
668 .get("gps")
669 .and_then(|v| serde_json::from_value::<GpsData>(v.clone()).ok());
670 let imu = map
671 .get("imu")
672 .and_then(|v| serde_json::from_value::<ImuData>(v.clone()).ok());
673
674 if gps.is_some() || imu.is_some() {
675 Some(Location { gps, imu })
676 } else {
677 None
678 }
679 }
680
681 let mut result = SensorsData::default();
682
683 match value {
684 None => result,
685 Some(Value::Array(arr)) => {
686 for item in arr {
688 if let Value::Object(map) = item {
689 if map.contains_key("type") {
691 if let Ok(file) =
693 serde_json::from_value::<SampleFile>(Value::Object(map.clone()))
694 {
695 result.files.push(file);
696 }
697 } else {
698 if let Some(loc) = extract_location(&map) {
700 if let Some(ref mut existing) = result.location {
702 if loc.gps.is_some() {
703 existing.gps = loc.gps;
704 }
705 if loc.imu.is_some() {
706 existing.imu = loc.imu;
707 }
708 } else {
709 result.location = Some(loc);
710 }
711 } else {
712 for (file_type, value) in map {
714 if let Some(file) = create_sample_file_from_value(file_type, value)
715 {
716 result.files.push(file);
717 }
718 }
719 }
720 }
721 }
722 }
723 result
724 }
725 Some(Value::Object(map)) => {
726 if let Some(loc) = extract_location(&map) {
728 result.location = Some(loc);
729 }
730
731 for (key, value) in map {
733 if key != "gps"
734 && key != "imu"
735 && let Some(file) = create_sample_file_from_value(key, value)
736 {
737 result.files.push(file);
738 }
739 }
740 result
741 }
742 Some(_) => result,
743 }
744}
745
746#[derive(Deserialize)]
750struct SampleRaw {
751 #[serde(default)]
752 id: Option<SampleID>,
753 #[serde(alias = "group_name")]
754 group: Option<String>,
755 sequence_name: Option<String>,
756 sequence_uuid: Option<String>,
757 sequence_description: Option<String>,
758 #[serde(default, deserialize_with = "deserialize_frame_number")]
759 frame_number: Option<u32>,
760 uuid: Option<String>,
761 image_name: Option<String>,
762 image_url: Option<String>,
763 width: Option<u32>,
764 height: Option<u32>,
765 date: Option<DateTime<Utc>>,
766 source: Option<String>,
767 degradation: Option<String>,
768 #[serde(default)]
769 neg_label_indices: Option<Vec<u32>>,
770 #[serde(default)]
771 not_exhaustive_label_indices: Option<Vec<u32>>,
772 #[serde(default, alias = "sensors")]
774 sensors: Option<serde_json::Value>,
775 #[serde(default, deserialize_with = "deserialize_annotations")]
776 annotations: Vec<Annotation>,
777}
778
779impl From<SampleRaw> for Sample {
780 fn from(raw: SampleRaw) -> Self {
781 let sensors_data = deserialize_sensors_data(raw.sensors);
782
783 Sample {
784 id: raw.id,
785 group: raw.group,
786 sequence_name: raw.sequence_name,
787 sequence_uuid: raw.sequence_uuid,
788 sequence_description: raw.sequence_description,
789 frame_number: raw.frame_number,
790 uuid: raw.uuid,
791 image_name: raw.image_name,
792 image_url: raw.image_url,
793 width: raw.width,
794 height: raw.height,
795 date: raw.date,
796 source: raw.source,
797 location: sensors_data.location,
798 degradation: raw.degradation,
799 neg_label_indices: raw.neg_label_indices,
800 not_exhaustive_label_indices: raw.not_exhaustive_label_indices,
801 files: sensors_data.files,
802 annotations: raw.annotations,
803 timing: None,
804 }
805 }
806}
807
808impl<'de> serde::Deserialize<'de> for Sample {
809 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
810 where
811 D: serde::Deserializer<'de>,
812 {
813 let raw = SampleRaw::deserialize(deserializer)?;
814 Ok(Sample::from(raw))
815 }
816}
817
818impl Display for Sample {
819 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
820 write!(
821 f,
822 "{} {}",
823 self.id
824 .map(|id| id.to_string())
825 .unwrap_or_else(|| "unknown".to_string()),
826 self.image_name().unwrap_or("unknown")
827 )
828 }
829}
830
831impl Default for Sample {
832 fn default() -> Self {
833 Self::new()
834 }
835}
836
837impl Sample {
838 pub fn new() -> Self {
840 Self {
841 id: None,
842 group: None,
843 sequence_name: None,
844 sequence_uuid: None,
845 sequence_description: None,
846 frame_number: None,
847 uuid: None,
848 image_name: None,
849 image_url: None,
850 width: None,
851 height: None,
852 date: None,
853 source: None,
854 location: None,
855 degradation: None,
856 neg_label_indices: None,
857 not_exhaustive_label_indices: None,
858 files: vec![],
859 annotations: vec![],
860 timing: None,
861 }
862 }
863
864 pub fn id(&self) -> Option<SampleID> {
865 self.id
866 }
867
868 pub fn name(&self) -> Option<String> {
869 self.image_name.as_ref().map(|n| extract_sample_name(n))
870 }
871
872 pub fn group(&self) -> Option<&String> {
873 self.group.as_ref()
874 }
875
876 pub fn sequence_name(&self) -> Option<&String> {
877 self.sequence_name.as_ref()
878 }
879
880 pub fn sequence_uuid(&self) -> Option<&String> {
881 self.sequence_uuid.as_ref()
882 }
883
884 pub fn sequence_description(&self) -> Option<&String> {
885 self.sequence_description.as_ref()
886 }
887
888 pub fn frame_number(&self) -> Option<u32> {
889 self.frame_number
890 }
891
892 pub fn uuid(&self) -> Option<&String> {
893 self.uuid.as_ref()
894 }
895
896 pub fn image_name(&self) -> Option<&str> {
897 self.image_name.as_deref()
898 }
899
900 pub fn image_url(&self) -> Option<&str> {
901 self.image_url.as_deref()
902 }
903
904 pub fn width(&self) -> Option<u32> {
905 self.width
906 }
907
908 pub fn height(&self) -> Option<u32> {
909 self.height
910 }
911
912 pub fn date(&self) -> Option<DateTime<Utc>> {
913 self.date
914 }
915
916 pub fn source(&self) -> Option<&String> {
917 self.source.as_ref()
918 }
919
920 pub fn location(&self) -> Option<&Location> {
921 self.location.as_ref()
922 }
923
924 pub fn files(&self) -> &[SampleFile] {
925 &self.files
926 }
927
928 pub fn annotations(&self) -> &[Annotation] {
929 &self.annotations
930 }
931
932 pub fn with_annotations(mut self, annotations: Vec<Annotation>) -> Self {
933 self.annotations = annotations;
934 self
935 }
936
937 pub fn with_frame_number(mut self, frame_number: Option<u32>) -> Self {
938 self.frame_number = frame_number;
939 self
940 }
941
942 pub async fn download(
949 &self,
950 client: &Client,
951 file_type: FileType,
952 ) -> Result<Option<Vec<u8>>, Error> {
953 use base64::{Engine, engine::general_purpose::STANDARD};
954
955 if file_type == FileType::Image {
957 if let Some(url) = self.image_url.as_deref()
958 && is_valid_url(url)
959 {
960 return Ok(Some(client.download(url).await?));
961 }
962 return Ok(None);
963 }
964
965 let file = resolve_file(&file_type, &self.files);
967
968 match file {
969 Some(f) => {
970 if let Some(url) = f.url() {
972 return Ok(Some(client.download(url).await?));
973 }
974
975 if let Some(data) = f.data() {
977 let decoded = if let Ok(bytes) = STANDARD.decode(data) {
984 if let Ok(text) = String::from_utf8(bytes.clone()) {
986 if text.starts_with('{') {
987 text
989 } else {
990 return Ok(Some(bytes));
992 }
993 } else {
994 return Ok(Some(bytes));
996 }
997 } else {
998 data.to_string()
1000 };
1001
1002 let content = if decoded.starts_with('{') {
1004 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&decoded) {
1005 if let Some(obj) = json.as_object() {
1006 obj.values()
1007 .next()
1008 .and_then(|v| v.as_str())
1009 .map(|s| s.to_string())
1010 .unwrap_or(decoded)
1011 } else {
1012 decoded
1013 }
1014 } else {
1015 decoded
1016 }
1017 } else {
1018 decoded
1019 };
1020
1021 return Ok(Some(content.as_bytes().to_vec()));
1022 }
1023
1024 Ok(None)
1025 }
1026 None => Ok(None),
1027 }
1028 }
1029}
1030
1031#[derive(Serialize, Deserialize, Clone, Debug)]
1039pub struct SampleFile {
1040 r#type: String,
1041 #[serde(skip_serializing_if = "Option::is_none")]
1042 url: Option<String>,
1043 #[serde(skip_serializing_if = "Option::is_none")]
1044 filename: Option<String>,
1045 #[serde(skip_serializing_if = "Option::is_none", skip_deserializing)]
1047 data: Option<String>,
1048 #[serde(skip)]
1051 bytes: Option<Vec<u8>>,
1052}
1053
1054impl SampleFile {
1055 pub fn with_url(file_type: String, url: String) -> Self {
1057 Self {
1058 r#type: file_type,
1059 url: Some(url),
1060 filename: None,
1061 data: None,
1062 bytes: None,
1063 }
1064 }
1065
1066 pub fn with_filename(file_type: String, filename: String) -> Self {
1068 Self {
1069 r#type: file_type,
1070 url: None,
1071 filename: Some(filename),
1072 data: None,
1073 bytes: None,
1074 }
1075 }
1076
1077 pub fn with_data(file_type: String, data: String) -> Self {
1079 Self {
1080 r#type: file_type,
1081 url: None,
1082 filename: None,
1083 data: Some(data),
1084 bytes: None,
1085 }
1086 }
1087
1088 pub fn with_bytes(file_type: String, filename: String, bytes: Vec<u8>) -> Self {
1098 Self {
1099 r#type: file_type,
1100 url: None,
1101 filename: Some(filename),
1102 data: None,
1103 bytes: Some(bytes),
1104 }
1105 }
1106
1107 pub fn file_type(&self) -> &str {
1108 &self.r#type
1109 }
1110
1111 pub fn url(&self) -> Option<&str> {
1112 self.url.as_deref()
1113 }
1114
1115 pub fn filename(&self) -> Option<&str> {
1116 self.filename.as_deref()
1117 }
1118
1119 pub fn data(&self) -> Option<&str> {
1121 self.data.as_deref()
1122 }
1123
1124 pub fn bytes(&self) -> Option<&[u8]> {
1126 self.bytes.as_deref()
1127 }
1128}
1129
1130#[derive(Serialize, Deserialize, Clone, Debug)]
1135pub struct Location {
1136 #[serde(skip_serializing_if = "Option::is_none")]
1137 pub gps: Option<GpsData>,
1138 #[serde(skip_serializing_if = "Option::is_none")]
1139 pub imu: Option<ImuData>,
1140}
1141
1142#[derive(Serialize, Deserialize, Clone, Debug)]
1144pub struct GpsData {
1145 pub lat: f64,
1146 pub lon: f64,
1147}
1148
1149impl GpsData {
1150 pub fn validate(&self) -> Result<(), String> {
1180 validate_gps_coordinates(self.lat, self.lon)
1181 }
1182}
1183
1184#[derive(Serialize, Deserialize, Clone, Debug)]
1186pub struct ImuData {
1187 pub roll: f64,
1188 pub pitch: f64,
1189 pub yaw: f64,
1190}
1191
1192impl ImuData {
1193 pub fn validate(&self) -> Result<(), String> {
1226 validate_imu_orientation(self.roll, self.pitch, self.yaw)
1227 }
1228}
1229
1230#[allow(dead_code)]
1231pub trait TypeName {
1232 fn type_name() -> String;
1233}
1234
1235#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
1236pub struct Box3d {
1237 x: f32,
1238 y: f32,
1239 z: f32,
1240 w: f32,
1241 h: f32,
1242 l: f32,
1243}
1244
1245impl TypeName for Box3d {
1246 fn type_name() -> String {
1247 "box3d".to_owned()
1248 }
1249}
1250
1251impl Box3d {
1252 pub fn new(cx: f32, cy: f32, cz: f32, width: f32, height: f32, length: f32) -> Self {
1253 Self {
1254 x: cx,
1255 y: cy,
1256 z: cz,
1257 w: width,
1258 h: height,
1259 l: length,
1260 }
1261 }
1262
1263 pub fn width(&self) -> f32 {
1264 self.w
1265 }
1266
1267 pub fn height(&self) -> f32 {
1268 self.h
1269 }
1270
1271 pub fn length(&self) -> f32 {
1272 self.l
1273 }
1274
1275 pub fn cx(&self) -> f32 {
1276 self.x
1277 }
1278
1279 pub fn cy(&self) -> f32 {
1280 self.y
1281 }
1282
1283 pub fn cz(&self) -> f32 {
1284 self.z
1285 }
1286
1287 pub fn left(&self) -> f32 {
1288 self.x - self.w / 2.0
1289 }
1290
1291 pub fn top(&self) -> f32 {
1292 self.y - self.h / 2.0
1293 }
1294
1295 pub fn front(&self) -> f32 {
1296 self.z - self.l / 2.0
1297 }
1298}
1299
1300#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
1301pub struct Box2d {
1302 h: f32,
1303 w: f32,
1304 x: f32,
1305 y: f32,
1306}
1307
1308impl TypeName for Box2d {
1309 fn type_name() -> String {
1310 "box2d".to_owned()
1311 }
1312}
1313
1314impl Box2d {
1315 pub fn new(left: f32, top: f32, width: f32, height: f32) -> Self {
1316 Self {
1317 x: left,
1318 y: top,
1319 w: width,
1320 h: height,
1321 }
1322 }
1323
1324 pub fn width(&self) -> f32 {
1325 self.w
1326 }
1327
1328 pub fn height(&self) -> f32 {
1329 self.h
1330 }
1331
1332 pub fn left(&self) -> f32 {
1333 self.x
1334 }
1335
1336 pub fn top(&self) -> f32 {
1337 self.y
1338 }
1339
1340 pub fn cx(&self) -> f32 {
1341 self.x + self.w / 2.0
1342 }
1343
1344 pub fn cy(&self) -> f32 {
1345 self.y + self.h / 2.0
1346 }
1347}
1348
1349#[derive(Clone, Debug, PartialEq)]
1350pub struct Polygon {
1351 pub rings: Vec<Vec<(f32, f32)>>,
1352}
1353
1354impl TypeName for Polygon {
1355 fn type_name() -> String {
1356 "polygon".to_owned()
1357 }
1358}
1359
1360impl Polygon {
1361 pub fn new(rings: Vec<Vec<(f32, f32)>>) -> Self {
1362 Self { rings }
1363 }
1364}
1365
1366impl serde::Serialize for Polygon {
1367 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1368 where
1369 S: serde::Serializer,
1370 {
1371 serde::Serialize::serialize(&self.rings, serializer)
1372 }
1373}
1374
1375impl<'de> serde::Deserialize<'de> for Polygon {
1376 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1377 where
1378 D: serde::Deserializer<'de>,
1379 {
1380 let value = serde_json::Value::deserialize(deserializer)?;
1382
1383 let polygon_value = if let Some(obj) = value.as_object() {
1385 obj.get("rings")
1387 .or_else(|| obj.get("polygon"))
1388 .cloned()
1389 .unwrap_or(serde_json::Value::Null)
1390 } else {
1391 value
1393 };
1394
1395 let rings = parse_polygon_value(&polygon_value);
1397
1398 Ok(Self { rings })
1399 }
1400}
1401
1402fn parse_polygon_value(value: &serde_json::Value) -> Vec<Vec<(f32, f32)>> {
1410 let Some(outer_array) = value.as_array() else {
1411 return vec![];
1412 };
1413
1414 let mut result = Vec::new();
1415
1416 for ring in outer_array {
1417 let Some(ring_array) = ring.as_array() else {
1418 continue;
1419 };
1420
1421 let is_3d = ring_array
1423 .first()
1424 .map(|first| first.is_array())
1425 .unwrap_or(false);
1426
1427 let points: Vec<(f32, f32)> = if is_3d {
1428 ring_array
1430 .iter()
1431 .filter_map(|point| {
1432 let arr = point.as_array()?;
1433 if arr.len() >= 2 {
1434 let x = arr[0].as_f64()? as f32;
1435 let y = arr[1].as_f64()? as f32;
1436 if x.is_finite() && y.is_finite() {
1437 Some((x, y))
1438 } else {
1439 None
1440 }
1441 } else {
1442 None
1443 }
1444 })
1445 .collect()
1446 } else {
1447 ring_array
1449 .chunks(2)
1450 .filter_map(|chunk| {
1451 if chunk.len() >= 2 {
1452 let x = chunk[0].as_f64()? as f32;
1453 let y = chunk[1].as_f64()? as f32;
1454 if x.is_finite() && y.is_finite() {
1455 Some((x, y))
1456 } else {
1457 None
1458 }
1459 } else {
1460 None
1461 }
1462 })
1463 .collect()
1464 };
1465
1466 if points.len() >= 3 {
1468 result.push(points);
1469 }
1470 }
1471
1472 result
1473}
1474
1475#[derive(Deserialize)]
1480struct AnnotationRaw {
1481 #[serde(default)]
1482 sample_id: Option<SampleID>,
1483 #[serde(default)]
1484 name: Option<String>,
1485 #[serde(default)]
1486 sequence_name: Option<String>,
1487 #[serde(default)]
1488 frame_number: Option<u32>,
1489 #[serde(rename = "group_name", default)]
1490 group: Option<String>,
1491 #[serde(rename = "object_reference", alias = "object_id", default)]
1492 object_id: Option<String>,
1493 #[serde(default)]
1494 label_name: Option<String>,
1495 #[serde(default)]
1496 label_index: Option<u64>,
1497 #[serde(default)]
1498 iscrowd: Option<bool>,
1499 #[serde(default)]
1500 category_frequency: Option<String>,
1501 #[serde(default)]
1503 box2d: Option<Box2d>,
1504 #[serde(default)]
1505 box3d: Option<Box3d>,
1506 #[serde(default, alias = "mask")]
1507 polygon: Option<Polygon>,
1508 #[serde(default)]
1510 x: Option<f64>,
1511 #[serde(default)]
1512 y: Option<f64>,
1513 #[serde(default)]
1514 w: Option<f64>,
1515 #[serde(default)]
1516 h: Option<f64>,
1517}
1518
1519#[derive(Serialize, Clone, Debug)]
1520pub struct Annotation {
1521 #[serde(skip_serializing_if = "Option::is_none")]
1522 sample_id: Option<SampleID>,
1523 #[serde(skip_serializing_if = "Option::is_none")]
1524 name: Option<String>,
1525 #[serde(skip_serializing_if = "Option::is_none")]
1526 sequence_name: Option<String>,
1527 #[serde(skip_serializing_if = "Option::is_none")]
1528 frame_number: Option<u32>,
1529 #[serde(rename = "group_name", skip_serializing_if = "Option::is_none")]
1533 group: Option<String>,
1534 #[serde(
1538 rename = "object_reference",
1539 alias = "object_id",
1540 skip_serializing_if = "Option::is_none"
1541 )]
1542 object_id: Option<String>,
1543 #[serde(skip_serializing_if = "Option::is_none")]
1544 label_name: Option<String>,
1545 #[serde(skip_serializing_if = "Option::is_none")]
1546 label_index: Option<u64>,
1547 #[serde(default, skip_serializing_if = "Option::is_none")]
1549 iscrowd: Option<bool>,
1550 #[serde(default, skip_serializing_if = "Option::is_none")]
1552 category_frequency: Option<String>,
1553 #[serde(skip_serializing_if = "Option::is_none")]
1554 box2d: Option<Box2d>,
1555 #[serde(skip_serializing_if = "Option::is_none")]
1556 box3d: Option<Box3d>,
1557 #[serde(skip_serializing_if = "Option::is_none")]
1558 polygon: Option<Polygon>,
1559 #[serde(skip)]
1561 mask: Option<MaskData>,
1562 #[serde(skip_serializing_if = "Option::is_none")]
1564 box2d_score: Option<f32>,
1565 #[serde(skip_serializing_if = "Option::is_none")]
1567 box3d_score: Option<f32>,
1568 #[serde(skip_serializing_if = "Option::is_none")]
1570 polygon_score: Option<f32>,
1571 #[serde(skip_serializing_if = "Option::is_none")]
1573 mask_score: Option<f32>,
1574}
1575
1576impl<'de> serde::Deserialize<'de> for Annotation {
1577 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1578 where
1579 D: serde::Deserializer<'de>,
1580 {
1581 let raw: AnnotationRaw = serde::Deserialize::deserialize(deserializer)?;
1583
1584 let box2d = raw.box2d.or_else(|| match (raw.x, raw.y, raw.w, raw.h) {
1586 (Some(x), Some(y), Some(w), Some(h)) if w > 0.0 && h > 0.0 => {
1587 Some(Box2d::new(x as f32, y as f32, w as f32, h as f32))
1588 }
1589 _ => None,
1590 });
1591
1592 Ok(Annotation {
1593 sample_id: raw.sample_id,
1594 name: raw.name,
1595 sequence_name: raw.sequence_name,
1596 frame_number: raw.frame_number,
1597 group: raw.group,
1598 object_id: raw.object_id,
1599 label_name: raw.label_name,
1600 label_index: raw.label_index,
1601 iscrowd: raw.iscrowd,
1602 category_frequency: raw.category_frequency,
1603 box2d,
1604 box3d: raw.box3d,
1605 polygon: raw.polygon,
1606 mask: None,
1607 box2d_score: None,
1608 box3d_score: None,
1609 polygon_score: None,
1610 mask_score: None,
1611 })
1612 }
1613}
1614
1615impl Default for Annotation {
1616 fn default() -> Self {
1617 Self::new()
1618 }
1619}
1620
1621impl Annotation {
1622 pub fn new() -> Self {
1623 Self {
1624 sample_id: None,
1625 name: None,
1626 sequence_name: None,
1627 frame_number: None,
1628 group: None,
1629 object_id: None,
1630 label_name: None,
1631 label_index: None,
1632 iscrowd: None,
1633 category_frequency: None,
1634 box2d: None,
1635 box3d: None,
1636 polygon: None,
1637 mask: None,
1638 box2d_score: None,
1639 box3d_score: None,
1640 polygon_score: None,
1641 mask_score: None,
1642 }
1643 }
1644
1645 pub fn set_sample_id(&mut self, sample_id: Option<SampleID>) {
1646 self.sample_id = sample_id;
1647 }
1648
1649 pub fn sample_id(&self) -> Option<SampleID> {
1650 self.sample_id
1651 }
1652
1653 pub fn set_name(&mut self, name: Option<String>) {
1654 self.name = name;
1655 }
1656
1657 pub fn name(&self) -> Option<&String> {
1658 self.name.as_ref()
1659 }
1660
1661 pub fn set_sequence_name(&mut self, sequence_name: Option<String>) {
1662 self.sequence_name = sequence_name;
1663 }
1664
1665 pub fn sequence_name(&self) -> Option<&String> {
1666 self.sequence_name.as_ref()
1667 }
1668
1669 pub fn set_frame_number(&mut self, frame_number: Option<u32>) {
1670 self.frame_number = frame_number;
1671 }
1672
1673 pub fn frame_number(&self) -> Option<u32> {
1674 self.frame_number
1675 }
1676
1677 pub fn set_group(&mut self, group: Option<String>) {
1678 self.group = group;
1679 }
1680
1681 pub fn group(&self) -> Option<&String> {
1682 self.group.as_ref()
1683 }
1684
1685 pub fn object_id(&self) -> Option<&String> {
1686 self.object_id.as_ref()
1687 }
1688
1689 pub fn set_object_id(&mut self, object_id: Option<String>) {
1690 self.object_id = object_id;
1691 }
1692
1693 pub fn label(&self) -> Option<&String> {
1694 self.label_name.as_ref()
1695 }
1696
1697 pub fn set_label(&mut self, label_name: Option<String>) {
1698 self.label_name = label_name;
1699 }
1700
1701 pub fn label_index(&self) -> Option<u64> {
1702 self.label_index
1703 }
1704
1705 pub fn set_label_index(&mut self, label_index: Option<u64>) {
1706 self.label_index = label_index;
1707 }
1708
1709 pub fn iscrowd(&self) -> Option<bool> {
1710 self.iscrowd
1711 }
1712
1713 pub fn set_iscrowd(&mut self, iscrowd: Option<bool>) {
1714 self.iscrowd = iscrowd;
1715 }
1716
1717 pub fn category_frequency(&self) -> Option<&String> {
1718 self.category_frequency.as_ref()
1719 }
1720
1721 pub fn set_category_frequency(&mut self, category_frequency: Option<String>) {
1722 self.category_frequency = category_frequency;
1723 }
1724
1725 pub fn box2d(&self) -> Option<&Box2d> {
1726 self.box2d.as_ref()
1727 }
1728
1729 pub fn set_box2d(&mut self, box2d: Option<Box2d>) {
1730 self.box2d = box2d;
1731 }
1732
1733 pub fn box3d(&self) -> Option<&Box3d> {
1734 self.box3d.as_ref()
1735 }
1736
1737 pub fn set_box3d(&mut self, box3d: Option<Box3d>) {
1738 self.box3d = box3d;
1739 }
1740
1741 pub fn polygon(&self) -> Option<&Polygon> {
1742 self.polygon.as_ref()
1743 }
1744
1745 pub fn set_polygon(&mut self, polygon: Option<Polygon>) {
1746 self.polygon = polygon;
1747 }
1748
1749 pub fn mask(&self) -> Option<&MaskData> {
1750 self.mask.as_ref()
1751 }
1752
1753 pub fn set_mask(&mut self, mask: Option<MaskData>) {
1754 self.mask = mask;
1755 }
1756
1757 pub fn box2d_score(&self) -> Option<f32> {
1758 self.box2d_score
1759 }
1760
1761 pub fn set_box2d_score(&mut self, score: Option<f32>) {
1762 self.box2d_score = score;
1763 }
1764
1765 pub fn box3d_score(&self) -> Option<f32> {
1766 self.box3d_score
1767 }
1768
1769 pub fn set_box3d_score(&mut self, score: Option<f32>) {
1770 self.box3d_score = score;
1771 }
1772
1773 pub fn polygon_score(&self) -> Option<f32> {
1774 self.polygon_score
1775 }
1776
1777 pub fn set_polygon_score(&mut self, score: Option<f32>) {
1778 self.polygon_score = score;
1779 }
1780
1781 pub fn mask_score(&self) -> Option<f32> {
1782 self.mask_score
1783 }
1784
1785 pub fn set_mask_score(&mut self, score: Option<f32>) {
1786 self.mask_score = score;
1787 }
1788}
1789
1790#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
1791pub struct Label {
1792 id: u64,
1793 dataset_id: DatasetID,
1794 index: u64,
1795 name: String,
1796}
1797
1798impl Label {
1799 pub fn id(&self) -> u64 {
1800 self.id
1801 }
1802
1803 pub fn dataset_id(&self) -> DatasetID {
1804 self.dataset_id
1805 }
1806
1807 pub fn index(&self) -> u64 {
1808 self.index
1809 }
1810
1811 pub fn name(&self) -> &str {
1812 &self.name
1813 }
1814
1815 pub async fn remove(&self, client: &Client) -> Result<(), Error> {
1816 client.remove_label(self.id()).await
1817 }
1818
1819 pub async fn set_name(&mut self, client: &Client, name: &str) -> Result<(), Error> {
1820 self.name = name.to_string();
1821 client.update_label(self).await
1822 }
1823
1824 pub async fn set_index(&mut self, client: &Client, index: u64) -> Result<(), Error> {
1825 self.index = index;
1826 client.update_label(self).await
1827 }
1828}
1829
1830impl Display for Label {
1831 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1832 write!(f, "{}", self.name())
1833 }
1834}
1835
1836#[derive(Serialize, Clone, Debug)]
1837pub struct NewLabelObject {
1838 pub name: String,
1839}
1840
1841#[derive(Serialize, Clone, Debug)]
1842pub struct NewLabel {
1843 pub dataset_id: DatasetID,
1844 pub labels: Vec<NewLabelObject>,
1845}
1846
1847#[derive(Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
1877pub struct Group {
1878 pub id: u64,
1883
1884 pub name: String,
1888}
1889
1890#[cfg(feature = "polars")]
1891fn extract_annotation_name(ann: &Annotation) -> Option<(String, Option<u32>)> {
1892 use std::path::Path;
1893
1894 let name = ann.name.as_ref()?;
1895 let name = Path::new(name).file_stem()?.to_str()?;
1896
1897 match &ann.sequence_name {
1900 Some(sequence) => Some((sequence.clone(), ann.frame_number)),
1901 None => Some((name.to_string(), None)),
1902 }
1903}
1904
1905#[cfg(feature = "polars")]
1909fn convert_polygon_to_nested_series(polygon: &Polygon) -> Series {
1910 let ring_series: Vec<Option<Series>> = polygon
1911 .rings
1912 .iter()
1913 .map(|ring| {
1914 let coords: Vec<f32> = ring.iter().flat_map(|&(x, y)| [x, y]).collect();
1915 Some(Series::new("".into(), coords))
1916 })
1917 .collect();
1918 Series::new("".into(), ring_series)
1919}
1920
1921#[cfg(feature = "polars")]
1971pub fn samples_dataframe(samples: &[Sample]) -> Result<DataFrame, Error> {
1972 let mut names: Vec<String> = Vec::new();
1974 let mut frames: Vec<Option<u32>> = Vec::new();
1975 let mut objects: Vec<Option<String>> = Vec::new();
1976 let mut labels: Vec<Option<String>> = Vec::new();
1977 let mut label_indices: Vec<Option<u64>> = Vec::new();
1978 let mut groups: Vec<Option<String>> = Vec::new();
1979 let mut polygons: Vec<Option<Series>> = Vec::new();
1980 let mut boxes2d: Vec<Option<Series>> = Vec::new();
1981 let mut boxes3d: Vec<Option<Series>> = Vec::new();
1982 let mut mask_bytes: Vec<Option<Vec<u8>>> = Vec::new();
1983 let mut box2d_scores: Vec<Option<f32>> = Vec::new();
1984 let mut box3d_scores: Vec<Option<f32>> = Vec::new();
1985 let mut polygon_scores: Vec<Option<f32>> = Vec::new();
1986 let mut mask_scores: Vec<Option<f32>> = Vec::new();
1987 let mut sizes: Vec<Option<Vec<u32>>> = Vec::new();
1988 let mut locations: Vec<Option<Vec<f32>>> = Vec::new();
1989 let mut poses: Vec<Option<Vec<f32>>> = Vec::new();
1990 let mut degradations: Vec<Option<String>> = Vec::new();
1991 let mut iscrowds: Vec<Option<bool>> = Vec::new();
1992 let mut category_frequencies: Vec<Option<String>> = Vec::new();
1993 let mut neg_label_indices_vec: Vec<Option<Vec<u32>>> = Vec::new();
1994 let mut not_exhaustive_label_indices_vec: Vec<Option<Vec<u32>>> = Vec::new();
1995 let mut timing_load: Vec<Option<i64>> = Vec::new();
1996 let mut timing_preprocess: Vec<Option<i64>> = Vec::new();
1997 let mut timing_inference: Vec<Option<i64>> = Vec::new();
1998 let mut timing_decode: Vec<Option<i64>> = Vec::new();
1999
2000 for sample in samples {
2001 let size = match (sample.width, sample.height) {
2003 (Some(w), Some(h)) => Some(vec![w, h]),
2004 _ => None,
2005 };
2006
2007 let location = sample.location.as_ref().and_then(|loc| {
2008 loc.gps
2009 .as_ref()
2010 .map(|gps| vec![gps.lat as f32, gps.lon as f32])
2011 });
2012
2013 let pose = sample.location.as_ref().and_then(|loc| {
2014 loc.imu
2015 .as_ref()
2016 .map(|imu| vec![imu.yaw as f32, imu.pitch as f32, imu.roll as f32])
2017 });
2018
2019 let degradation = sample.degradation.clone();
2020
2021 let t_load = sample.timing.as_ref().and_then(|t| t.load);
2023 let t_preprocess = sample.timing.as_ref().and_then(|t| t.preprocess);
2024 let t_inference = sample.timing.as_ref().and_then(|t| t.inference);
2025 let t_decode = sample.timing.as_ref().and_then(|t| t.decode);
2026
2027 macro_rules! push_sample_fields {
2029 () => {
2030 sizes.push(size.clone());
2031 locations.push(location.clone());
2032 poses.push(pose.clone());
2033 degradations.push(degradation.clone());
2034 neg_label_indices_vec.push(sample.neg_label_indices.clone());
2035 not_exhaustive_label_indices_vec.push(sample.not_exhaustive_label_indices.clone());
2036 timing_load.push(t_load);
2037 timing_preprocess.push(t_preprocess);
2038 timing_inference.push(t_inference);
2039 timing_decode.push(t_decode);
2040 };
2041 }
2042
2043 if sample.annotations.is_empty() {
2044 let (name, frame) = match extract_annotation_name_from_sample(sample) {
2046 Some(nf) => nf,
2047 None => continue,
2048 };
2049
2050 names.push(name);
2051 frames.push(frame);
2052 objects.push(None);
2053 labels.push(None);
2054 label_indices.push(None);
2055 groups.push(sample.group.clone());
2056 polygons.push(None);
2057 boxes2d.push(None);
2058 boxes3d.push(None);
2059 mask_bytes.push(None);
2060 box2d_scores.push(None);
2061 box3d_scores.push(None);
2062 polygon_scores.push(None);
2063 mask_scores.push(None);
2064 iscrowds.push(None);
2065 category_frequencies.push(None);
2066 push_sample_fields!();
2067 } else {
2068 for ann in &sample.annotations {
2070 let (name, frame) = match extract_annotation_name(ann) {
2071 Some(nf) => nf,
2072 None => continue,
2073 };
2074
2075 let polygon = ann.polygon.as_ref().map(convert_polygon_to_nested_series);
2076
2077 let box2d = ann
2078 .box2d
2079 .as_ref()
2080 .map(|b| Series::new("box2d".into(), [b.cx(), b.cy(), b.width(), b.height()]));
2081
2082 let box3d = ann
2083 .box3d
2084 .as_ref()
2085 .map(|b| Series::new("box3d".into(), [b.x, b.y, b.z, b.w, b.h, b.l]));
2086
2087 names.push(name);
2088 frames.push(frame);
2089 objects.push(ann.object_id().cloned());
2090 labels.push(ann.label_name.clone());
2091 label_indices.push(ann.label_index);
2092 groups.push(sample.group.clone());
2093 polygons.push(polygon);
2094 boxes2d.push(box2d);
2095 boxes3d.push(box3d);
2096 mask_bytes.push(ann.mask.as_ref().map(|m| m.as_bytes().to_vec()));
2097 box2d_scores.push(ann.box2d_score());
2098 box3d_scores.push(ann.box3d_score());
2099 polygon_scores.push(ann.polygon_score());
2100 mask_scores.push(ann.mask_score());
2101 iscrowds.push(ann.iscrowd);
2102 category_frequencies.push(ann.category_frequency.clone());
2103 push_sample_fields!();
2104 }
2105 }
2106 }
2107
2108 let names_col: Column = Series::new("name".into(), names).into();
2110 let frames_col: Column = Series::new("frame".into(), frames).into();
2111 let objects_col: Column = Series::new("object_id".into(), objects).into();
2112
2113 let labels_col: Column = Series::new("label".into(), labels)
2115 .cast(&DataType::Categorical(
2116 Categories::new("labels".into(), "labels".into(), CategoricalPhysical::U8),
2117 Arc::new(CategoricalMapping::with_hasher(
2118 u8::MAX as usize,
2119 Default::default(),
2120 )),
2121 ))?
2122 .into();
2123
2124 let label_indices_col: Column = Series::new("label_index".into(), label_indices).into();
2125
2126 let groups_col: Column = Series::new("group".into(), groups)
2128 .cast(&DataType::Categorical(
2129 Categories::new("groups".into(), "groups".into(), CategoricalPhysical::U8),
2130 Arc::new(CategoricalMapping::with_hasher(
2131 u8::MAX as usize,
2132 Default::default(),
2133 )),
2134 ))?
2135 .into();
2136
2137 let polygons_col: Column = if polygons.iter().all(|p| p.is_none()) {
2142 Series::new_null("polygon".into(), polygons.len()).into()
2144 } else {
2145 let typed_polygons: Vec<Option<Series>> = polygons
2148 .into_iter()
2149 .map(|opt| {
2150 opt.map(|s| {
2151 s.cast(&DataType::List(Box::new(DataType::Float32)))
2152 .unwrap_or(s)
2153 })
2154 })
2155 .collect();
2156 Series::new("polygon".into(), &typed_polygons)
2157 .cast(&DataType::List(Box::new(DataType::List(Box::new(
2158 DataType::Float32,
2159 )))))?
2160 .into()
2161 };
2162
2163 let boxes2d_col: Column = Series::new("box2d".into(), boxes2d)
2164 .cast(&DataType::Array(Box::new(DataType::Float32), 4))?
2165 .into();
2166 let boxes3d_col: Column = Series::new("box3d".into(), boxes3d)
2167 .cast(&DataType::Array(Box::new(DataType::Float32), 6))?
2168 .into();
2169
2170 let mask_col: Column = Series::new("mask".into(), mask_bytes).into();
2172
2173 let box2d_score_col: Column = Series::new("box2d_score".into(), box2d_scores).into();
2175 let box3d_score_col: Column = Series::new("box3d_score".into(), box3d_scores).into();
2176 let polygon_score_col: Column = Series::new("polygon_score".into(), polygon_scores).into();
2177 let mask_score_col: Column = Series::new("mask_score".into(), mask_scores).into();
2178
2179 let size_series: Vec<Option<Series>> = sizes
2181 .into_iter()
2182 .map(|opt_vec| opt_vec.map(|vec| Series::new("size".into(), vec)))
2183 .collect();
2184 let sizes_col: Column = Series::new("size".into(), size_series)
2185 .cast(&DataType::Array(Box::new(DataType::UInt32), 2))?
2186 .into();
2187
2188 let location_series: Vec<Option<Series>> = locations
2189 .into_iter()
2190 .map(|opt_vec| opt_vec.map(|vec| Series::new("location".into(), vec)))
2191 .collect();
2192 let locations_col: Column = Series::new("location".into(), location_series)
2193 .cast(&DataType::Array(Box::new(DataType::Float32), 2))?
2194 .into();
2195
2196 let pose_series: Vec<Option<Series>> = poses
2197 .into_iter()
2198 .map(|opt_vec| opt_vec.map(|vec| Series::new("pose".into(), vec)))
2199 .collect();
2200 let poses_col: Column = Series::new("pose".into(), pose_series)
2201 .cast(&DataType::Array(Box::new(DataType::Float32), 3))?
2202 .into();
2203
2204 let degradations_col: Column = Series::new("degradation".into(), degradations).into();
2205
2206 let iscrowds_col: Column = Series::new("iscrowd".into(), iscrowds).into();
2208
2209 let category_frequencies_col: Column =
2210 Series::new("category_frequency".into(), category_frequencies)
2211 .cast(&DataType::Categorical(
2212 Categories::new(
2213 "cat_freq".into(),
2214 "cat_freq".into(),
2215 CategoricalPhysical::U8,
2216 ),
2217 Arc::new(CategoricalMapping::with_hasher(
2218 u8::MAX as usize,
2219 Default::default(),
2220 )),
2221 ))?
2222 .into();
2223
2224 let neg_label_indices_series: Vec<Option<Series>> = neg_label_indices_vec
2225 .into_iter()
2226 .map(|opt_vec| opt_vec.map(|vec| Series::new("neg_label_indices".into(), vec)))
2227 .collect();
2228 let neg_label_indices_col: Column =
2229 Series::new("neg_label_indices".into(), neg_label_indices_series)
2230 .cast(&DataType::List(Box::new(DataType::UInt32)))?
2231 .into();
2232
2233 let not_exhaustive_label_indices_series: Vec<Option<Series>> = not_exhaustive_label_indices_vec
2234 .into_iter()
2235 .map(|opt_vec| opt_vec.map(|vec| Series::new("not_exhaustive_label_indices".into(), vec)))
2236 .collect();
2237 let not_exhaustive_label_indices_col: Column = Series::new(
2238 "not_exhaustive_label_indices".into(),
2239 not_exhaustive_label_indices_series,
2240 )
2241 .cast(&DataType::List(Box::new(DataType::UInt32)))?
2242 .into();
2243
2244 let timing_col: Column = StructChunked::from_series(
2246 "timing".into(),
2247 frames_col.len(),
2248 [
2249 Series::new("load".into(), &timing_load),
2250 Series::new("preprocess".into(), &timing_preprocess),
2251 Series::new("inference".into(), &timing_inference),
2252 Series::new("decode".into(), &timing_decode),
2253 ]
2254 .iter(),
2255 )?
2256 .into_series()
2257 .into();
2258
2259 let all_columns: Vec<Column> = vec![
2261 names_col,
2262 frames_col,
2263 objects_col,
2264 labels_col,
2265 label_indices_col,
2266 groups_col,
2267 polygons_col,
2268 boxes2d_col,
2269 boxes3d_col,
2270 mask_col,
2271 box2d_score_col,
2272 box3d_score_col,
2273 polygon_score_col,
2274 mask_score_col,
2275 sizes_col,
2276 locations_col,
2277 poses_col,
2278 degradations_col,
2279 iscrowds_col,
2280 category_frequencies_col,
2281 neg_label_indices_col,
2282 not_exhaustive_label_indices_col,
2283 timing_col,
2284 ];
2285
2286 let height = all_columns.first().map(|c| c.len()).unwrap_or(0);
2287
2288 let non_empty_columns: Vec<Column> = all_columns
2289 .into_iter()
2290 .filter(|col| col.name() == "name" || !is_all_null_column(col))
2291 .collect();
2292
2293 Ok(DataFrame::new(height, non_empty_columns)?)
2294}
2295
2296#[cfg(feature = "polars")]
2300fn is_all_null_column(col: &Column) -> bool {
2301 if col.is_empty() {
2302 return true;
2303 }
2304 if col.null_count() == col.len() {
2305 return true;
2306 }
2307 if let DataType::Struct(..) = col.dtype()
2309 && let Ok(s) = col.as_materialized_series().struct_()
2310 {
2311 return s
2312 .fields_as_series()
2313 .iter()
2314 .all(|field| field.null_count() == field.len());
2315 }
2316 false
2317}
2318
2319#[cfg(feature = "polars")]
2321fn extract_annotation_name_from_sample(sample: &Sample) -> Option<(String, Option<u32>)> {
2322 use std::path::Path;
2323
2324 let name = sample.image_name.as_ref()?;
2325 let name = Path::new(name).file_stem()?.to_str()?;
2326
2327 match &sample.sequence_name {
2330 Some(sequence) => Some((sequence.clone(), sample.frame_number)),
2331 None => Some((name.to_string(), None)),
2332 }
2333}
2334
2335fn extract_sample_name(image_name: &str) -> String {
2348 let name = image_name
2350 .rsplit_once('.')
2351 .and_then(|(name, _)| {
2352 if name.is_empty() {
2354 None
2355 } else {
2356 Some(name.to_string())
2357 }
2358 })
2359 .unwrap_or_else(|| image_name.to_string());
2360
2361 name.rsplit_once(".camera")
2363 .and_then(|(name, _)| {
2364 if name.is_empty() {
2366 None
2367 } else {
2368 Some(name.to_string())
2369 }
2370 })
2371 .unwrap_or_else(|| name.clone())
2372}
2373
2374fn resolve_file<'a>(file_type: &FileType, files: &'a [SampleFile]) -> Option<&'a SampleFile> {
2383 match file_type {
2384 FileType::Image => None, FileType::All => None, file => {
2387 let type_names = file_type_names(file);
2389 files
2390 .iter()
2391 .find(|f| type_names.contains(&f.r#type.as_str()))
2392 }
2393 }
2394}
2395
2396fn file_type_names(file_type: &FileType) -> Vec<&'static str> {
2399 match file_type {
2400 FileType::Image => vec!["image"],
2401 FileType::LidarPcd => vec!["lidar.pcd"],
2402 FileType::LidarDepth => vec!["lidar.depth", "depth.png", "depthmap"],
2403 FileType::LidarReflect => vec!["lidar.reflect"],
2404 FileType::RadarPcd => vec!["radar.pcd", "pcd"],
2405 FileType::RadarCube => vec!["radar.png", "cube"],
2406 FileType::All => vec![],
2407 }
2408}
2409
2410fn convert_annotations_map_to_vec(map: HashMap<String, Vec<Annotation>>) -> Vec<Annotation> {
2423 let mut all_annotations = Vec::new();
2424 if let Some(bbox_anns) = map.get("bbox") {
2425 all_annotations.extend(bbox_anns.clone());
2426 }
2427 if let Some(box3d_anns) = map.get("box3d") {
2428 all_annotations.extend(box3d_anns.clone());
2429 }
2430 if let Some(mask_anns) = map.get("mask") {
2431 all_annotations.extend(mask_anns.clone());
2432 }
2433 all_annotations
2434}
2435
2436fn validate_gps_coordinates(lat: f64, lon: f64) -> Result<(), String> {
2456 if !lat.is_finite() {
2457 return Err(format!("GPS latitude is not finite: {}", lat));
2458 }
2459 if !lon.is_finite() {
2460 return Err(format!("GPS longitude is not finite: {}", lon));
2461 }
2462 if !(-90.0..=90.0).contains(&lat) {
2463 return Err(format!("GPS latitude out of range [-90, 90]: {}", lat));
2464 }
2465 if !(-180.0..=180.0).contains(&lon) {
2466 return Err(format!("GPS longitude out of range [-180, 180]: {}", lon));
2467 }
2468 Ok(())
2469}
2470
2471fn validate_imu_orientation(roll: f64, pitch: f64, yaw: f64) -> Result<(), String> {
2490 if !roll.is_finite() {
2491 return Err(format!("IMU roll is not finite: {}", roll));
2492 }
2493 if !pitch.is_finite() {
2494 return Err(format!("IMU pitch is not finite: {}", pitch));
2495 }
2496 if !yaw.is_finite() {
2497 return Err(format!("IMU yaw is not finite: {}", yaw));
2498 }
2499 if !(-180.0..=180.0).contains(&roll) {
2500 return Err(format!("IMU roll out of range [-180, 180]: {}", roll));
2501 }
2502 if !(-90.0..=90.0).contains(&pitch) {
2503 return Err(format!("IMU pitch out of range [-90, 90]: {}", pitch));
2504 }
2505 if !(-180.0..=180.0).contains(&yaw) {
2506 return Err(format!("IMU yaw out of range [-180, 180]: {}", yaw));
2507 }
2508 Ok(())
2509}
2510
2511#[cfg(feature = "polars")]
2539pub fn unflatten_polygon_coordinates(coords: &[f32]) -> Vec<Vec<(f32, f32)>> {
2540 let mut polygons = Vec::new();
2541 let mut current_polygon = Vec::new();
2542 let mut i = 0;
2543
2544 while i < coords.len() {
2545 if coords[i].is_nan() {
2546 if !current_polygon.is_empty() {
2548 polygons.push(std::mem::take(&mut current_polygon));
2549 }
2550 i += 1;
2551 } else if i + 1 < coords.len() && !coords[i + 1].is_nan() {
2552 current_polygon.push((coords[i], coords[i + 1]));
2554 i += 2;
2555 } else if i + 1 < coords.len() && coords[i + 1].is_nan() {
2556 i += 1;
2559 } else {
2560 i += 1;
2562 }
2563 }
2564
2565 if !current_polygon.is_empty() {
2567 polygons.push(current_polygon);
2568 }
2569
2570 polygons
2571}
2572
2573#[cfg(test)]
2574mod tests {
2575 use super::*;
2576
2577 fn flatten_annotation_map(
2586 map: std::collections::HashMap<String, Vec<Annotation>>,
2587 ) -> Vec<Annotation> {
2588 let mut all_annotations = Vec::new();
2589
2590 for key in ["bbox", "box3d", "mask"] {
2592 if let Some(mut anns) = map.get(key).cloned() {
2593 all_annotations.append(&mut anns);
2594 }
2595 }
2596
2597 all_annotations
2598 }
2599
2600 fn annotation_group_field_name() -> &'static str {
2602 "group_name"
2603 }
2604
2605 fn annotation_object_id_field_name() -> &'static str {
2607 "object_reference"
2608 }
2609
2610 fn annotation_object_id_alias() -> &'static str {
2612 "object_id"
2613 }
2614
2615 fn validate_annotation_field_names(
2618 json_str: &str,
2619 expected_group: bool,
2620 expected_object_ref: bool,
2621 ) -> Result<(), String> {
2622 if expected_group && !json_str.contains("\"group_name\"") {
2623 return Err("Missing expected field: group_name".to_string());
2624 }
2625 if expected_object_ref && !json_str.contains("\"object_reference\"") {
2626 return Err("Missing expected field: object_reference".to_string());
2627 }
2628 Ok(())
2629 }
2630
2631 #[test]
2633 fn test_file_type_conversions() {
2634 let api_cases = vec![
2636 (FileType::Image, "image"),
2637 (FileType::LidarPcd, "lidar.pcd"),
2638 (FileType::LidarDepth, "lidar.depth"),
2639 (FileType::LidarReflect, "lidar.reflect"),
2640 (FileType::RadarPcd, "radar.pcd"),
2641 (FileType::RadarCube, "radar.png"),
2642 ];
2643
2644 let ext_cases = vec![
2646 (FileType::Image, "jpg"),
2647 (FileType::LidarPcd, "lidar.pcd"),
2648 (FileType::LidarDepth, "lidar.png"),
2649 (FileType::LidarReflect, "lidar.jpg"),
2650 (FileType::RadarPcd, "radar.pcd"),
2651 (FileType::RadarCube, "radar.png"),
2652 ];
2653
2654 for (file_type, expected_str) in &api_cases {
2656 assert_eq!(file_type.to_string(), *expected_str);
2657 }
2658
2659 for (file_type, expected_ext) in &ext_cases {
2661 assert_eq!(file_type.file_extension(), *expected_ext);
2662 }
2663
2664 assert_eq!(
2666 FileType::try_from("lidar.depth").unwrap(),
2667 FileType::LidarDepth
2668 );
2669 assert_eq!(
2670 FileType::try_from("lidar.png").unwrap(),
2671 FileType::LidarDepth
2672 );
2673 assert_eq!(
2674 FileType::try_from("depth.png").unwrap(),
2675 FileType::LidarDepth
2676 );
2677 assert_eq!(
2678 FileType::try_from("lidar.reflect").unwrap(),
2679 FileType::LidarReflect
2680 );
2681 assert_eq!(
2682 FileType::try_from("lidar.jpg").unwrap(),
2683 FileType::LidarReflect
2684 );
2685 assert_eq!(
2686 FileType::try_from("lidar.jpeg").unwrap(),
2687 FileType::LidarReflect
2688 );
2689
2690 assert!(FileType::try_from("invalid").is_err());
2692
2693 for (file_type, _) in &api_cases {
2695 let s = file_type.to_string();
2696 let parsed = FileType::try_from(s.as_str()).unwrap();
2697 assert_eq!(parsed, *file_type);
2698 }
2699 }
2700
2701 #[test]
2703 fn test_annotation_type_conversions() {
2704 let cases = vec![
2705 (AnnotationType::Box2d, "box2d"),
2706 (AnnotationType::Box3d, "box3d"),
2707 (AnnotationType::Polygon, "polygon"),
2708 (AnnotationType::Mask, "mask"),
2709 ];
2710
2711 for (ann_type, expected_str) in &cases {
2713 assert_eq!(ann_type.to_string(), *expected_str);
2714 }
2715
2716 assert_eq!(
2718 AnnotationType::try_from("box2d").unwrap(),
2719 AnnotationType::Box2d
2720 );
2721 assert_eq!(
2722 AnnotationType::try_from("box3d").unwrap(),
2723 AnnotationType::Box3d
2724 );
2725 assert_eq!(
2726 AnnotationType::try_from("polygon").unwrap(),
2727 AnnotationType::Polygon
2728 );
2729 assert_eq!(
2731 AnnotationType::try_from("mask").unwrap(),
2732 AnnotationType::Polygon
2733 );
2734 assert_eq!(
2736 AnnotationType::try_from("raster").unwrap(),
2737 AnnotationType::Mask
2738 );
2739
2740 assert_eq!(
2742 AnnotationType::from("box2d".to_string()),
2743 AnnotationType::Box2d
2744 );
2745 assert_eq!(
2746 AnnotationType::from("box3d".to_string()),
2747 AnnotationType::Box3d
2748 );
2749 assert_eq!(
2750 AnnotationType::from("polygon".to_string()),
2751 AnnotationType::Polygon
2752 );
2753 assert_eq!(
2755 AnnotationType::from("mask".to_string()),
2756 AnnotationType::Polygon
2757 );
2758
2759 assert_eq!(
2761 AnnotationType::from("invalid".to_string()),
2762 AnnotationType::Box2d
2763 );
2764
2765 assert!(AnnotationType::try_from("invalid").is_err());
2767
2768 assert_eq!(
2773 AnnotationType::try_from(AnnotationType::Box2d.to_string().as_str()).unwrap(),
2774 AnnotationType::Box2d
2775 );
2776 assert_eq!(
2777 AnnotationType::try_from(AnnotationType::Box3d.to_string().as_str()).unwrap(),
2778 AnnotationType::Box3d
2779 );
2780 assert_eq!(
2781 AnnotationType::try_from(AnnotationType::Polygon.to_string().as_str()).unwrap(),
2782 AnnotationType::Polygon
2783 );
2784 }
2785
2786 #[test]
2788 fn test_extract_sample_name_with_extension_and_camera() {
2789 assert_eq!(extract_sample_name("scene_001.camera.jpg"), "scene_001");
2790 }
2791
2792 #[test]
2793 fn test_extract_sample_name_multiple_dots() {
2794 assert_eq!(extract_sample_name("image.v2.camera.png"), "image.v2");
2795 }
2796
2797 #[test]
2798 fn test_extract_sample_name_extension_only() {
2799 assert_eq!(extract_sample_name("test.jpg"), "test");
2800 }
2801
2802 #[test]
2803 fn test_extract_sample_name_no_extension() {
2804 assert_eq!(extract_sample_name("test"), "test");
2805 }
2806
2807 #[test]
2808 fn test_extract_sample_name_edge_case_dot_prefix() {
2809 assert_eq!(extract_sample_name(".jpg"), ".jpg");
2810 }
2811
2812 #[test]
2814 fn test_resolve_file_image_type_returns_none() {
2815 let files = vec![];
2817 let result = resolve_file(&FileType::Image, &files);
2818 assert!(result.is_none());
2819 }
2820
2821 #[test]
2822 fn test_resolve_file_lidar_pcd() {
2823 let files = vec![
2824 SampleFile::with_url(
2825 "lidar.pcd".to_string(),
2826 "https://example.com/file.pcd".to_string(),
2827 ),
2828 SampleFile::with_url(
2829 "radar.pcd".to_string(),
2830 "https://example.com/radar.pcd".to_string(),
2831 ),
2832 ];
2833 let result = resolve_file(&FileType::LidarPcd, &files);
2834 assert!(result.is_some());
2835 assert_eq!(result.unwrap().url(), Some("https://example.com/file.pcd"));
2836 }
2837
2838 #[test]
2839 fn test_resolve_file_not_found() {
2840 let files = vec![SampleFile::with_url(
2841 "lidar.pcd".to_string(),
2842 "https://example.com/file.pcd".to_string(),
2843 )];
2844 let result = resolve_file(&FileType::RadarPcd, &files);
2846 assert!(result.is_none());
2847 }
2848
2849 #[test]
2850 fn test_resolve_file_lidar_depth() {
2851 let files = vec![SampleFile::with_url(
2853 "lidar.depth".to_string(),
2854 "https://example.com/depth.png".to_string(),
2855 )];
2856 let result = resolve_file(&FileType::LidarDepth, &files);
2857 assert!(result.is_some());
2858 assert_eq!(result.unwrap().url(), Some("https://example.com/depth.png"));
2859 }
2860
2861 #[test]
2862 fn test_resolve_file_lidar_reflect() {
2863 let files = vec![SampleFile::with_url(
2865 "lidar.reflect".to_string(),
2866 "https://example.com/reflect.png".to_string(),
2867 )];
2868 let result = resolve_file(&FileType::LidarReflect, &files);
2869 assert!(result.is_some());
2870 assert_eq!(
2871 result.unwrap().url(),
2872 Some("https://example.com/reflect.png")
2873 );
2874 }
2875
2876 #[test]
2877 fn test_resolve_file_radar_cube() {
2878 let files = vec![SampleFile::with_url(
2880 "radar.png".to_string(),
2881 "https://example.com/radar.png".to_string(),
2882 )];
2883 let result = resolve_file(&FileType::RadarCube, &files);
2884 assert!(result.is_some());
2885 assert_eq!(result.unwrap().url(), Some("https://example.com/radar.png"));
2886 }
2887
2888 #[test]
2889 fn test_resolve_file_with_inline_data() {
2890 let files = vec![SampleFile::with_data(
2892 "radar.pcd".to_string(),
2893 "SGVsbG8gV29ybGQ=".to_string(), )];
2895 let result = resolve_file(&FileType::RadarPcd, &files);
2896 assert!(result.is_some());
2897 let file = result.unwrap();
2898 assert!(file.url().is_none());
2899 assert_eq!(file.data(), Some("SGVsbG8gV29ybGQ="));
2900 }
2901
2902 #[test]
2903 fn test_convert_annotations_map_to_vec_with_bbox() {
2904 let mut map = HashMap::new();
2905 let bbox_ann = Annotation::new();
2906 map.insert("bbox".to_string(), vec![bbox_ann.clone()]);
2907
2908 let annotations = convert_annotations_map_to_vec(map);
2909 assert_eq!(annotations.len(), 1);
2910 }
2911
2912 #[test]
2913 fn test_convert_annotations_map_to_vec_all_types() {
2914 let mut map = HashMap::new();
2915 map.insert("bbox".to_string(), vec![Annotation::new()]);
2916 map.insert("box3d".to_string(), vec![Annotation::new()]);
2917 map.insert("mask".to_string(), vec![Annotation::new()]);
2918
2919 let annotations = convert_annotations_map_to_vec(map);
2920 assert_eq!(annotations.len(), 3);
2921 }
2922
2923 #[test]
2924 fn test_convert_annotations_map_to_vec_empty() {
2925 let map = HashMap::new();
2926 let annotations = convert_annotations_map_to_vec(map);
2927 assert_eq!(annotations.len(), 0);
2928 }
2929
2930 #[test]
2931 fn test_convert_annotations_map_to_vec_unknown_type_ignored() {
2932 let mut map = HashMap::new();
2933 map.insert("unknown".to_string(), vec![Annotation::new()]);
2934
2935 let annotations = convert_annotations_map_to_vec(map);
2936 assert_eq!(annotations.len(), 0);
2938 }
2939
2940 #[test]
2942 fn test_annotation_group_field_name() {
2943 assert_eq!(annotation_group_field_name(), "group_name");
2944 }
2945
2946 #[test]
2947 fn test_annotation_object_id_field_name() {
2948 assert_eq!(annotation_object_id_field_name(), "object_reference");
2949 }
2950
2951 #[test]
2952 fn test_annotation_object_id_alias() {
2953 assert_eq!(annotation_object_id_alias(), "object_id");
2954 }
2955
2956 #[test]
2957 fn test_validate_annotation_field_names_success() {
2958 let json = r#"{"group_name":"train","object_reference":"obj1"}"#;
2959 assert!(validate_annotation_field_names(json, true, true).is_ok());
2960 }
2961
2962 #[test]
2963 fn test_validate_annotation_field_names_missing_group() {
2964 let json = r#"{"object_reference":"obj1"}"#;
2965 let result = validate_annotation_field_names(json, true, false);
2966 assert!(result.is_err());
2967 assert!(result.unwrap_err().contains("group_name"));
2968 }
2969
2970 #[test]
2971 fn test_validate_annotation_field_names_missing_object_ref() {
2972 let json = r#"{"group_name":"train"}"#;
2973 let result = validate_annotation_field_names(json, false, true);
2974 assert!(result.is_err());
2975 assert!(result.unwrap_err().contains("object_reference"));
2976 }
2977
2978 #[test]
2979 fn test_annotation_serialization_field_names() {
2980 let mut ann = Annotation::new();
2982 ann.set_group(Some("train".to_string()));
2983 ann.set_object_id(Some("obj1".to_string()));
2984
2985 let json = serde_json::to_string(&ann).unwrap();
2986 assert!(validate_annotation_field_names(&json, true, true).is_ok());
2988 }
2989
2990 #[test]
2992 fn test_validate_gps_coordinates_valid() {
2993 assert!(validate_gps_coordinates(37.7749, -122.4194).is_ok()); assert!(validate_gps_coordinates(0.0, 0.0).is_ok()); assert!(validate_gps_coordinates(90.0, 180.0).is_ok()); assert!(validate_gps_coordinates(-90.0, -180.0).is_ok()); }
2998
2999 #[test]
3000 fn test_validate_gps_coordinates_invalid_latitude() {
3001 let result = validate_gps_coordinates(91.0, 0.0);
3002 assert!(result.is_err());
3003 assert!(result.unwrap_err().contains("latitude out of range"));
3004
3005 let result = validate_gps_coordinates(-91.0, 0.0);
3006 assert!(result.is_err());
3007 assert!(result.unwrap_err().contains("latitude out of range"));
3008 }
3009
3010 #[test]
3011 fn test_validate_gps_coordinates_invalid_longitude() {
3012 let result = validate_gps_coordinates(0.0, 181.0);
3013 assert!(result.is_err());
3014 assert!(result.unwrap_err().contains("longitude out of range"));
3015
3016 let result = validate_gps_coordinates(0.0, -181.0);
3017 assert!(result.is_err());
3018 assert!(result.unwrap_err().contains("longitude out of range"));
3019 }
3020
3021 #[test]
3022 fn test_validate_gps_coordinates_non_finite() {
3023 let result = validate_gps_coordinates(f64::NAN, 0.0);
3024 assert!(result.is_err());
3025 assert!(result.unwrap_err().contains("not finite"));
3026
3027 let result = validate_gps_coordinates(0.0, f64::INFINITY);
3028 assert!(result.is_err());
3029 assert!(result.unwrap_err().contains("not finite"));
3030 }
3031
3032 #[test]
3033 fn test_validate_imu_orientation_valid() {
3034 assert!(validate_imu_orientation(0.0, 0.0, 0.0).is_ok());
3035 assert!(validate_imu_orientation(45.0, 30.0, 90.0).is_ok());
3036 assert!(validate_imu_orientation(180.0, 90.0, -180.0).is_ok()); assert!(validate_imu_orientation(-180.0, -90.0, 180.0).is_ok()); }
3039
3040 #[test]
3041 fn test_validate_imu_orientation_invalid_roll() {
3042 let result = validate_imu_orientation(181.0, 0.0, 0.0);
3043 assert!(result.is_err());
3044 assert!(result.unwrap_err().contains("roll out of range"));
3045
3046 let result = validate_imu_orientation(-181.0, 0.0, 0.0);
3047 assert!(result.is_err());
3048 }
3049
3050 #[test]
3051 fn test_validate_imu_orientation_invalid_pitch() {
3052 let result = validate_imu_orientation(0.0, 91.0, 0.0);
3053 assert!(result.is_err());
3054 assert!(result.unwrap_err().contains("pitch out of range"));
3055
3056 let result = validate_imu_orientation(0.0, -91.0, 0.0);
3057 assert!(result.is_err());
3058 }
3059
3060 #[test]
3061 fn test_validate_imu_orientation_non_finite() {
3062 let result = validate_imu_orientation(f64::NAN, 0.0, 0.0);
3063 assert!(result.is_err());
3064 assert!(result.unwrap_err().contains("not finite"));
3065
3066 let result = validate_imu_orientation(0.0, f64::INFINITY, 0.0);
3067 assert!(result.is_err());
3068
3069 let result = validate_imu_orientation(0.0, 0.0, f64::NEG_INFINITY);
3070 assert!(result.is_err());
3071 }
3072
3073 #[test]
3075 #[cfg(feature = "polars")]
3076 fn test_unflatten_polygon_coordinates_single_polygon() {
3077 let coords = vec![1.0, 2.0, 3.0, 4.0];
3078 let result = unflatten_polygon_coordinates(&coords);
3079
3080 assert_eq!(result.len(), 1);
3081 assert_eq!(result[0].len(), 2);
3082 assert_eq!(result[0][0], (1.0, 2.0));
3083 assert_eq!(result[0][1], (3.0, 4.0));
3084 }
3085
3086 #[test]
3087 #[cfg(feature = "polars")]
3088 fn test_unflatten_polygon_coordinates_multiple_polygons() {
3089 let coords = vec![1.0, 2.0, 3.0, 4.0, f32::NAN, 5.0, 6.0, 7.0, 8.0];
3090 let result = unflatten_polygon_coordinates(&coords);
3091
3092 assert_eq!(result.len(), 2);
3093 assert_eq!(result[0].len(), 2);
3094 assert_eq!(result[0][0], (1.0, 2.0));
3095 assert_eq!(result[0][1], (3.0, 4.0));
3096 assert_eq!(result[1].len(), 2);
3097 assert_eq!(result[1][0], (5.0, 6.0));
3098 assert_eq!(result[1][1], (7.0, 8.0));
3099 }
3100
3101 #[test]
3102 #[cfg(feature = "polars")]
3103 fn test_unflatten_polygon_coordinates_roundtrip() {
3104 let flat = vec![1.0, 2.0, 3.0, 4.0, f32::NAN, 5.0, 6.0, 7.0, 8.0];
3106 let result = unflatten_polygon_coordinates(&flat);
3107
3108 let expected = vec![vec![(1.0, 2.0), (3.0, 4.0)], vec![(5.0, 6.0), (7.0, 8.0)]];
3109 assert_eq!(result, expected);
3110 }
3111
3112 #[test]
3114 fn test_flatten_annotation_map_all_types() {
3115 use std::collections::HashMap;
3116
3117 let mut map = HashMap::new();
3118
3119 let mut bbox_ann = Annotation::new();
3121 bbox_ann.set_label(Some("bbox_label".to_string()));
3122
3123 let mut box3d_ann = Annotation::new();
3124 box3d_ann.set_label(Some("box3d_label".to_string()));
3125
3126 let mut mask_ann = Annotation::new();
3127 mask_ann.set_label(Some("mask_label".to_string()));
3128
3129 map.insert("bbox".to_string(), vec![bbox_ann.clone()]);
3130 map.insert("box3d".to_string(), vec![box3d_ann.clone()]);
3131 map.insert("mask".to_string(), vec![mask_ann.clone()]);
3132
3133 let result = flatten_annotation_map(map);
3134
3135 assert_eq!(result.len(), 3);
3136 assert_eq!(result[0].label(), Some(&"bbox_label".to_string()));
3138 assert_eq!(result[1].label(), Some(&"box3d_label".to_string()));
3139 assert_eq!(result[2].label(), Some(&"mask_label".to_string()));
3140 }
3141
3142 #[test]
3143 fn test_flatten_annotation_map_single_type() {
3144 use std::collections::HashMap;
3145
3146 let mut map = HashMap::new();
3147 let mut bbox_ann = Annotation::new();
3148 bbox_ann.set_label(Some("test".to_string()));
3149 map.insert("bbox".to_string(), vec![bbox_ann]);
3150
3151 let result = flatten_annotation_map(map);
3152
3153 assert_eq!(result.len(), 1);
3154 assert_eq!(result[0].label(), Some(&"test".to_string()));
3155 }
3156
3157 #[test]
3158 fn test_flatten_annotation_map_empty() {
3159 use std::collections::HashMap;
3160
3161 let map = HashMap::new();
3162 let result = flatten_annotation_map(map);
3163
3164 assert_eq!(result.len(), 0);
3165 }
3166
3167 #[test]
3168 fn test_flatten_annotation_map_deterministic_order() {
3169 use std::collections::HashMap;
3170
3171 let mut map = HashMap::new();
3172
3173 let mut bbox_ann = Annotation::new();
3174 bbox_ann.set_label(Some("bbox".to_string()));
3175
3176 let mut box3d_ann = Annotation::new();
3177 box3d_ann.set_label(Some("box3d".to_string()));
3178
3179 let mut mask_ann = Annotation::new();
3180 mask_ann.set_label(Some("mask".to_string()));
3181
3182 map.insert("mask".to_string(), vec![mask_ann]);
3184 map.insert("box3d".to_string(), vec![box3d_ann]);
3185 map.insert("bbox".to_string(), vec![bbox_ann]);
3186
3187 let result = flatten_annotation_map(map);
3188
3189 assert_eq!(result.len(), 3);
3191 assert_eq!(result[0].label(), Some(&"bbox".to_string()));
3192 assert_eq!(result[1].label(), Some(&"box3d".to_string()));
3193 assert_eq!(result[2].label(), Some(&"mask".to_string()));
3194 }
3195
3196 #[test]
3198 fn test_box2d_construction_and_accessors() {
3199 let bbox = Box2d::new(10.0, 20.0, 100.0, 50.0);
3201 assert_eq!(
3202 (bbox.left(), bbox.top(), bbox.width(), bbox.height()),
3203 (10.0, 20.0, 100.0, 50.0)
3204 );
3205
3206 assert_eq!((bbox.cx(), bbox.cy()), (60.0, 45.0)); let bbox = Box2d::new(0.0, 0.0, 640.0, 480.0);
3211 assert_eq!(
3212 (bbox.left(), bbox.top(), bbox.width(), bbox.height()),
3213 (0.0, 0.0, 640.0, 480.0)
3214 );
3215 assert_eq!((bbox.cx(), bbox.cy()), (320.0, 240.0));
3216 }
3217
3218 #[test]
3219 fn test_box2d_center_calculation() {
3220 let bbox = Box2d::new(10.0, 20.0, 100.0, 50.0);
3221
3222 assert_eq!(bbox.cx(), 60.0); assert_eq!(bbox.cy(), 45.0); }
3226
3227 #[test]
3228 fn test_box2d_zero_dimensions() {
3229 let bbox = Box2d::new(10.0, 20.0, 0.0, 0.0);
3230
3231 assert_eq!(bbox.cx(), 10.0);
3233 assert_eq!(bbox.cy(), 20.0);
3234 }
3235
3236 #[test]
3237 fn test_box2d_negative_dimensions() {
3238 let bbox = Box2d::new(100.0, 100.0, -50.0, -50.0);
3239
3240 assert_eq!(bbox.width(), -50.0);
3242 assert_eq!(bbox.height(), -50.0);
3243 assert_eq!(bbox.cx(), 75.0); assert_eq!(bbox.cy(), 75.0); }
3246
3247 #[test]
3249 fn test_box3d_construction_and_accessors() {
3250 let bbox = Box3d::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0);
3252 assert_eq!((bbox.cx(), bbox.cy(), bbox.cz()), (1.0, 2.0, 3.0));
3253 assert_eq!(
3254 (bbox.width(), bbox.height(), bbox.length()),
3255 (4.0, 5.0, 6.0)
3256 );
3257
3258 let bbox = Box3d::new(10.0, 20.0, 30.0, 4.0, 6.0, 8.0);
3260 assert_eq!((bbox.left(), bbox.top(), bbox.front()), (8.0, 17.0, 26.0)); let bbox = Box3d::new(0.0, 0.0, 0.0, 2.0, 3.0, 4.0);
3264 assert_eq!((bbox.cx(), bbox.cy(), bbox.cz()), (0.0, 0.0, 0.0));
3265 assert_eq!(
3266 (bbox.width(), bbox.height(), bbox.length()),
3267 (2.0, 3.0, 4.0)
3268 );
3269 assert_eq!((bbox.left(), bbox.top(), bbox.front()), (-1.0, -1.5, -2.0));
3270 }
3271
3272 #[test]
3273 fn test_box3d_center_calculation() {
3274 let bbox = Box3d::new(10.0, 20.0, 30.0, 100.0, 50.0, 40.0);
3275
3276 assert_eq!(bbox.cx(), 10.0);
3278 assert_eq!(bbox.cy(), 20.0);
3279 assert_eq!(bbox.cz(), 30.0);
3280 }
3281
3282 #[test]
3283 fn test_box3d_zero_dimensions() {
3284 let bbox = Box3d::new(5.0, 10.0, 15.0, 0.0, 0.0, 0.0);
3285
3286 assert_eq!(bbox.cx(), 5.0);
3288 assert_eq!(bbox.cy(), 10.0);
3289 assert_eq!(bbox.cz(), 15.0);
3290 assert_eq!((bbox.left(), bbox.top(), bbox.front()), (5.0, 10.0, 15.0));
3291 }
3292
3293 #[test]
3294 fn test_box3d_negative_dimensions() {
3295 let bbox = Box3d::new(100.0, 100.0, 100.0, -50.0, -50.0, -50.0);
3296
3297 assert_eq!(bbox.width(), -50.0);
3299 assert_eq!(bbox.height(), -50.0);
3300 assert_eq!(bbox.length(), -50.0);
3301 assert_eq!(
3302 (bbox.left(), bbox.top(), bbox.front()),
3303 (125.0, 125.0, 125.0)
3304 );
3305 }
3306
3307 #[test]
3309 fn test_polygon_creation_and_deserialization() {
3310 let rings = vec![vec![(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]];
3312 let polygon = Polygon::new(rings.clone());
3313 assert_eq!(polygon.rings, rings);
3314
3315 let legacy = serde_json::json!({
3317 "polygon": {
3318 "polygon": [[
3319 [0.0_f32, 0.0_f32],
3320 [1.0_f32, 0.0_f32],
3321 [1.0_f32, 1.0_f32]
3322 ]]
3323 }
3324 });
3325
3326 #[derive(serde::Deserialize)]
3327 struct Wrapper {
3328 polygon: Polygon,
3329 }
3330
3331 let parsed: Wrapper = serde_json::from_value(legacy).unwrap();
3332 assert_eq!(parsed.polygon.rings.len(), 1);
3333 assert_eq!(parsed.polygon.rings[0].len(), 3);
3334 }
3335
3336 #[test]
3338 fn test_sample_construction_and_accessors() {
3339 let sample = Sample::new();
3341 assert_eq!(sample.id(), None);
3342 assert_eq!(sample.image_name(), None);
3343 assert_eq!(sample.width(), None);
3344 assert_eq!(sample.height(), None);
3345
3346 let mut sample = Sample::new();
3348 sample.image_name = Some("test.jpg".to_string());
3349 sample.width = Some(1920);
3350 sample.height = Some(1080);
3351 sample.group = Some("group1".to_string());
3352
3353 assert_eq!(sample.image_name(), Some("test.jpg"));
3354 assert_eq!(sample.width(), Some(1920));
3355 assert_eq!(sample.height(), Some(1080));
3356 assert_eq!(sample.group(), Some(&"group1".to_string()));
3357 }
3358
3359 #[test]
3360 fn test_sample_name_extraction_from_image_name() {
3361 let mut sample = Sample::new();
3362
3363 sample.image_name = Some("test_image.jpg".to_string());
3365 assert_eq!(sample.name(), Some("test_image".to_string()));
3366
3367 sample.image_name = Some("test_image.camera.jpg".to_string());
3369 assert_eq!(sample.name(), Some("test_image".to_string()));
3370
3371 sample.image_name = Some("test_image".to_string());
3373 assert_eq!(sample.name(), Some("test_image".to_string()));
3374 }
3375
3376 #[test]
3378 fn test_annotation_construction_and_setters() {
3379 let ann = Annotation::new();
3381 assert_eq!(ann.sample_id(), None);
3382 assert_eq!(ann.label(), None);
3383 assert_eq!(ann.box2d(), None);
3384 assert_eq!(ann.box3d(), None);
3385 assert_eq!(ann.polygon(), None);
3386
3387 let mut ann = Annotation::new();
3389 ann.set_label(Some("car".to_string()));
3390 assert_eq!(ann.label(), Some(&"car".to_string()));
3391
3392 ann.set_label_index(Some(42));
3393 assert_eq!(ann.label_index(), Some(42));
3394
3395 let bbox = Box2d::new(10.0, 20.0, 100.0, 50.0);
3397 ann.set_box2d(Some(bbox.clone()));
3398 assert!(ann.box2d().is_some());
3399 assert_eq!(ann.box2d().unwrap().left(), 10.0);
3400 }
3401
3402 #[test]
3404 fn test_sample_file_with_url_and_filename() {
3405 let file = SampleFile::with_url(
3407 "lidar.pcd".to_string(),
3408 "https://example.com/file.pcd".to_string(),
3409 );
3410 assert_eq!(file.file_type(), "lidar.pcd");
3411 assert_eq!(file.url(), Some("https://example.com/file.pcd"));
3412 assert_eq!(file.filename(), None);
3413
3414 let file = SampleFile::with_filename("image".to_string(), "test.jpg".to_string());
3416 assert_eq!(file.file_type(), "image");
3417 assert_eq!(file.filename(), Some("test.jpg"));
3418 assert_eq!(file.url(), None);
3419 }
3420
3421 #[test]
3423 fn test_sample_deserializes_gps_imu_from_sensors() {
3424 use serde_json::json;
3425
3426 let sample_json = json!({
3428 "id": 123,
3429 "image_name": "test.jpg",
3430 "sensors": [
3431 {"gps": {"lat": 37.7749, "lon": -122.4194}},
3432 {"imu": {"roll": 1.5, "pitch": 2.5, "yaw": 3.5}},
3433 {"radar.pcd": "https://example.com/radar.pcd"}
3434 ]
3435 });
3436
3437 let sample: Sample = serde_json::from_value(sample_json).unwrap();
3438
3439 assert!(sample.location.is_some());
3441 let location = sample.location.as_ref().unwrap();
3442
3443 assert!(location.gps.is_some());
3445 let gps = location.gps.as_ref().unwrap();
3446 assert!((gps.lat - 37.7749).abs() < 0.0001);
3447 assert!((gps.lon - (-122.4194)).abs() < 0.0001);
3448
3449 assert!(location.imu.is_some());
3451 let imu = location.imu.as_ref().unwrap();
3452 assert!((imu.roll - 1.5).abs() < 0.0001);
3453 assert!((imu.pitch - 2.5).abs() < 0.0001);
3454 assert!((imu.yaw - 3.5).abs() < 0.0001);
3455
3456 assert_eq!(sample.files.len(), 1);
3458 assert_eq!(sample.files[0].file_type(), "radar.pcd");
3459 assert_eq!(sample.files[0].url(), Some("https://example.com/radar.pcd"));
3460 }
3461
3462 #[test]
3463 fn test_sample_deserializes_gps_only() {
3464 use serde_json::json;
3465
3466 let sample_json = json!({
3468 "id": 456,
3469 "sensors": [
3470 {"gps": {"lat": 40.7128, "lon": -74.0060}}
3471 ]
3472 });
3473
3474 let sample: Sample = serde_json::from_value(sample_json).unwrap();
3475
3476 assert!(sample.location.is_some());
3477 let location = sample.location.as_ref().unwrap();
3478
3479 assert!(location.gps.is_some());
3480 assert!(location.imu.is_none());
3481
3482 let gps = location.gps.as_ref().unwrap();
3483 assert!((gps.lat - 40.7128).abs() < 0.0001);
3484 assert!((gps.lon - (-74.0060)).abs() < 0.0001);
3485 }
3486
3487 #[test]
3488 fn test_sample_deserializes_without_location() {
3489 use serde_json::json;
3490
3491 let sample_json = json!({
3493 "id": 789,
3494 "sensors": [
3495 {"radar.pcd": "https://example.com/radar.pcd"},
3496 {"lidar.pcd": "https://example.com/lidar.pcd"}
3497 ]
3498 });
3499
3500 let sample: Sample = serde_json::from_value(sample_json).unwrap();
3501
3502 assert!(sample.location.is_none());
3504
3505 assert_eq!(sample.files.len(), 2);
3507 }
3508
3509 #[test]
3511 fn test_label_deserialization_and_accessors() {
3512 use serde_json::json;
3513
3514 let label_json = json!({
3516 "id": 123,
3517 "dataset_id": 456,
3518 "index": 5,
3519 "name": "car"
3520 });
3521
3522 let label: Label = serde_json::from_value(label_json).unwrap();
3523 assert_eq!(label.id(), 123);
3524 assert_eq!(label.index(), 5);
3525 assert_eq!(label.name(), "car");
3526 assert_eq!(label.to_string(), "car");
3527 assert_eq!(format!("{}", label), "car");
3528
3529 let label_json = json!({
3531 "id": 1,
3532 "dataset_id": 100,
3533 "index": 0,
3534 "name": "person"
3535 });
3536
3537 let label: Label = serde_json::from_value(label_json).unwrap();
3538 assert_eq!(format!("{}", label), "person");
3539 }
3540
3541 #[test]
3543 fn test_annotation_serialization_with_mask_and_box() {
3544 let polygon = vec![vec![
3545 (0.0_f32, 0.0_f32),
3546 (1.0_f32, 0.0_f32),
3547 (1.0_f32, 1.0_f32),
3548 ]];
3549
3550 let mut annotation = Annotation::new();
3551 annotation.set_label(Some("test".to_string()));
3552 annotation.set_box2d(Some(Box2d::new(10.0, 20.0, 30.0, 40.0)));
3553 annotation.set_polygon(Some(Polygon::new(polygon)));
3554
3555 let mut sample = Sample::new();
3556 sample.annotations.push(annotation);
3557
3558 let json = serde_json::to_value(&sample).unwrap();
3559 let annotations = json
3560 .get("annotations")
3561 .and_then(|value| value.as_array())
3562 .expect("annotations serialized as array");
3563 assert_eq!(annotations.len(), 1);
3564
3565 let annotation_json = annotations[0].as_object().expect("annotation object");
3566 assert!(annotation_json.contains_key("box2d"));
3567 assert!(annotation_json.contains_key("polygon"));
3568 assert!(!annotation_json.contains_key("x"));
3569 assert!(
3570 annotation_json
3571 .get("polygon")
3572 .and_then(|value| value.as_array())
3573 .is_some()
3574 );
3575 }
3576
3577 #[test]
3578 fn test_frame_number_negative_one_deserializes_as_none() {
3579 let json = r#"{
3582 "uuid": "test-uuid",
3583 "frame_number": -1
3584 }"#;
3585
3586 let sample: Sample = serde_json::from_str(json).unwrap();
3587 assert_eq!(sample.frame_number, None);
3588 }
3589
3590 #[test]
3591 fn test_frame_number_positive_value_deserializes_correctly() {
3592 let json = r#"{
3594 "uuid": "test-uuid",
3595 "frame_number": 5
3596 }"#;
3597
3598 let sample: Sample = serde_json::from_str(json).unwrap();
3599 assert_eq!(sample.frame_number, Some(5));
3600 }
3601
3602 #[test]
3603 fn test_frame_number_null_deserializes_as_none() {
3604 let json = r#"{
3606 "uuid": "test-uuid",
3607 "frame_number": null
3608 }"#;
3609
3610 let sample: Sample = serde_json::from_str(json).unwrap();
3611 assert_eq!(sample.frame_number, None);
3612 }
3613
3614 #[test]
3615 fn test_frame_number_missing_deserializes_as_none() {
3616 let json = r#"{
3618 "uuid": "test-uuid"
3619 }"#;
3620
3621 let sample: Sample = serde_json::from_str(json).unwrap();
3622 assert_eq!(sample.frame_number, None);
3623 }
3624
3625 #[cfg(feature = "polars")]
3630 #[test]
3631 fn test_samples_dataframe_preserves_group_for_samples_without_annotations() {
3632 use polars::prelude::*;
3633
3634 let mut sample_with_ann = Sample::new();
3636 sample_with_ann.image_name = Some("annotated.jpg".to_string());
3637 sample_with_ann.group = Some("train".to_string());
3638 let mut annotation = Annotation::new();
3639 annotation.set_label(Some("car".to_string()));
3640 annotation.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
3641 annotation.set_name(Some("annotated".to_string()));
3642 sample_with_ann.annotations = vec![annotation];
3643
3644 let mut sample_no_ann = Sample::new();
3646 sample_no_ann.image_name = Some("unannotated.jpg".to_string());
3647 sample_no_ann.group = Some("val".to_string()); sample_no_ann.annotations = vec![]; let samples = vec![sample_with_ann, sample_no_ann];
3651
3652 let df = samples_dataframe(&samples).expect("Failed to create DataFrame");
3654
3655 assert_eq!(df.height(), 2, "Expected 2 rows (one per sample)");
3657
3658 let groups_col = df.column("group").expect("group column should exist");
3660 let groups_cast = groups_col.cast(&DataType::String).expect("cast to string");
3661 let groups = groups_cast.str().expect("as str");
3662
3663 let names_col = df.column("name").expect("name column should exist");
3665 let names_cast = names_col.cast(&DataType::String).expect("cast to string");
3666 let names = names_cast.str().expect("as str");
3667
3668 let mut found_unannotated = false;
3669 for idx in 0..df.height() {
3670 if let Some(name) = names.get(idx)
3671 && name == "unannotated"
3672 {
3673 found_unannotated = true;
3674 let group = groups.get(idx);
3675 assert_eq!(
3676 group,
3677 Some("val"),
3678 "CRITICAL: Sample 'unannotated' without annotations must have group 'val'"
3679 );
3680 }
3681 }
3682
3683 assert!(
3684 found_unannotated,
3685 "Did not find 'unannotated' sample in DataFrame - \
3686 this means samples without annotations are not being included"
3687 );
3688 }
3689
3690 #[cfg(feature = "polars")]
3691 #[test]
3692 fn test_samples_dataframe_includes_all_samples_even_without_annotations() {
3693 let mut sample1 = Sample::new();
3697 sample1.image_name = Some("with_ann.jpg".to_string());
3698 sample1.group = Some("train".to_string());
3699 let mut ann = Annotation::new();
3700 ann.set_label(Some("person".to_string()));
3701 ann.set_box2d(Some(Box2d::new(0.0, 0.0, 0.5, 0.5)));
3702 ann.set_name(Some("with_ann".to_string()));
3703 sample1.annotations = vec![ann];
3704
3705 let mut sample2 = Sample::new();
3706 sample2.image_name = Some("no_ann_train.jpg".to_string());
3707 sample2.group = Some("train".to_string());
3708 sample2.annotations = vec![];
3709
3710 let mut sample3 = Sample::new();
3711 sample3.image_name = Some("no_ann_val.jpg".to_string());
3712 sample3.group = Some("val".to_string());
3713 sample3.annotations = vec![];
3714
3715 let samples = vec![sample1, sample2, sample3];
3716
3717 let df = samples_dataframe(&samples).expect("Failed to create DataFrame");
3718
3719 assert_eq!(
3721 df.height(),
3722 3,
3723 "Expected 3 rows (samples without annotations should create one row each)"
3724 );
3725
3726 let groups_col = df.column("group").expect("group column");
3728 let groups_cast = groups_col.cast(&polars::prelude::DataType::String).unwrap();
3729 let groups = groups_cast.str().unwrap();
3730
3731 let mut train_count = 0;
3732 let mut val_count = 0;
3733
3734 for idx in 0..df.height() {
3735 match groups.get(idx) {
3736 Some("train") => train_count += 1,
3737 Some("val") => val_count += 1,
3738 other => panic!(
3739 "Unexpected group value at row {}: {:?}. \
3740 All samples should have their group preserved.",
3741 idx, other
3742 ),
3743 }
3744 }
3745
3746 assert_eq!(train_count, 2, "Expected 2 samples in 'train' group");
3747 assert_eq!(val_count, 1, "Expected 1 sample in 'val' group");
3748 }
3749
3750 #[cfg(feature = "polars")]
3751 #[test]
3752 fn test_samples_dataframe_group_is_not_null_for_samples_with_group() {
3753 let mut sample = Sample::new();
3757 sample.image_name = Some("test.jpg".to_string());
3758 sample.group = Some("test_group".to_string());
3759 sample.annotations = vec![];
3760
3761 let df = samples_dataframe(&[sample]).expect("Failed to create DataFrame");
3762
3763 let groups_col = df.column("group").expect("group column");
3764
3765 assert_eq!(
3767 groups_col.null_count(),
3768 0,
3769 "Sample with group='test_group' but no annotations has NULL group in DataFrame. \
3770 This is a bug in samples_dataframe - group must be preserved!"
3771 );
3772 }
3773
3774 #[cfg(feature = "polars")]
3775 #[test]
3776 fn test_samples_dataframe_group_consistent_across_all_rows_for_same_image() {
3777 use polars::prelude::*;
3778
3779 let mut sample = Sample::new();
3783 sample.image_name = Some("multi_ann.jpg".to_string());
3784 sample.group = Some("train".to_string());
3785
3786 let mut ann1 = Annotation::new();
3788 ann1.set_label(Some("car".to_string()));
3789 ann1.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
3790 ann1.set_name(Some("multi_ann".to_string()));
3791
3792 let mut ann2 = Annotation::new();
3793 ann2.set_label(Some("truck".to_string()));
3794 ann2.set_box2d(Some(Box2d::new(0.5, 0.6, 0.2, 0.2)));
3795 ann2.set_name(Some("multi_ann".to_string()));
3796
3797 let mut ann3 = Annotation::new();
3798 ann3.set_label(Some("bus".to_string()));
3799 ann3.set_box2d(Some(Box2d::new(0.7, 0.8, 0.1, 0.1)));
3800 ann3.set_name(Some("multi_ann".to_string()));
3801
3802 sample.annotations = vec![ann1, ann2, ann3];
3803
3804 let df = samples_dataframe(&[sample]).expect("Failed to create DataFrame");
3805
3806 assert_eq!(df.height(), 3, "Expected 3 rows (one per annotation)");
3808
3809 let groups_col = df.column("group").expect("group column");
3811 let groups_cast = groups_col.cast(&DataType::String).expect("cast to string");
3812 let groups = groups_cast.str().expect("as str");
3813
3814 assert_eq!(groups_col.null_count(), 0, "No rows should have null group");
3816
3817 for idx in 0..df.height() {
3819 let group = groups.get(idx);
3820 assert_eq!(
3821 group,
3822 Some("train"),
3823 "Row {} should have group 'train', got {:?}. \
3824 All rows for the same image must have identical group values.",
3825 idx,
3826 group
3827 );
3828 }
3829 }
3830
3831 #[cfg(feature = "polars")]
3832 #[test]
3833 fn test_samples_dataframe_lvis_columns() {
3834 let mut ann = Annotation::new();
3835 ann.set_name(Some("test".to_string()));
3836 ann.set_label(Some("person".to_string()));
3837 ann.set_label_index(Some(1));
3838 ann.set_iscrowd(Some(false));
3839 ann.set_category_frequency(Some("f".to_string()));
3840
3841 let sample = Sample {
3842 image_name: Some("test.jpg".to_string()),
3843 width: Some(640),
3844 height: Some(480),
3845 annotations: vec![ann],
3846 neg_label_indices: Some(vec![5, 12]),
3847 not_exhaustive_label_indices: Some(vec![3]),
3848 ..Default::default()
3849 };
3850
3851 let df = samples_dataframe(&[sample]).unwrap();
3852
3853 assert!(df.column("iscrowd").is_ok(), "iscrowd column missing");
3855 assert!(
3856 df.column("category_frequency").is_ok(),
3857 "category_frequency column missing"
3858 );
3859 assert!(
3860 df.column("neg_label_indices").is_ok(),
3861 "neg_label_indices column missing"
3862 );
3863 assert!(
3864 df.column("not_exhaustive_label_indices").is_ok(),
3865 "not_exhaustive_label_indices column missing"
3866 );
3867
3868 assert!(
3870 df.column("polygon").is_err(),
3871 "polygon column should be dropped (all null)"
3872 );
3873 assert!(
3874 df.column("box2d").is_err(),
3875 "box2d column should be dropped (all null)"
3876 );
3877 }
3878
3879 #[test]
3880 fn test_annotation_serialization_skips_lvis_fields() {
3881 let ann = Annotation::new();
3882 let json = serde_json::to_string(&ann).unwrap();
3883 assert!(
3884 !json.contains("iscrowd"),
3885 "iscrowd should be omitted when None"
3886 );
3887 assert!(
3888 !json.contains("category_frequency"),
3889 "category_frequency should be omitted when None"
3890 );
3891 }
3892
3893 #[test]
3894 fn test_sample_serialization_skips_lvis_fields() {
3895 let sample = Sample::new();
3896 let json = serde_json::to_string(&sample).unwrap();
3897 assert!(
3898 !json.contains("neg_label_indices"),
3899 "neg_label_indices should be omitted when None"
3900 );
3901 assert!(
3902 !json.contains("not_exhaustive_label_indices"),
3903 "not_exhaustive_label_indices should be omitted when None"
3904 );
3905 }
3906
3907 #[test]
3908 fn test_annotation_score_fields() {
3909 let mut ann = Annotation::default();
3910 assert!(ann.box2d_score.is_none());
3911 assert!(ann.polygon_score.is_none());
3912 assert!(ann.mask_score.is_none());
3913 ann.box2d_score = Some(0.95);
3914 ann.polygon_score = Some(0.87);
3915 ann.mask_score = Some(0.42);
3916 assert_eq!(ann.box2d_score, Some(0.95));
3917 assert_eq!(ann.polygon_score, Some(0.87));
3918 assert_eq!(ann.mask_score, Some(0.42));
3919 }
3920
3921 #[test]
3922 fn test_timing_struct() {
3923 let timing = Timing {
3924 load: Some(1_000_000),
3925 preprocess: Some(2_000_000),
3926 inference: Some(50_000_000),
3927 decode: Some(3_000_000),
3928 };
3929 assert_eq!(timing.inference, Some(50_000_000));
3930
3931 let default = Timing::default();
3932 assert!(default.load.is_none());
3933 }
3934
3935 #[test]
3936 fn test_sample_timing() {
3937 let mut sample = Sample::default();
3938 assert!(sample.timing.is_none());
3939 sample.timing = Some(Timing {
3940 load: Some(1_000_000),
3941 ..Default::default()
3942 });
3943 assert!(sample.timing.is_some());
3944 }
3945
3946 #[cfg(feature = "polars")]
3951 #[test]
3952 fn test_samples_dataframe_polygon_column() {
3953 let mut ann = Annotation::new();
3954 ann.set_name(Some("test".to_string()));
3955 ann.set_polygon(Some(Polygon::new(vec![vec![
3956 (0.1, 0.2),
3957 (0.3, 0.4),
3958 (0.5, 0.6),
3959 ]])));
3960
3961 let sample = Sample {
3962 image_name: Some("test.jpg".to_string()),
3963 annotations: vec![ann],
3964 ..Default::default()
3965 };
3966
3967 let df = samples_dataframe(&[sample]).unwrap();
3968
3969 assert!(df.column("polygon").is_ok(), "Should have polygon column");
3971
3972 if let Ok(mask_col) = df.column("mask") {
3975 assert_eq!(
3977 mask_col.dtype(),
3978 &polars::prelude::DataType::Binary,
3979 "mask column must be Binary type (PNG bytes), not float list"
3980 );
3981 }
3982 }
3983
3984 #[cfg(feature = "polars")]
3985 #[test]
3986 fn test_samples_dataframe_column_presence_drops_all_null() {
3987 let sample = Sample {
3989 image_name: Some("test.jpg".to_string()),
3990 ..Default::default()
3991 };
3992
3993 let df = samples_dataframe(&[sample]).unwrap();
3994
3995 assert!(df.column("name").is_ok(), "name column must always exist");
3997
3998 assert!(
4000 df.column("polygon").is_err(),
4001 "All-null polygon should be dropped"
4002 );
4003 assert!(
4004 df.column("box2d").is_err(),
4005 "All-null box2d should be dropped"
4006 );
4007 assert!(
4008 df.column("box3d").is_err(),
4009 "All-null box3d should be dropped"
4010 );
4011 assert!(
4012 df.column("mask").is_err(),
4013 "All-null mask should be dropped"
4014 );
4015 assert!(
4016 df.column("box2d_score").is_err(),
4017 "All-null score columns should be dropped"
4018 );
4019 assert!(
4020 df.column("timing").is_err(),
4021 "All-null timing should be dropped"
4022 );
4023 }
4024
4025 #[cfg(feature = "polars")]
4026 #[test]
4027 fn test_samples_dataframe_score_columns() {
4028 let mut ann = Annotation::new();
4029 ann.set_name(Some("test".to_string()));
4030 ann.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
4031 ann.set_box2d_score(Some(0.95));
4032 ann.set_polygon(Some(Polygon::new(vec![vec![
4033 (0.0, 0.0),
4034 (1.0, 0.0),
4035 (1.0, 1.0),
4036 ]])));
4037 ann.set_polygon_score(Some(0.87));
4038
4039 let sample = Sample {
4040 image_name: Some("test.jpg".to_string()),
4041 annotations: vec![ann],
4042 ..Default::default()
4043 };
4044
4045 let df = samples_dataframe(&[sample]).unwrap();
4046
4047 assert!(
4049 df.column("box2d_score").is_ok(),
4050 "box2d_score column missing"
4051 );
4052 assert!(
4053 df.column("polygon_score").is_ok(),
4054 "polygon_score column missing"
4055 );
4056
4057 assert!(
4059 df.column("box3d_score").is_err(),
4060 "box3d_score should be dropped (all null)"
4061 );
4062 assert!(
4063 df.column("mask_score").is_err(),
4064 "mask_score should be dropped (all null)"
4065 );
4066
4067 let box2d_scores = df.column("box2d_score").unwrap();
4069 let val = box2d_scores.f32().unwrap().get(0);
4070 assert_eq!(val, Some(0.95));
4071 }
4072
4073 #[cfg(feature = "polars")]
4074 #[test]
4075 fn test_samples_dataframe_timing_column() {
4076 let mut ann = Annotation::new();
4077 ann.set_name(Some("test".to_string()));
4078 ann.set_label(Some("person".to_string()));
4079
4080 let sample = Sample {
4081 image_name: Some("test.jpg".to_string()),
4082 annotations: vec![ann],
4083 timing: Some(Timing {
4084 load: Some(1_000_000),
4085 preprocess: Some(2_000_000),
4086 inference: Some(50_000_000),
4087 decode: Some(3_000_000),
4088 }),
4089 ..Default::default()
4090 };
4091
4092 let df = samples_dataframe(&[sample]).unwrap();
4093
4094 assert!(df.column("timing").is_ok(), "timing column missing");
4096
4097 let timing_col = df.column("timing").unwrap();
4099 assert!(
4100 matches!(timing_col.dtype(), polars::prelude::DataType::Struct(..)),
4101 "timing column should be Struct type, got {:?}",
4102 timing_col.dtype()
4103 );
4104 }
4105
4106 #[cfg(feature = "polars")]
4107 #[test]
4108 fn test_samples_dataframe_mask_binary_column() {
4109 let mut ann = Annotation::new();
4110 ann.set_name(Some("test".to_string()));
4111 let pixels = vec![0u8, 255, 128, 64];
4113 let mask_data = MaskData::encode(&pixels, 2, 2, 8).unwrap();
4114 ann.set_mask(Some(mask_data));
4115
4116 let sample = Sample {
4117 image_name: Some("test.jpg".to_string()),
4118 annotations: vec![ann],
4119 ..Default::default()
4120 };
4121
4122 let df = samples_dataframe(&[sample]).unwrap();
4123
4124 let mask_col = df.column("mask").unwrap();
4126 assert_eq!(
4127 mask_col.dtype(),
4128 &polars::prelude::DataType::Binary,
4129 "mask column should be Binary"
4130 );
4131 assert_eq!(mask_col.null_count(), 0, "mask value should not be null");
4132 }
4133
4134 #[test]
4139 fn test_annotation_type_seg_alias() {
4140 assert_eq!(
4141 AnnotationType::try_from("seg").unwrap(),
4142 AnnotationType::Polygon,
4143 "\"seg\" should map to Polygon for server round-trip"
4144 );
4145 }
4146
4147 #[cfg(feature = "polars")]
4152 #[test]
4153 fn test_samples_dataframe_timing_partial() {
4154 let mut ann = Annotation::new();
4156 ann.set_name(Some("test".to_string()));
4157 ann.set_label(Some("person".to_string()));
4158
4159 let sample = Sample {
4160 image_name: Some("test.jpg".to_string()),
4161 annotations: vec![ann],
4162 timing: Some(Timing {
4163 load: Some(1000),
4164 ..Default::default()
4165 }),
4166 ..Default::default()
4167 };
4168
4169 let df = samples_dataframe(&[sample]).unwrap();
4170
4171 assert!(
4173 df.column("timing").is_ok(),
4174 "timing column should be present when partial data exists"
4175 );
4176 }
4177
4178 #[cfg(feature = "polars")]
4179 #[test]
4180 fn test_samples_dataframe_timing_all_none_omitted() {
4181 let mut ann = Annotation::new();
4183 ann.set_name(Some("test".to_string()));
4184 ann.set_label(Some("person".to_string()));
4185
4186 let sample = Sample {
4187 image_name: Some("test.jpg".to_string()),
4188 annotations: vec![ann],
4189 timing: None,
4190 ..Default::default()
4191 };
4192
4193 let df = samples_dataframe(&[sample]).unwrap();
4194
4195 assert!(
4196 df.column("timing").is_err(),
4197 "timing column should be omitted when all samples have timing: None"
4198 );
4199 }
4200
4201 #[cfg(feature = "polars")]
4206 #[test]
4207 fn test_samples_dataframe_score_zero_survives() {
4208 let mut ann = Annotation::new();
4210 ann.set_name(Some("test".to_string()));
4211 ann.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
4212 ann.set_box2d_score(Some(0.0));
4213
4214 let sample = Sample {
4215 image_name: Some("test.jpg".to_string()),
4216 annotations: vec![ann],
4217 ..Default::default()
4218 };
4219
4220 let df = samples_dataframe(&[sample]).unwrap();
4221
4222 let scores = df.column("box2d_score").unwrap();
4223 let val = scores.f32().unwrap().get(0);
4224 assert_eq!(val, Some(0.0), "score of 0.0 should survive as non-null");
4225 }
4226
4227 #[cfg(feature = "polars")]
4228 #[test]
4229 fn test_samples_dataframe_score_one_survives() {
4230 let mut ann = Annotation::new();
4231 ann.set_name(Some("test".to_string()));
4232 ann.set_box2d(Some(Box2d::new(0.1, 0.2, 0.3, 0.4)));
4233 ann.set_box2d_score(Some(1.0));
4234
4235 let sample = Sample {
4236 image_name: Some("test.jpg".to_string()),
4237 annotations: vec![ann],
4238 ..Default::default()
4239 };
4240
4241 let df = samples_dataframe(&[sample]).unwrap();
4242
4243 let scores = df.column("box2d_score").unwrap();
4244 let val = scores.f32().unwrap().get(0);
4245 assert_eq!(val, Some(1.0), "score of 1.0 should survive as non-null");
4246 }
4247}