1use crate::configs::{self, deserialize_dshape, DimName, QuantTuple};
47use crate::{ConfigOutput, ConfigOutputs, DecoderError, DecoderResult};
48use serde::{Deserialize, Serialize};
49
50pub const MAX_SUPPORTED_SCHEMA_VERSION: u32 = 2;
54
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
61pub struct SchemaV2 {
62 pub schema_version: u32,
64
65 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub input: Option<InputSpec>,
72
73 #[serde(default, skip_serializing_if = "Vec::is_empty")]
76 pub outputs: Vec<LogicalOutput>,
77
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub nms: Option<NmsMode>,
81
82 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub decoder_version: Option<DecoderVersion>,
88}
89
90impl Default for SchemaV2 {
91 fn default() -> Self {
92 Self {
93 schema_version: 2,
94 input: None,
95 outputs: Vec::new(),
96 nms: None,
97 decoder_version: None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub struct InputSpec {
105 pub shape: Vec<usize>,
107
108 #[serde(
111 default,
112 deserialize_with = "deserialize_dshape",
113 skip_serializing_if = "Vec::is_empty"
114 )]
115 pub dshape: Vec<(DimName, usize)>,
116
117 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub cameraadaptor: Option<String>,
122}
123
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub struct LogicalOutput {
133 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub name: Option<String>,
136
137 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
144 pub type_: Option<LogicalType>,
145
146 pub shape: Vec<usize>,
149
150 #[serde(
152 default,
153 deserialize_with = "deserialize_dshape",
154 skip_serializing_if = "Vec::is_empty"
155 )]
156 pub dshape: Vec<(DimName, usize)>,
157
158 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub decoder: Option<DecoderKind>,
162
163 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub encoding: Option<BoxEncoding>,
166
167 #[serde(default, skip_serializing_if = "Option::is_none")]
169 pub score_format: Option<ScoreFormat>,
170
171 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub normalized: Option<bool>,
177
178 #[serde(default, skip_serializing_if = "Option::is_none")]
181 pub anchors: Option<Vec<[f32; 2]>>,
182
183 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub stride: Option<Stride>,
188
189 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub dtype: Option<DType>,
193
194 #[serde(default, skip_serializing_if = "Option::is_none")]
197 pub quantization: Option<Quantization>,
198
199 #[serde(default, skip_serializing_if = "Vec::is_empty")]
203 pub outputs: Vec<PhysicalOutput>,
204
205 #[serde(default, skip_serializing_if = "Option::is_none")]
213 pub activation_applied: Option<Activation>,
214
215 #[serde(default, skip_serializing_if = "Option::is_none")]
224 pub activation_required: Option<Activation>,
225}
226
227impl LogicalOutput {
228 pub fn is_split(&self) -> bool {
231 !self.outputs.is_empty()
232 }
233}
234
235#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
241pub struct PhysicalOutput {
242 pub name: String,
246
247 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
254 pub type_: Option<PhysicalType>,
255
256 pub shape: Vec<usize>,
258
259 #[serde(
262 default,
263 deserialize_with = "deserialize_dshape",
264 skip_serializing_if = "Vec::is_empty"
265 )]
266 pub dshape: Vec<(DimName, usize)>,
267
268 pub dtype: DType,
270
271 #[serde(default, skip_serializing_if = "Option::is_none")]
274 pub quantization: Option<Quantization>,
275
276 #[serde(default, skip_serializing_if = "Option::is_none")]
279 pub stride: Option<Stride>,
280
281 #[serde(default, skip_serializing_if = "Option::is_none")]
284 pub scale_index: Option<usize>,
285
286 #[serde(default, skip_serializing_if = "Option::is_none")]
290 pub activation_applied: Option<Activation>,
291
292 #[serde(default, skip_serializing_if = "Option::is_none")]
295 pub activation_required: Option<Activation>,
296}
297
298#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
304pub struct Quantization {
305 #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
308 pub scale: Vec<f32>,
309
310 #[serde(
314 default,
315 deserialize_with = "deserialize_opt_scalar_or_vec_i32",
316 skip_serializing_if = "Option::is_none"
317 )]
318 pub zero_point: Option<Vec<i32>>,
319
320 #[serde(default, skip_serializing_if = "Option::is_none")]
323 pub axis: Option<usize>,
324
325 #[serde(default, skip_serializing_if = "Option::is_none")]
329 pub dtype: Option<DType>,
330}
331
332impl Quantization {
333 pub fn is_per_tensor(&self) -> bool {
335 self.scale.len() == 1
336 }
337
338 pub fn is_per_channel(&self) -> bool {
340 self.scale.len() > 1
341 }
342
343 pub fn is_symmetric(&self) -> bool {
345 match &self.zero_point {
346 None => true,
347 Some(zps) => zps.iter().all(|&z| z == 0),
348 }
349 }
350
351 pub fn zero_point_at(&self, channel: usize) -> i32 {
354 match &self.zero_point {
355 None => 0,
356 Some(zps) if zps.len() == 1 => zps[0],
357 Some(zps) => zps.get(channel).copied().unwrap_or(0),
358 }
359 }
360
361 pub fn scale_at(&self, channel: usize) -> f32 {
363 if self.scale.len() == 1 {
364 self.scale[0]
365 } else {
366 self.scale.get(channel).copied().unwrap_or(0.0)
367 }
368 }
369}
370
371impl TryFrom<&Quantization> for edgefirst_tensor::Quantization {
381 type Error = edgefirst_tensor::Error;
382
383 fn try_from(q: &Quantization) -> Result<Self, Self::Error> {
384 match (q.scale.as_slice(), q.zero_point.as_deref(), q.axis) {
385 ([scale], None, None) => Ok(Self::per_tensor_symmetric(*scale)),
387 ([scale], Some([zp]), None) => Ok(Self::per_tensor(*scale, *zp)),
389 ([scale], Some([zp]), Some(_)) => Ok(Self::per_tensor(*scale, *zp)),
391 ([scale], None, Some(_)) => Ok(Self::per_tensor_symmetric(*scale)),
392 (scales, None, Some(axis)) if scales.len() > 1 => {
394 Self::per_channel_symmetric(scales.to_vec(), axis)
395 }
396 (scales, Some(zps), Some(axis)) if scales.len() > 1 => {
397 Self::per_channel(scales.to_vec(), zps.to_vec(), axis)
398 }
399 (scales, _, None) if scales.len() > 1 => {
401 Err(edgefirst_tensor::Error::QuantizationInvalid {
402 field: "axis",
403 expected: "Some(axis) for per-channel".into(),
404 got: "None".into(),
405 })
406 }
407 _ => Err(edgefirst_tensor::Error::QuantizationInvalid {
408 field: "scale",
409 expected: "non-empty".into(),
410 got: format!("len={}", q.scale.len()),
411 }),
412 }
413 }
414}
415
416fn deserialize_scalar_or_vec_f32<'de, D>(de: D) -> Result<Vec<f32>, D::Error>
418where
419 D: serde::Deserializer<'de>,
420{
421 #[derive(Deserialize)]
422 #[serde(untagged)]
423 enum OneOrMany {
424 One(f32),
425 Many(Vec<f32>),
426 }
427 match OneOrMany::deserialize(de)? {
428 OneOrMany::One(v) => Ok(vec![v]),
429 OneOrMany::Many(vs) => Ok(vs),
430 }
431}
432
433fn deserialize_opt_scalar_or_vec_i32<'de, D>(de: D) -> Result<Option<Vec<i32>>, D::Error>
435where
436 D: serde::Deserializer<'de>,
437{
438 #[derive(Deserialize)]
439 #[serde(untagged)]
440 enum OneOrMany {
441 One(i32),
442 Many(Vec<i32>),
443 }
444 match Option::<OneOrMany>::deserialize(de)? {
445 None => Ok(None),
446 Some(OneOrMany::One(v)) => Ok(Some(vec![v])),
447 Some(OneOrMany::Many(vs)) => Ok(Some(vs)),
448 }
449}
450
451#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
454#[serde(untagged)]
455pub enum Stride {
456 Square(u32),
457 Rect([u32; 2]),
458}
459
460impl Stride {
461 pub fn x(self) -> u32 {
463 match self {
464 Stride::Square(s) => s,
465 Stride::Rect([sx, _]) => sx,
466 }
467 }
468
469 pub fn y(self) -> u32 {
471 match self {
472 Stride::Square(s) => s,
473 Stride::Rect([_, sy]) => sy,
474 }
475 }
476}
477
478#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
480#[serde(rename_all = "snake_case")]
481pub enum LogicalType {
482 Boxes,
484 Scores,
486 Objectness,
488 Classes,
490 MaskCoefs,
492 Protos,
494 Landmarks,
496 Detections,
498 Segmentation,
500 Masks,
502 Detection,
504}
505
506#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
512#[serde(rename_all = "snake_case")]
513pub enum PhysicalType {
514 Boxes,
515 Scores,
516 Objectness,
517 Classes,
518 MaskCoefs,
519 Protos,
520 Landmarks,
521 Detections,
522 Segmentation,
523 Masks,
524 Detection,
525 BoxesXy,
527 BoxesWh,
529}
530
531#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
533#[serde(rename_all = "snake_case")]
534pub enum BoxEncoding {
535 Dfl,
538 #[serde(alias = "ltrb")]
541 Direct,
542 Anchor,
545}
546
547#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
549#[serde(rename_all = "snake_case")]
550pub enum ScoreFormat {
551 PerClass,
554 ObjXClass,
558}
559
560#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
562#[serde(rename_all = "snake_case")]
563pub enum Activation {
564 Sigmoid,
565 Softmax,
566 Tanh,
567}
568
569#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
571pub enum DecoderKind {
572 #[serde(rename = "modelpack")]
574 ModelPack,
575 #[serde(rename = "ultralytics")]
577 Ultralytics,
578}
579
580#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
582#[serde(rename_all = "snake_case")]
583pub enum DecoderVersion {
584 Yolov5,
585 Yolov8,
586 Yolo11,
587 Yolo26,
588}
589
590impl DecoderVersion {
591 pub fn is_end_to_end(self) -> bool {
593 matches!(self, DecoderVersion::Yolo26)
594 }
595}
596
597#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
599#[serde(rename_all = "snake_case")]
600pub enum NmsMode {
601 ClassAgnostic,
603 ClassAware,
606}
607
608#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
610#[serde(rename_all = "snake_case")]
611pub enum DType {
612 Int8,
613 Uint8,
614 Int16,
615 Uint16,
616 Int32,
617 Uint32,
618 Float16,
619 Float32,
620}
621
622impl DType {
623 pub fn size_bytes(self) -> usize {
625 match self {
626 DType::Int8 | DType::Uint8 => 1,
627 DType::Int16 | DType::Uint16 | DType::Float16 => 2,
628 DType::Int32 | DType::Uint32 | DType::Float32 => 4,
629 }
630 }
631
632 pub fn is_integer(self) -> bool {
634 matches!(
635 self,
636 DType::Int8
637 | DType::Uint8
638 | DType::Int16
639 | DType::Uint16
640 | DType::Int32
641 | DType::Uint32
642 )
643 }
644
645 pub fn is_float(self) -> bool {
647 matches!(self, DType::Float16 | DType::Float32)
648 }
649}
650
651impl SchemaV2 {
656 pub fn parse_json(s: &str) -> DecoderResult<Self> {
664 let value: serde_json::Value = serde_json::from_str(s)?;
665 Self::from_json_value(value)
666 }
667
668 pub fn parse_yaml(s: &str) -> DecoderResult<Self> {
672 let value: serde_yaml::Value = serde_yaml::from_str(s)?;
673 let json = serde_json::to_value(value)
674 .map_err(|e| DecoderError::InvalidConfig(format!("yaml→json bridge failed: {e}")))?;
675 Self::from_json_value(json)
676 }
677
678 pub fn parse_file(path: impl AsRef<std::path::Path>) -> DecoderResult<Self> {
682 let path = path.as_ref();
683 let content = std::fs::read_to_string(path)
684 .map_err(|e| DecoderError::InvalidConfig(format!("read {}: {e}", path.display())))?;
685 let ext = path
686 .extension()
687 .and_then(|e| e.to_str())
688 .map(str::to_ascii_lowercase);
689 match ext.as_deref() {
690 Some("json") => Self::parse_json(&content),
691 Some("yaml") | Some("yml") => Self::parse_yaml(&content),
692 _ => Self::parse_json(&content).or_else(|_| Self::parse_yaml(&content)),
693 }
694 }
695
696 pub fn from_json_value(value: serde_json::Value) -> DecoderResult<Self> {
699 let version = value
700 .get("schema_version")
701 .and_then(|v| v.as_u64())
702 .map(|v| v as u32)
703 .unwrap_or(1);
704
705 if version > MAX_SUPPORTED_SCHEMA_VERSION {
706 return Err(DecoderError::NotSupported(format!(
707 "schema_version {version} is not supported by this HAL \
708 (maximum supported version is {MAX_SUPPORTED_SCHEMA_VERSION}); \
709 upgrade the HAL or downgrade the metadata"
710 )));
711 }
712
713 if version >= 2 {
714 serde_json::from_value(value).map_err(DecoderError::Json)
715 } else {
716 let v1: ConfigOutputs = serde_json::from_value(value).map_err(DecoderError::Json)?;
717 Self::from_v1(&v1)
718 }
719 }
720
721 pub fn from_v1(v1: &ConfigOutputs) -> DecoderResult<Self> {
737 let outputs = v1
738 .outputs
739 .iter()
740 .map(logical_from_v1)
741 .collect::<DecoderResult<Vec<_>>>()?;
742 Ok(SchemaV2 {
743 schema_version: 2,
744 input: None,
745 outputs,
746 nms: v1.nms.as_ref().map(NmsMode::from_v1),
747 decoder_version: v1.decoder_version.as_ref().map(DecoderVersion::from_v1),
748 })
749 }
750}
751
752impl SchemaV2 {
753 pub fn to_legacy_config_outputs(&self) -> DecoderResult<ConfigOutputs> {
780 let mut outputs = Vec::with_capacity(self.outputs.len());
781 for logical in &self.outputs {
782 if logical.type_.is_none() {
788 continue;
789 }
790 if logical.type_ == Some(LogicalType::Boxes)
797 && logical.encoding == Some(BoxEncoding::Dfl)
798 && logical.outputs.is_empty()
799 {
800 return Err(DecoderError::NotSupported(format!(
801 "`boxes` output `{}` has `encoding: dfl` on a flat \
802 logical (no per-scale children); the HAL's DFL \
803 decode kernel only runs inside the per-scale merge \
804 path. Split the boxes output into per-FPN-level \
805 children (Hailo convention) or pre-decode to 4 \
806 channels in the model graph (TFLite convention).",
807 logical.name.as_deref().unwrap_or("<anonymous>"),
808 )));
809 }
810 if let Some(q) = &logical.quantization {
811 if q.is_per_channel() {
812 return Err(DecoderError::NotSupported(format!(
813 "logical `{}` uses per-channel quantization \
814 (axis {:?}, {} scales); the v1 decoder only \
815 supports per-tensor quantization",
816 logical.name.as_deref().unwrap_or("<anonymous>"),
817 q.axis,
818 q.scale.len(),
819 )));
820 }
821 }
822 outputs.push(logical_to_legacy_config_output(logical)?);
823 }
824 Ok(ConfigOutputs {
825 outputs,
826 nms: self.nms.map(NmsMode::to_v1),
827 decoder_version: self.decoder_version.map(|v| v.to_v1()),
828 })
829 }
830
831 pub fn validate(&self) -> DecoderResult<()> {
849 if self.schema_version == 0 || self.schema_version > MAX_SUPPORTED_SCHEMA_VERSION {
850 return Err(DecoderError::InvalidConfig(format!(
851 "schema_version {} outside supported range [1, {MAX_SUPPORTED_SCHEMA_VERSION}]",
852 self.schema_version
853 )));
854 }
855
856 for logical in &self.outputs {
857 validate_logical(logical)?;
858 }
859
860 Ok(())
861 }
862}
863
864fn validate_logical(logical: &LogicalOutput) -> DecoderResult<()> {
865 if logical.outputs.is_empty() {
866 return Ok(());
867 }
868
869 for child in &logical.outputs {
871 if child.name.is_empty() {
872 return Err(DecoderError::InvalidConfig(format!(
873 "physical child of logical `{}` is missing `name`; name is \
874 required for tensor binding",
875 logical.name.as_deref().unwrap_or("<anonymous>")
876 )));
877 }
878 }
879
880 for (i, a) in logical.outputs.iter().enumerate() {
889 for b in &logical.outputs[i + 1..] {
890 let (Some(ta), Some(tb)) = (a.type_, b.type_) else {
891 continue;
892 };
893 if a.shape == b.shape && ta == tb {
894 return Err(DecoderError::InvalidConfig(format!(
895 "physical children `{}` and `{}` share shape {:?} and \
896 type; tensor binding cannot be resolved",
897 a.name, b.name, a.shape
898 )));
899 }
900 }
901 }
902
903 let strided: Vec<_> = logical.outputs.iter().map(|c| c.stride.is_some()).collect();
907 let all_strided = strided.iter().all(|&b| b);
908 let none_strided = strided.iter().all(|&b| !b);
909 if !(all_strided || none_strided) {
910 return Err(DecoderError::InvalidConfig(format!(
911 "logical `{}` mixes per-scale children (with stride) and \
912 channel sub-split children (without stride); decomposition \
913 must be uniform",
914 logical.name.as_deref().unwrap_or("<anonymous>")
915 )));
916 }
917
918 if logical.type_ == Some(LogicalType::Boxes) && logical.encoding == Some(BoxEncoding::Dfl) {
921 for child in &logical.outputs {
922 if let Some(feat) = last_feature_axis(child) {
923 if feat % 4 != 0 {
924 return Err(DecoderError::InvalidConfig(format!(
925 "DFL boxes child `{}` feature axis {feat} is not \
926 divisible by 4 (reg_max×4)",
927 child.name
928 )));
929 }
930 }
931 }
932 }
933
934 Ok(())
935}
936
937pub(crate) fn last_feature_axis(child: &PhysicalOutput) -> Option<usize> {
940 for (name, size) in &child.dshape {
943 if matches!(
944 name,
945 DimName::NumFeatures
946 | DimName::NumClasses
947 | DimName::NumProtos
948 | DimName::BoxCoords
949 | DimName::NumAnchorsXFeatures
950 ) {
951 return Some(*size);
952 }
953 }
954 child.shape.last().copied()
955}
956
957fn quantization_from_v1(q: Option<QuantTuple>) -> Option<Quantization> {
958 q.map(|QuantTuple(scale, zp)| Quantization {
959 scale: vec![scale],
960 zero_point: Some(vec![zp]),
961 axis: None,
962 dtype: None,
963 })
964}
965
966fn logical_from_v1(v1: &ConfigOutput) -> DecoderResult<LogicalOutput> {
967 match v1 {
968 ConfigOutput::Detection(d) => {
969 let encoding = match (d.decoder, d.anchors.is_some()) {
975 (configs::DecoderType::ModelPack, true) => Some(BoxEncoding::Anchor),
976 (configs::DecoderType::Ultralytics, _) => Some(BoxEncoding::Direct),
977 (configs::DecoderType::ModelPack, false) => None,
980 };
981 Ok(LogicalOutput {
982 name: None,
983 type_: Some(LogicalType::Detection),
984 shape: d.shape.clone(),
985 dshape: d.dshape.clone(),
986 decoder: Some(DecoderKind::from_v1(d.decoder)),
987 encoding,
988 score_format: None,
989 normalized: d.normalized,
990 anchors: d.anchors.clone(),
991 stride: None,
992 dtype: None,
993 quantization: quantization_from_v1(d.quantization),
994 outputs: Vec::new(),
995 activation_applied: None,
996 activation_required: None,
997 })
998 }
999 ConfigOutput::Boxes(b) => Ok(LogicalOutput {
1000 name: None,
1001 type_: Some(LogicalType::Boxes),
1002 shape: b.shape.clone(),
1003 dshape: b.dshape.clone(),
1004 decoder: Some(DecoderKind::from_v1(b.decoder)),
1005 encoding: Some(BoxEncoding::Direct),
1009 score_format: None,
1010 normalized: b.normalized,
1011 anchors: None,
1012 stride: None,
1013 dtype: None,
1014 quantization: quantization_from_v1(b.quantization),
1015 outputs: Vec::new(),
1016 activation_applied: None,
1017 activation_required: None,
1018 }),
1019 ConfigOutput::Scores(s) => Ok(LogicalOutput {
1020 name: None,
1021 type_: Some(LogicalType::Scores),
1022 shape: s.shape.clone(),
1023 dshape: s.dshape.clone(),
1024 decoder: Some(DecoderKind::from_v1(s.decoder)),
1025 encoding: None,
1026 score_format: Some(ScoreFormat::PerClass),
1030 normalized: None,
1031 anchors: None,
1032 stride: None,
1033 dtype: None,
1034 quantization: quantization_from_v1(s.quantization),
1035 outputs: Vec::new(),
1036 activation_applied: None,
1037 activation_required: None,
1038 }),
1039 ConfigOutput::Protos(p) => Ok(LogicalOutput {
1040 name: None,
1041 type_: Some(LogicalType::Protos),
1042 shape: p.shape.clone(),
1043 dshape: p.dshape.clone(),
1044 decoder: Some(DecoderKind::from_v1(p.decoder)),
1046 encoding: None,
1047 score_format: None,
1048 normalized: None,
1049 anchors: None,
1050 stride: None,
1051 dtype: None,
1052 quantization: quantization_from_v1(p.quantization),
1053 outputs: Vec::new(),
1054 activation_applied: None,
1055 activation_required: None,
1056 }),
1057 ConfigOutput::MaskCoefficients(m) => Ok(LogicalOutput {
1058 name: None,
1059 type_: Some(LogicalType::MaskCoefs),
1060 shape: m.shape.clone(),
1061 dshape: m.dshape.clone(),
1062 decoder: Some(DecoderKind::from_v1(m.decoder)),
1063 encoding: None,
1064 score_format: None,
1065 normalized: None,
1066 anchors: None,
1067 stride: None,
1068 dtype: None,
1069 quantization: quantization_from_v1(m.quantization),
1070 outputs: Vec::new(),
1071 activation_applied: None,
1072 activation_required: None,
1073 }),
1074 ConfigOutput::Segmentation(seg) => Ok(LogicalOutput {
1075 name: None,
1076 type_: Some(LogicalType::Segmentation),
1077 shape: seg.shape.clone(),
1078 dshape: seg.dshape.clone(),
1079 decoder: Some(DecoderKind::from_v1(seg.decoder)),
1080 encoding: None,
1081 score_format: None,
1082 normalized: None,
1083 anchors: None,
1084 stride: None,
1085 dtype: None,
1086 quantization: quantization_from_v1(seg.quantization),
1087 outputs: Vec::new(),
1088 activation_applied: None,
1089 activation_required: None,
1090 }),
1091 ConfigOutput::Mask(m) => Ok(LogicalOutput {
1092 name: None,
1093 type_: Some(LogicalType::Masks),
1094 shape: m.shape.clone(),
1095 dshape: m.dshape.clone(),
1096 decoder: Some(DecoderKind::from_v1(m.decoder)),
1097 encoding: None,
1098 score_format: None,
1099 normalized: None,
1100 anchors: None,
1101 stride: None,
1102 dtype: None,
1103 quantization: quantization_from_v1(m.quantization),
1104 outputs: Vec::new(),
1105 activation_applied: None,
1106 activation_required: None,
1107 }),
1108 ConfigOutput::Classes(c) => Ok(LogicalOutput {
1109 name: None,
1110 type_: Some(LogicalType::Classes),
1111 shape: c.shape.clone(),
1112 dshape: c.dshape.clone(),
1113 decoder: Some(DecoderKind::from_v1(c.decoder)),
1114 encoding: None,
1115 score_format: None,
1116 normalized: None,
1117 anchors: None,
1118 stride: None,
1119 dtype: None,
1120 quantization: quantization_from_v1(c.quantization),
1121 outputs: Vec::new(),
1122 activation_applied: None,
1123 activation_required: None,
1124 }),
1125 }
1126}
1127
1128impl DecoderKind {
1129 pub fn from_v1(v: configs::DecoderType) -> Self {
1131 match v {
1132 configs::DecoderType::ModelPack => DecoderKind::ModelPack,
1133 configs::DecoderType::Ultralytics => DecoderKind::Ultralytics,
1134 }
1135 }
1136
1137 pub fn to_v1(self) -> configs::DecoderType {
1139 match self {
1140 DecoderKind::ModelPack => configs::DecoderType::ModelPack,
1141 DecoderKind::Ultralytics => configs::DecoderType::Ultralytics,
1142 }
1143 }
1144}
1145
1146impl DecoderVersion {
1147 pub fn from_v1(v: &configs::DecoderVersion) -> Self {
1149 match v {
1150 configs::DecoderVersion::Yolov5 => DecoderVersion::Yolov5,
1151 configs::DecoderVersion::Yolov8 => DecoderVersion::Yolov8,
1152 configs::DecoderVersion::Yolo11 => DecoderVersion::Yolo11,
1153 configs::DecoderVersion::Yolo26 => DecoderVersion::Yolo26,
1154 }
1155 }
1156
1157 pub fn to_v1(self) -> configs::DecoderVersion {
1159 match self {
1160 DecoderVersion::Yolov5 => configs::DecoderVersion::Yolov5,
1161 DecoderVersion::Yolov8 => configs::DecoderVersion::Yolov8,
1162 DecoderVersion::Yolo11 => configs::DecoderVersion::Yolo11,
1163 DecoderVersion::Yolo26 => configs::DecoderVersion::Yolo26,
1164 }
1165 }
1166}
1167
1168impl NmsMode {
1169 pub fn from_v1(v: &configs::Nms) -> Self {
1171 match v {
1172 configs::Nms::Auto | configs::Nms::ClassAgnostic => NmsMode::ClassAgnostic,
1173 configs::Nms::ClassAware => NmsMode::ClassAware,
1174 }
1175 }
1176
1177 pub fn to_v1(self) -> configs::Nms {
1179 match self {
1180 NmsMode::ClassAgnostic => configs::Nms::ClassAgnostic,
1181 NmsMode::ClassAware => configs::Nms::ClassAware,
1182 }
1183 }
1184}
1185
1186fn quantization_to_legacy(q: &Quantization) -> DecoderResult<QuantTuple> {
1189 if q.is_per_channel() {
1190 return Err(DecoderError::NotSupported(
1191 "per-channel quantization cannot be expressed as a v1 QuantTuple".into(),
1192 ));
1193 }
1194 let scale = *q.scale.first().unwrap_or(&0.0);
1195 let zp = q.zero_point_at(0);
1196 Ok(QuantTuple(scale, zp))
1197}
1198
1199pub(crate) fn squeeze_padding_dims(
1205 shape: Vec<usize>,
1206 dshape: Vec<(DimName, usize)>,
1207) -> (Vec<usize>, Vec<(DimName, usize)>) {
1208 if dshape.is_empty() {
1212 return (shape, dshape);
1213 }
1214 let keep: Vec<bool> = dshape
1215 .iter()
1216 .map(|(n, _)| !matches!(n, DimName::Padding))
1217 .collect();
1218 let shape = shape
1219 .into_iter()
1220 .zip(keep.iter())
1221 .filter_map(|(s, &k)| k.then_some(s))
1222 .collect();
1223 let dshape = dshape
1224 .into_iter()
1225 .zip(keep.iter())
1226 .filter_map(|(d, &k)| k.then_some(d))
1227 .collect();
1228 (shape, dshape)
1229}
1230
1231pub(crate) fn padding_axes(dshape: &[(DimName, usize)]) -> Vec<usize> {
1236 let mut v: Vec<usize> = dshape
1237 .iter()
1238 .enumerate()
1239 .filter_map(|(i, (n, _))| matches!(n, DimName::Padding).then_some(i))
1240 .collect();
1241 v.sort_by(|a, b| b.cmp(a));
1242 v
1243}
1244
1245fn logical_to_legacy_config_output(logical: &LogicalOutput) -> DecoderResult<ConfigOutput> {
1246 let decoder = logical
1247 .decoder
1248 .map(|d| d.to_v1())
1249 .unwrap_or(configs::DecoderType::Ultralytics);
1250 let quantization = logical
1251 .quantization
1252 .as_ref()
1253 .map(quantization_to_legacy)
1254 .transpose()?;
1255 let (shape, dshape) = match logical.decoder {
1261 Some(DecoderKind::ModelPack) => (logical.shape.clone(), logical.dshape.clone()),
1262 _ => squeeze_padding_dims(logical.shape.clone(), logical.dshape.clone()),
1263 };
1264
1265 let ty = logical.type_.ok_or_else(|| {
1266 DecoderError::InvalidConfig(format!(
1270 "logical output `{}` has no type; typeless outputs should be \
1271 filtered before legacy conversion",
1272 logical.name.as_deref().unwrap_or("<anonymous>")
1273 ))
1274 })?;
1275
1276 Ok(match ty {
1277 LogicalType::Boxes => ConfigOutput::Boxes(configs::Boxes {
1278 decoder,
1279 quantization,
1280 shape,
1281 dshape,
1282 normalized: logical.normalized,
1283 }),
1284 LogicalType::Scores => ConfigOutput::Scores(configs::Scores {
1285 decoder,
1286 quantization,
1287 shape,
1288 dshape,
1289 }),
1290 LogicalType::Protos => ConfigOutput::Protos(configs::Protos {
1291 decoder,
1292 quantization,
1293 shape,
1294 dshape,
1295 }),
1296 LogicalType::MaskCoefs => ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
1297 decoder,
1298 quantization,
1299 shape,
1300 dshape,
1301 }),
1302 LogicalType::Segmentation => ConfigOutput::Segmentation(configs::Segmentation {
1303 decoder,
1304 quantization,
1305 shape,
1306 dshape,
1307 }),
1308 LogicalType::Masks => ConfigOutput::Mask(configs::Mask {
1309 decoder,
1310 quantization,
1311 shape,
1312 dshape,
1313 }),
1314 LogicalType::Classes => ConfigOutput::Classes(configs::Classes {
1315 decoder,
1316 quantization,
1317 shape,
1318 dshape,
1319 }),
1320 LogicalType::Detection | LogicalType::Detections => {
1324 ConfigOutput::Detection(configs::Detection {
1325 anchors: logical.anchors.clone(),
1326 decoder,
1327 quantization,
1328 shape,
1329 dshape,
1330 normalized: logical.normalized,
1331 })
1332 }
1333 LogicalType::Objectness | LogicalType::Landmarks => {
1336 return Err(DecoderError::NotSupported(format!(
1337 "logical type {:?} has no legacy v1 equivalent; use the \
1338 native v2 decoder path",
1339 ty
1340 )));
1341 }
1342 })
1343}
1344
1345#[cfg(test)]
1346#[cfg_attr(coverage_nightly, coverage(off))]
1347mod tests {
1348 use super::*;
1349
1350 #[test]
1351 fn schema_default_is_v2() {
1352 let s = SchemaV2::default();
1353 assert_eq!(s.schema_version, 2);
1354 assert!(s.outputs.is_empty());
1355 }
1356
1357 #[test]
1358 fn fixtures_round_trip_through_serde() {
1359 let yolov8 =
1360 edgefirst_bench::testdata::read_to_string("per_scale/synthetic_yolov8n_schema.json");
1361 let _: super::SchemaV2 = serde_json::from_str(&yolov8).expect("yolov8n fixture must parse");
1362
1363 let yolo26 =
1364 edgefirst_bench::testdata::read_to_string("per_scale/synthetic_yolo26n_schema.json");
1365 let _: super::SchemaV2 = serde_json::from_str(&yolo26).expect("yolo26n fixture must parse");
1366
1367 let flat =
1368 edgefirst_bench::testdata::read_to_string("per_scale/synthetic_flat_schema.json");
1369 let _: super::SchemaV2 = serde_json::from_str(&flat).expect("flat fixture must parse");
1370 }
1371
1372 #[test]
1373 fn box_encoding_accepts_ltrb_alias_for_direct() {
1374 let dfl: BoxEncoding = serde_json::from_str("\"dfl\"").unwrap();
1375 assert_eq!(dfl, BoxEncoding::Dfl);
1376
1377 let direct: BoxEncoding = serde_json::from_str("\"direct\"").unwrap();
1378 assert_eq!(direct, BoxEncoding::Direct);
1379
1380 let ltrb: BoxEncoding = serde_json::from_str("\"ltrb\"").unwrap();
1382 assert_eq!(ltrb, BoxEncoding::Direct);
1383 }
1384
1385 #[test]
1386 fn dtype_roundtrip() {
1387 for d in [
1388 DType::Int8,
1389 DType::Uint8,
1390 DType::Int16,
1391 DType::Uint16,
1392 DType::Float16,
1393 DType::Float32,
1394 ] {
1395 let j = serde_json::to_string(&d).unwrap();
1396 let back: DType = serde_json::from_str(&j).unwrap();
1397 assert_eq!(back, d);
1398 }
1399 }
1400
1401 #[test]
1402 fn dtype_widths() {
1403 assert_eq!(DType::Int8.size_bytes(), 1);
1404 assert_eq!(DType::Float16.size_bytes(), 2);
1405 assert_eq!(DType::Float32.size_bytes(), 4);
1406 }
1407
1408 #[test]
1409 fn stride_accepts_scalar_or_pair() {
1410 let a: Stride = serde_json::from_str("8").unwrap();
1411 let b: Stride = serde_json::from_str("[8, 16]").unwrap();
1412 assert_eq!(a, Stride::Square(8));
1413 assert_eq!(b, Stride::Rect([8, 16]));
1414 assert_eq!(a.x(), 8);
1415 assert_eq!(a.y(), 8);
1416 assert_eq!(b.x(), 8);
1417 assert_eq!(b.y(), 16);
1418 }
1419
1420 #[test]
1421 fn quantization_scalar_scale() {
1422 let j = r#"{"scale": 0.00392, "zero_point": 0, "dtype": "int8"}"#;
1423 let q: Quantization = serde_json::from_str(j).unwrap();
1424 assert!(q.is_per_tensor());
1425 assert!(q.is_symmetric());
1426 assert_eq!(q.scale_at(0), 0.00392);
1427 assert_eq!(q.scale_at(5), 0.00392);
1428 assert_eq!(q.zero_point_at(0), 0);
1429 }
1430
1431 #[test]
1432 fn quantization_per_channel() {
1433 let j = r#"{"scale": [0.054, 0.089, 0.195], "axis": 0, "dtype": "int8"}"#;
1434 let q: Quantization = serde_json::from_str(j).unwrap();
1435 assert!(q.is_per_channel());
1436 assert!(q.is_symmetric());
1437 assert_eq!(q.axis, Some(0));
1438 assert_eq!(q.scale_at(0), 0.054);
1439 assert_eq!(q.scale_at(2), 0.195);
1440 }
1441
1442 #[test]
1443 fn quantization_asymmetric_per_tensor() {
1444 let j = r#"{"scale": 0.176, "zero_point": 198, "dtype": "uint8"}"#;
1445 let q: Quantization = serde_json::from_str(j).unwrap();
1446 assert!(!q.is_symmetric());
1447 assert_eq!(q.zero_point_at(0), 198);
1448 assert_eq!(q.zero_point_at(10), 198);
1449 }
1450
1451 #[test]
1452 fn quantization_symmetric_default_zero_point() {
1453 let j = r#"{"scale": 0.00392, "dtype": "int8"}"#;
1454 let q: Quantization = serde_json::from_str(j).unwrap();
1455 assert!(q.is_symmetric());
1456 assert_eq!(q.zero_point_at(0), 0);
1457 }
1458
1459 #[test]
1460 fn quantization_to_tensor_per_tensor_asymmetric() {
1461 let q = Quantization {
1462 scale: vec![0.1],
1463 zero_point: Some(vec![-5]),
1464 axis: None,
1465 dtype: Some(DType::Int8),
1466 };
1467 let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
1468 assert!(t.is_per_tensor());
1469 assert!(!t.is_symmetric());
1470 assert_eq!(t.scale(), &[0.1]);
1471 assert_eq!(t.zero_point(), Some(&[-5][..]));
1472 }
1473
1474 #[test]
1475 fn quantization_to_tensor_per_tensor_symmetric() {
1476 let q = Quantization {
1477 scale: vec![0.05],
1478 zero_point: None,
1479 axis: None,
1480 dtype: Some(DType::Int8),
1481 };
1482 let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
1483 assert!(t.is_per_tensor());
1484 assert!(t.is_symmetric());
1485 }
1486
1487 #[test]
1488 fn quantization_to_tensor_per_channel_asymmetric() {
1489 let q = Quantization {
1490 scale: vec![0.1, 0.2, 0.3],
1491 zero_point: Some(vec![-1, 0, 1]),
1492 axis: Some(2),
1493 dtype: Some(DType::Int8),
1494 };
1495 let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
1496 assert!(t.is_per_channel());
1497 assert_eq!(t.axis(), Some(2));
1498 assert_eq!(t.scale().len(), 3);
1499 assert_eq!(t.zero_point().map(|z| z.len()), Some(3));
1500 }
1501
1502 #[test]
1503 fn quantization_to_tensor_per_channel_symmetric() {
1504 let q = Quantization {
1505 scale: vec![0.054, 0.089, 0.195],
1506 zero_point: None,
1507 axis: Some(0),
1508 dtype: Some(DType::Int8),
1509 };
1510 let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
1511 assert!(t.is_per_channel());
1512 assert!(t.is_symmetric());
1513 assert_eq!(t.axis(), Some(0));
1514 }
1515
1516 #[test]
1517 fn quantization_to_tensor_per_channel_missing_axis_errors() {
1518 let q = Quantization {
1519 scale: vec![0.1, 0.2, 0.3],
1520 zero_point: None,
1521 axis: None,
1522 dtype: None,
1523 };
1524 let err = edgefirst_tensor::Quantization::try_from(&q).unwrap_err();
1525 assert!(matches!(
1526 err,
1527 edgefirst_tensor::Error::QuantizationInvalid { .. }
1528 ));
1529 }
1530
1531 #[test]
1532 fn logical_output_flat_tflite_boxes() {
1533 let j = r#"{
1535 "name": "boxes", "type": "boxes",
1536 "shape": [1, 64, 8400],
1537 "dshape": [{"batch": 1}, {"num_features": 64}, {"num_boxes": 8400}],
1538 "dtype": "int8",
1539 "quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
1540 "decoder": "ultralytics",
1541 "encoding": "dfl",
1542 "normalized": true
1543 }"#;
1544 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1545 assert_eq!(lo.type_, Some(LogicalType::Boxes));
1546 assert_eq!(lo.encoding, Some(BoxEncoding::Dfl));
1547 assert_eq!(lo.normalized, Some(true));
1548 assert!(!lo.is_split());
1549 assert_eq!(lo.dtype, Some(DType::Int8));
1550 }
1551
1552 #[test]
1553 fn logical_output_hailo_per_scale_split() {
1554 let j = r#"{
1556 "name": "boxes", "type": "boxes",
1557 "shape": [1, 64, 8400],
1558 "encoding": "dfl", "decoder": "ultralytics", "normalized": true,
1559 "outputs": [
1560 {
1561 "name": "boxes_0", "type": "boxes",
1562 "stride": 8, "scale_index": 0,
1563 "shape": [1, 80, 80, 64],
1564 "dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_features": 64}],
1565 "dtype": "uint8",
1566 "quantization": {"scale": 0.0234, "zero_point": 128, "dtype": "uint8"}
1567 }
1568 ]
1569 }"#;
1570 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1571 assert!(lo.is_split());
1572 assert_eq!(lo.outputs.len(), 1);
1573 let child = &lo.outputs[0];
1574 assert_eq!(child.name, "boxes_0");
1575 assert_eq!(child.type_, Some(PhysicalType::Boxes));
1576 assert_eq!(child.stride, Some(Stride::Square(8)));
1577 assert_eq!(child.scale_index, Some(0));
1578 assert_eq!(child.dtype, DType::Uint8);
1579 }
1580
1581 #[test]
1582 fn logical_output_ara2_xy_wh_channel_split() {
1583 let j = r#"{
1585 "name": "boxes", "type": "boxes",
1586 "shape": [1, 4, 8400, 1],
1587 "encoding": "direct", "decoder": "ultralytics", "normalized": true,
1588 "outputs": [
1589 {
1590 "name": "_model_22_Div_1_output_0", "type": "boxes_xy",
1591 "shape": [1, 2, 8400, 1],
1592 "dshape": [{"batch": 1}, {"box_coords": 2}, {"num_boxes": 8400}, {"padding": 1}],
1593 "dtype": "int16",
1594 "quantization": {"scale": 3.129e-5, "zero_point": 0, "dtype": "int16"}
1595 },
1596 {
1597 "name": "_model_22_Sub_1_output_0", "type": "boxes_wh",
1598 "shape": [1, 2, 8400, 1],
1599 "dshape": [{"batch": 1}, {"box_coords": 2}, {"num_boxes": 8400}, {"padding": 1}],
1600 "dtype": "int16",
1601 "quantization": {"scale": 3.149e-5, "zero_point": 0, "dtype": "int16"}
1602 }
1603 ]
1604 }"#;
1605 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1606 assert_eq!(lo.encoding, Some(BoxEncoding::Direct));
1607 assert_eq!(lo.outputs.len(), 2);
1608 assert_eq!(lo.outputs[0].type_, Some(PhysicalType::BoxesXy));
1609 assert_eq!(lo.outputs[1].type_, Some(PhysicalType::BoxesWh));
1610 assert!(lo.outputs[0].stride.is_none());
1611 assert!(lo.outputs[1].stride.is_none());
1612 }
1613
1614 #[test]
1615 fn logical_output_hailo_scores_sigmoid_applied() {
1616 let j = r#"{
1617 "name": "scores", "type": "scores",
1618 "shape": [1, 80, 8400],
1619 "decoder": "ultralytics", "score_format": "per_class",
1620 "outputs": [
1621 {
1622 "name": "scores_0", "type": "scores",
1623 "stride": 8, "scale_index": 0,
1624 "shape": [1, 80, 80, 80],
1625 "dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_classes": 80}],
1626 "dtype": "uint8",
1627 "quantization": {"scale": 0.003922, "dtype": "uint8"},
1628 "activation_applied": "sigmoid"
1629 }
1630 ]
1631 }"#;
1632 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1633 assert_eq!(lo.score_format, Some(ScoreFormat::PerClass));
1634 let child = &lo.outputs[0];
1635 assert_eq!(child.activation_applied, Some(Activation::Sigmoid));
1636 assert!(child.activation_required.is_none());
1637 }
1638
1639 #[test]
1640 fn yolo26_end_to_end_detections() {
1641 let j = r#"{
1642 "schema_version": 2,
1643 "decoder_version": "yolo26",
1644 "outputs": [{
1645 "name": "output0", "type": "detections",
1646 "shape": [1, 100, 6],
1647 "dshape": [{"batch": 1}, {"num_boxes": 100}, {"num_features": 6}],
1648 "dtype": "int8",
1649 "quantization": {"scale": 0.0078, "zero_point": 0, "dtype": "int8"},
1650 "normalized": false,
1651 "decoder": "ultralytics"
1652 }]
1653 }"#;
1654 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1655 assert_eq!(s.decoder_version, Some(DecoderVersion::Yolo26));
1656 assert!(s.decoder_version.unwrap().is_end_to_end());
1657 assert_eq!(s.outputs[0].type_, Some(LogicalType::Detections));
1658 assert_eq!(s.outputs[0].normalized, Some(false));
1659 assert!(s.nms.is_none());
1660 }
1661
1662 #[test]
1663 fn modelpack_anchor_detection_with_rect_stride() {
1664 let j = r#"{
1665 "schema_version": 2,
1666 "outputs": [{
1667 "name": "output_0", "type": "detection",
1668 "shape": [1, 40, 40, 54],
1669 "dshape": [{"batch": 1}, {"height": 40}, {"width": 40}, {"num_anchors_x_features": 54}],
1670 "dtype": "uint8",
1671 "quantization": {"scale": 0.176, "zero_point": 198, "dtype": "uint8"},
1672 "decoder": "modelpack",
1673 "encoding": "anchor",
1674 "stride": [16, 16],
1675 "anchors": [[0.054, 0.065], [0.089, 0.139], [0.195, 0.196]]
1676 }]
1677 }"#;
1678 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1679 let lo = &s.outputs[0];
1680 assert_eq!(lo.encoding, Some(BoxEncoding::Anchor));
1681 assert_eq!(lo.stride, Some(Stride::Rect([16, 16])));
1682 assert_eq!(lo.anchors.as_ref().map(|a| a.len()), Some(3));
1683 }
1684
1685 #[test]
1686 fn yolov5_obj_x_class_objectness_logical() {
1687 let j = r#"{
1688 "name": "objectness", "type": "objectness",
1689 "shape": [1, 3, 8400],
1690 "decoder": "ultralytics",
1691 "outputs": [{
1692 "name": "objectness_0", "type": "objectness",
1693 "stride": 8, "scale_index": 0,
1694 "shape": [1, 80, 80, 3],
1695 "dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_features": 3}],
1696 "dtype": "uint8",
1697 "quantization": {"scale": 0.0039, "zero_point": 0, "dtype": "uint8"},
1698 "activation_applied": "sigmoid"
1699 }]
1700 }"#;
1701 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1702 assert_eq!(lo.type_, Some(LogicalType::Objectness));
1703 assert_eq!(lo.outputs[0].activation_applied, Some(Activation::Sigmoid));
1704 }
1705
1706 #[test]
1707 fn direct_protos_no_decoder() {
1708 let j = r#"{
1710 "name": "protos", "type": "protos",
1711 "shape": [1, 32, 160, 160],
1712 "dshape": [{"batch": 1}, {"num_protos": 32}, {"height": 160}, {"width": 160}],
1713 "dtype": "uint8",
1714 "quantization": {"scale": 0.0203, "zero_point": 45, "dtype": "uint8"},
1715 "stride": 4
1716 }"#;
1717 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1718 assert_eq!(lo.type_, Some(LogicalType::Protos));
1719 assert!(lo.decoder.is_none());
1720 assert_eq!(lo.stride, Some(Stride::Square(4)));
1721 }
1722
1723 #[test]
1724 fn full_yolov8_tflite_flat_detection() {
1725 let j = r#"{
1727 "schema_version": 2,
1728 "decoder_version": "yolov8",
1729 "nms": "class_agnostic",
1730 "input": { "shape": [1, 640, 640, 3], "cameraadaptor": "rgb" },
1731 "outputs": [
1732 {
1733 "name": "boxes", "type": "boxes",
1734 "shape": [1, 64, 8400],
1735 "dshape": [{"batch": 1}, {"num_features": 64}, {"num_boxes": 8400}],
1736 "dtype": "int8",
1737 "quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
1738 "decoder": "ultralytics",
1739 "encoding": "dfl",
1740 "normalized": true
1741 },
1742 {
1743 "name": "scores", "type": "scores",
1744 "shape": [1, 80, 8400],
1745 "dshape": [{"batch": 1}, {"num_classes": 80}, {"num_boxes": 8400}],
1746 "dtype": "int8",
1747 "quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
1748 "decoder": "ultralytics",
1749 "score_format": "per_class"
1750 }
1751 ]
1752 }"#;
1753 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1754 assert_eq!(s.schema_version, 2);
1755 assert_eq!(s.decoder_version, Some(DecoderVersion::Yolov8));
1756 assert_eq!(s.nms, Some(NmsMode::ClassAgnostic));
1757 assert_eq!(s.input.as_ref().unwrap().shape, vec![1, 640, 640, 3]);
1758 assert_eq!(s.outputs.len(), 2);
1759 }
1760
1761 #[test]
1762 fn schema_unknown_version_parses_without_validation() {
1763 let j = r#"{"schema_version": 99, "outputs": []}"#;
1766 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1767 assert_eq!(s.schema_version, 99);
1768 }
1769
1770 #[test]
1771 fn serde_roundtrip_preserves_fields() {
1772 let original = SchemaV2 {
1773 schema_version: 2,
1774 input: Some(InputSpec {
1775 shape: vec![1, 3, 640, 640],
1776 dshape: vec![],
1777 cameraadaptor: Some("rgb".into()),
1778 }),
1779 outputs: vec![LogicalOutput {
1780 name: Some("boxes".into()),
1781 type_: Some(LogicalType::Boxes),
1782 shape: vec![1, 4, 8400],
1783 dshape: vec![],
1784 decoder: Some(DecoderKind::Ultralytics),
1785 encoding: Some(BoxEncoding::Direct),
1786 score_format: None,
1787 normalized: Some(true),
1788 anchors: None,
1789 stride: None,
1790 dtype: Some(DType::Float32),
1791 quantization: None,
1792 outputs: vec![],
1793 activation_applied: None,
1794 activation_required: None,
1795 }],
1796 nms: Some(NmsMode::ClassAgnostic),
1797 decoder_version: Some(DecoderVersion::Yolov8),
1798 };
1799 let j = serde_json::to_string(&original).unwrap();
1800 let parsed: SchemaV2 = serde_json::from_str(&j).unwrap();
1801 assert_eq!(parsed, original);
1802 }
1803
1804 #[test]
1807 fn parse_v1_yaml_yolov8_seg_testdata() {
1808 let yaml = edgefirst_bench::testdata::read_to_string("yolov8_seg.yaml");
1809 let schema = SchemaV2::parse_yaml(&yaml).expect("parse v1 yaml");
1810 assert_eq!(schema.schema_version, 2);
1811 assert_eq!(schema.outputs.len(), 2);
1812 let det = &schema.outputs[0];
1814 assert_eq!(det.type_, Some(LogicalType::Detection));
1815 assert_eq!(det.shape, vec![1, 116, 8400]);
1816 assert_eq!(det.decoder, Some(DecoderKind::Ultralytics));
1817 assert_eq!(det.encoding, Some(BoxEncoding::Direct));
1818 let q = det.quantization.as_ref().unwrap();
1819 assert_eq!(q.scale.len(), 1);
1820 assert!((q.scale[0] - 0.021_287_762).abs() < 1e-6);
1821 assert_eq!(q.zero_point, Some(vec![31]));
1822 let protos = &schema.outputs[1];
1824 assert_eq!(protos.type_, Some(LogicalType::Protos));
1825 assert_eq!(protos.shape, vec![1, 160, 160, 32]);
1826 }
1827
1828 #[test]
1829 fn parse_v1_json_modelpack_split_testdata() {
1830 let json = edgefirst_bench::testdata::read_to_string("modelpack_split.json");
1831 let schema = SchemaV2::parse_json(&json).expect("parse v1 json");
1832 assert_eq!(schema.schema_version, 2);
1833 assert_eq!(schema.outputs.len(), 2);
1834 for out in &schema.outputs {
1836 assert_eq!(out.type_, Some(LogicalType::Detection));
1837 assert_eq!(out.decoder, Some(DecoderKind::ModelPack));
1838 assert_eq!(out.encoding, Some(BoxEncoding::Anchor));
1839 assert_eq!(out.anchors.as_ref().map(|a| a.len()), Some(3));
1840 }
1841 }
1842
1843 #[test]
1844 fn parse_v2_json_direct_when_schema_version_present() {
1845 let j = r#"{
1846 "schema_version": 2,
1847 "outputs": [{
1848 "name": "boxes", "type": "boxes",
1849 "shape": [1, 4, 8400],
1850 "dshape": [{"batch": 1}, {"box_coords": 4}, {"num_boxes": 8400}],
1851 "dtype": "float32",
1852 "decoder": "ultralytics",
1853 "encoding": "direct",
1854 "normalized": true
1855 }]
1856 }"#;
1857 let schema = SchemaV2::parse_json(j).unwrap();
1858 assert_eq!(schema.schema_version, 2);
1859 assert_eq!(schema.outputs[0].type_, Some(LogicalType::Boxes));
1860 }
1861
1862 #[test]
1863 fn parse_rejects_future_schema_version() {
1864 let j = r#"{"schema_version": 99, "outputs": []}"#;
1865 let err = SchemaV2::parse_json(j).unwrap_err();
1866 matches!(err, DecoderError::NotSupported(_));
1867 }
1868
1869 #[test]
1870 fn parse_absent_schema_version_treats_as_v1() {
1871 let j = r#"{
1873 "outputs": [
1874 {
1875 "type": "boxes", "decoder": "ultralytics",
1876 "shape": [1, 4, 8400],
1877 "quantization": [0.00392, 0]
1878 },
1879 {
1880 "type": "scores", "decoder": "ultralytics",
1881 "shape": [1, 80, 8400],
1882 "quantization": [0.00392, 0]
1883 }
1884 ]
1885 }"#;
1886 let schema = SchemaV2::parse_json(j).expect("v1 legacy parse");
1887 assert_eq!(schema.schema_version, 2); assert_eq!(schema.outputs.len(), 2);
1889 assert_eq!(schema.outputs[0].type_, Some(LogicalType::Boxes));
1890 assert_eq!(schema.outputs[1].type_, Some(LogicalType::Scores));
1891 assert_eq!(schema.outputs[1].score_format, Some(ScoreFormat::PerClass));
1893 }
1894
1895 #[test]
1896 fn from_v1_preserves_nms_and_decoder_version() {
1897 let v1 = ConfigOutputs {
1898 outputs: vec![ConfigOutput::Boxes(crate::configs::Boxes {
1899 decoder: crate::configs::DecoderType::Ultralytics,
1900 quantization: Some(crate::configs::QuantTuple(0.01, 5)),
1901 shape: vec![1, 4, 8400],
1902 dshape: vec![],
1903 normalized: Some(true),
1904 })],
1905 nms: Some(crate::configs::Nms::ClassAware),
1906 decoder_version: Some(crate::configs::DecoderVersion::Yolo11),
1907 };
1908 let v2 = SchemaV2::from_v1(&v1).unwrap();
1909 assert_eq!(v2.nms, Some(NmsMode::ClassAware));
1910 assert_eq!(v2.decoder_version, Some(DecoderVersion::Yolo11));
1911 assert_eq!(v2.outputs[0].normalized, Some(true));
1912 let q = v2.outputs[0].quantization.as_ref().unwrap();
1913 assert_eq!(q.scale, vec![0.01]);
1914 assert_eq!(q.zero_point, Some(vec![5]));
1915 assert_eq!(q.dtype, None); }
1917
1918 #[test]
1925 fn typeless_logical_output_parses_and_roundtrips() {
1926 let j = r#"{
1927 "schema_version": 2,
1928 "outputs": [
1929 {
1930 "name": "extra_telemetry",
1931 "shape": [1, 16]
1932 },
1933 {
1934 "name": "boxes",
1935 "type": "boxes",
1936 "shape": [1, 4, 8400]
1937 }
1938 ]
1939 }"#;
1940 let schema: SchemaV2 = serde_json::from_str(j).unwrap();
1941 assert_eq!(schema.outputs.len(), 2);
1942 assert_eq!(schema.outputs[0].type_, None);
1943 assert_eq!(schema.outputs[0].name.as_deref(), Some("extra_telemetry"));
1944 assert_eq!(schema.outputs[1].type_, Some(LogicalType::Boxes));
1945
1946 let round = serde_json::to_string(&schema).unwrap();
1948 let first_obj = round
1949 .split("\"outputs\":[")
1950 .nth(1)
1951 .and_then(|s| s.split("}").next())
1952 .expect("outputs array");
1953 assert!(
1954 !first_obj.contains("\"type\""),
1955 "typeless output must not serialize a `type` field, got: {first_obj}"
1956 );
1957 }
1958
1959 #[test]
1965 fn typeless_outputs_filtered_from_legacy_config() {
1966 let schema = SchemaV2 {
1967 schema_version: 2,
1968 input: None,
1969 outputs: vec![
1970 LogicalOutput {
1971 name: Some("diagnostic_histogram".into()),
1972 type_: None,
1973 shape: vec![1, 256],
1974 dshape: vec![],
1975 decoder: None,
1976 encoding: None,
1977 score_format: None,
1978 normalized: None,
1979 anchors: None,
1980 stride: None,
1981 dtype: None,
1982 quantization: None,
1983 outputs: vec![],
1984 activation_applied: None,
1985 activation_required: None,
1986 },
1987 LogicalOutput {
1988 name: Some("boxes".into()),
1989 type_: Some(LogicalType::Boxes),
1990 shape: vec![1, 4, 8400],
1991 dshape: vec![],
1992 decoder: Some(DecoderKind::Ultralytics),
1993 encoding: Some(BoxEncoding::Direct),
1994 score_format: None,
1995 normalized: Some(true),
1996 anchors: None,
1997 stride: None,
1998 dtype: None,
1999 quantization: None,
2000 outputs: vec![],
2001 activation_applied: None,
2002 activation_required: None,
2003 },
2004 ],
2005 nms: None,
2006 decoder_version: None,
2007 };
2008 let legacy = schema.to_legacy_config_outputs().unwrap();
2009 assert_eq!(
2010 legacy.outputs.len(),
2011 1,
2012 "typeless output must be filtered from legacy config"
2013 );
2014 assert!(
2015 matches!(legacy.outputs[0], ConfigOutput::Boxes(_)),
2016 "only the typed `boxes` output should survive lowering"
2017 );
2018 }
2019
2020 #[test]
2025 fn all_typeless_schema_produces_empty_legacy_config() {
2026 let schema = SchemaV2 {
2027 schema_version: 2,
2028 input: None,
2029 outputs: vec![LogicalOutput {
2030 name: Some("aux".into()),
2031 type_: None,
2032 shape: vec![1, 8],
2033 dshape: vec![],
2034 decoder: None,
2035 encoding: None,
2036 score_format: None,
2037 normalized: None,
2038 anchors: None,
2039 stride: None,
2040 dtype: None,
2041 quantization: None,
2042 outputs: vec![],
2043 activation_applied: None,
2044 activation_required: None,
2045 }],
2046 nms: None,
2047 decoder_version: None,
2048 };
2049 let legacy = schema.to_legacy_config_outputs().unwrap();
2050 assert!(legacy.outputs.is_empty());
2051 }
2052
2053 #[test]
2059 fn typeless_physical_child_parses_and_skips_uniqueness() {
2060 let j = r#"{
2061 "name": "boxes",
2062 "type": "boxes",
2063 "shape": [1, 8400, 4],
2064 "outputs": [
2065 {
2066 "name": "boxes_xy",
2067 "type": "boxes_xy",
2068 "shape": [1, 8400, 2],
2069 "dtype": "float32"
2070 },
2071 {
2072 "name": "aux_user_managed",
2073 "shape": [1, 8400, 2],
2074 "dtype": "float32"
2075 }
2076 ]
2077 }"#;
2078 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
2079 assert_eq!(lo.outputs.len(), 2);
2080 assert_eq!(lo.outputs[0].type_, Some(PhysicalType::BoxesXy));
2081 assert_eq!(lo.outputs[1].type_, None);
2082
2083 let schema = SchemaV2 {
2087 schema_version: 2,
2088 input: None,
2089 outputs: vec![lo],
2090 nms: None,
2091 decoder_version: None,
2092 };
2093 schema.validate().expect(
2094 "typed + typeless children with equal shape must not trigger \
2095 uniqueness error",
2096 );
2097
2098 let s = serde_json::to_string(&schema).unwrap();
2100 assert!(
2101 s.contains("\"aux_user_managed\""),
2102 "typeless child must survive round-trip: {s}"
2103 );
2104 let aux_obj = s
2106 .split("\"aux_user_managed\"")
2107 .nth(1)
2108 .and_then(|s| s.split('}').next())
2109 .unwrap_or("");
2110 assert!(
2111 !aux_obj.contains("\"type\""),
2112 "typeless child must not serialize `type`, got: {aux_obj}"
2113 );
2114 }
2115
2116 #[test]
2117 fn from_v1_modelpack_anchor_detection_maps_encoding() {
2118 let v1 = ConfigOutputs {
2119 outputs: vec![ConfigOutput::Detection(crate::configs::Detection {
2120 anchors: Some(vec![[0.1, 0.2], [0.3, 0.4]]),
2121 decoder: crate::configs::DecoderType::ModelPack,
2122 quantization: Some(crate::configs::QuantTuple(0.176, 198)),
2123 shape: vec![1, 40, 40, 54],
2124 dshape: vec![],
2125 normalized: None,
2126 })],
2127 nms: None,
2128 decoder_version: None,
2129 };
2130 let v2 = SchemaV2::from_v1(&v1).unwrap();
2131 assert_eq!(v2.outputs[0].encoding, Some(BoxEncoding::Anchor));
2132 assert_eq!(v2.outputs[0].decoder, Some(DecoderKind::ModelPack));
2133 assert_eq!(v2.outputs[0].anchors.as_ref().map(|a| a.len()), Some(2));
2134 }
2135
2136 #[test]
2139 fn validate_accepts_flat_v2_yolov8_detection() {
2140 let j = r#"{
2141 "schema_version": 2,
2142 "outputs": [
2143 {"name":"boxes","type":"boxes","shape":[1,64,8400],
2144 "dtype":"int8","decoder":"ultralytics","encoding":"dfl"},
2145 {"name":"scores","type":"scores","shape":[1,80,8400],
2146 "dtype":"int8","decoder":"ultralytics","score_format":"per_class"}
2147 ]
2148 }"#;
2149 SchemaV2::parse_json(j).unwrap().validate().unwrap();
2150 }
2151
2152 #[test]
2153 fn validate_rejects_unnamed_physical_child() {
2154 let j = r#"{
2155 "schema_version": 2,
2156 "outputs": [{
2157 "name":"boxes","type":"boxes","shape":[1,64,8400],
2158 "encoding":"dfl","decoder":"ultralytics",
2159 "outputs": [{
2160 "name":"","type":"boxes","stride":8,
2161 "shape":[1,80,80,64],"dtype":"uint8"
2162 }]
2163 }]
2164 }"#;
2165 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
2166 let msg = format!("{err}");
2167 assert!(msg.contains("missing `name`"), "got: {msg}");
2168 }
2169
2170 #[test]
2171 fn validate_rejects_duplicate_physical_shapes() {
2172 let j = r#"{
2173 "schema_version": 2,
2174 "outputs": [{
2175 "name":"boxes","type":"boxes","shape":[1,64,8400],
2176 "encoding":"dfl","decoder":"ultralytics",
2177 "outputs": [
2178 {"name":"a","type":"boxes","stride":8,"shape":[1,80,80,64],"dtype":"uint8"},
2179 {"name":"b","type":"boxes","stride":16,"shape":[1,80,80,64],"dtype":"uint8"}
2180 ]
2181 }]
2182 }"#;
2183 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
2184 let msg = format!("{err}");
2185 assert!(msg.contains("share shape"), "got: {msg}");
2186 }
2187
2188 #[test]
2189 fn validate_rejects_mixed_decomposition() {
2190 let j = r#"{
2192 "schema_version": 2,
2193 "outputs": [{
2194 "name":"boxes","type":"boxes","shape":[1,4,8400,1],
2195 "encoding":"direct","decoder":"ultralytics",
2196 "outputs": [
2197 {"name":"xy","type":"boxes_xy","shape":[1,2,8400,1],"dtype":"int16"},
2198 {"name":"p0","type":"boxes","stride":8,"shape":[1,80,80,64],"dtype":"uint8"}
2199 ]
2200 }]
2201 }"#;
2202 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
2203 let msg = format!("{err}");
2204 assert!(msg.contains("uniform"), "got: {msg}");
2205 }
2206
2207 #[test]
2208 fn validate_rejects_dfl_boxes_feature_not_divisible_by_4() {
2209 let j = r#"{
2210 "schema_version": 2,
2211 "outputs": [{
2212 "name":"boxes","type":"boxes","shape":[1,63,8400],
2213 "encoding":"dfl","decoder":"ultralytics",
2214 "outputs": [{
2215 "name":"b0","type":"boxes","stride":8,
2216 "shape":[1,80,80,63],
2217 "dshape":[{"batch":1},{"height":80},{"width":80},{"num_features":63}],
2218 "dtype":"uint8"
2219 }]
2220 }]
2221 }"#;
2222 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
2223 let msg = format!("{err}");
2224 assert!(msg.contains("not"), "got: {msg}");
2225 assert!(msg.contains("divisible by 4"), "got: {msg}");
2226 }
2227
2228 #[test]
2229 fn validate_accepts_hailo_per_scale_yolov8() {
2230 let j = r#"{
2231 "schema_version": 2,
2232 "outputs": [{
2233 "name":"boxes","type":"boxes","shape":[1,64,8400],
2234 "encoding":"dfl","decoder":"ultralytics","normalized":true,
2235 "outputs": [
2236 {"name":"b0","type":"boxes","stride":8,
2237 "shape":[1,80,80,64],
2238 "dshape":[{"batch":1},{"height":80},{"width":80},{"num_features":64}],
2239 "dtype":"uint8",
2240 "quantization":{"scale":0.0234,"zero_point":128,"dtype":"uint8"}},
2241 {"name":"b1","type":"boxes","stride":16,
2242 "shape":[1,40,40,64],
2243 "dshape":[{"batch":1},{"height":40},{"width":40},{"num_features":64}],
2244 "dtype":"uint8",
2245 "quantization":{"scale":0.0198,"zero_point":130,"dtype":"uint8"}},
2246 {"name":"b2","type":"boxes","stride":32,
2247 "shape":[1,20,20,64],
2248 "dshape":[{"batch":1},{"height":20},{"width":20},{"num_features":64}],
2249 "dtype":"uint8",
2250 "quantization":{"scale":0.0312,"zero_point":125,"dtype":"uint8"}}
2251 ]
2252 }]
2253 }"#;
2254 let s = SchemaV2::parse_json(j).unwrap();
2255 s.validate().unwrap();
2256 }
2257
2258 #[test]
2259 fn validate_accepts_ara2_xy_wh() {
2260 let j = r#"{
2261 "schema_version": 2,
2262 "outputs": [{
2263 "name":"boxes","type":"boxes","shape":[1,4,8400,1],
2264 "encoding":"direct","decoder":"ultralytics","normalized":true,
2265 "outputs": [
2266 {"name":"xy","type":"boxes_xy","shape":[1,2,8400,1],
2267 "dshape":[{"batch":1},{"box_coords":2},{"num_boxes":8400},{"padding":1}],
2268 "dtype":"int16",
2269 "quantization":{"scale":3.1e-5,"zero_point":0,"dtype":"int16"}},
2270 {"name":"wh","type":"boxes_wh","shape":[1,2,8400,1],
2271 "dshape":[{"batch":1},{"box_coords":2},{"num_boxes":8400},{"padding":1}],
2272 "dtype":"int16",
2273 "quantization":{"scale":3.2e-5,"zero_point":0,"dtype":"int16"}}
2274 ]
2275 }]
2276 }"#;
2277 SchemaV2::parse_json(j).unwrap().validate().unwrap();
2278 }
2279
2280 #[test]
2281 fn parse_file_auto_detects_json() {
2282 let tmp = std::env::temp_dir().join(format!("schema_v2_test_{}.json", std::process::id()));
2283 std::fs::write(&tmp, r#"{"schema_version":2,"outputs":[]}"#).unwrap();
2284 let s = SchemaV2::parse_file(&tmp).unwrap();
2285 assert_eq!(s.schema_version, 2);
2286 let _ = std::fs::remove_file(&tmp);
2287 }
2288
2289 #[test]
2290 fn parse_file_auto_detects_yaml() {
2291 let tmp = std::env::temp_dir().join(format!("schema_v2_test_{}.yaml", std::process::id()));
2292 std::fs::write(&tmp, "schema_version: 2\noutputs: []\n").unwrap();
2293 let s = SchemaV2::parse_file(&tmp).unwrap();
2294 assert_eq!(s.schema_version, 2);
2295 let _ = std::fs::remove_file(&tmp);
2296 }
2297
2298 #[test]
2301 fn parse_real_ara2_int8_dvm_metadata() {
2302 let json = edgefirst_bench::testdata::read_to_string("ara2_int8_edgefirst.json");
2303 let schema = SchemaV2::parse_json(&json).expect("ARA-2 int8 parse");
2304 assert_eq!(schema.schema_version, 2);
2305 assert_eq!(schema.decoder_version, Some(DecoderVersion::Yolov8));
2306 assert_eq!(schema.nms, Some(NmsMode::ClassAgnostic));
2307 assert_eq!(schema.input.as_ref().unwrap().shape, vec![1, 3, 640, 640]);
2308
2309 assert_eq!(schema.outputs.len(), 4);
2311 let boxes = &schema.outputs[0];
2312 assert_eq!(boxes.type_, Some(LogicalType::Boxes));
2313 assert_eq!(boxes.encoding, Some(BoxEncoding::Direct));
2314 assert_eq!(boxes.normalized, Some(true));
2315 assert_eq!(boxes.shape, vec![1, 4, 8400, 1]); assert_eq!(boxes.outputs.len(), 2);
2317 assert_eq!(boxes.outputs[0].type_, Some(PhysicalType::BoxesXy));
2318 assert_eq!(boxes.outputs[1].type_, Some(PhysicalType::BoxesWh));
2319 let q_xy = boxes.outputs[0].quantization.as_ref().unwrap();
2321 assert_eq!(q_xy.dtype, Some(DType::Int8));
2322 assert!((q_xy.scale[0] - 0.004_177_792).abs() < 1e-6);
2323 assert_eq!(q_xy.zero_point_at(0), -122);
2324
2325 let scores = &schema.outputs[1];
2326 assert_eq!(scores.type_, Some(LogicalType::Scores));
2327 assert_eq!(scores.score_format, Some(ScoreFormat::PerClass));
2328 assert_eq!(scores.shape, vec![1, 80, 8400, 1]);
2329
2330 let mask_coefs = &schema.outputs[2];
2331 assert_eq!(mask_coefs.type_, Some(LogicalType::MaskCoefs));
2332 assert_eq!(mask_coefs.shape, vec![1, 32, 8400, 1]);
2333
2334 let protos = &schema.outputs[3];
2335 assert_eq!(protos.type_, Some(LogicalType::Protos));
2336 assert_eq!(protos.shape, vec![1, 32, 160, 160]);
2337
2338 schema.validate().expect("ARA-2 int8 validate");
2340 }
2341
2342 #[test]
2343 fn parse_real_ara2_int16_dvm_metadata() {
2344 let json = edgefirst_bench::testdata::read_to_string("ara2_int16_edgefirst.json");
2345 let schema = SchemaV2::parse_json(&json).expect("ARA-2 int16 parse");
2346 assert_eq!(schema.schema_version, 2);
2347 assert_eq!(schema.outputs.len(), 4);
2348 let boxes = &schema.outputs[0];
2349 assert_eq!(boxes.outputs.len(), 2);
2350 let q_xy = boxes.outputs[0].quantization.as_ref().unwrap();
2351 assert_eq!(q_xy.dtype, Some(DType::Int16));
2352 assert!((q_xy.scale[0] - 3.211_570_6e-5).abs() < 1e-10);
2353 assert_eq!(q_xy.zero_point_at(0), 0);
2354 let mc_q = schema.outputs[2].quantization.as_ref().unwrap();
2356 assert_eq!(mc_q.dtype, Some(DType::Int16));
2357 schema.validate().expect("ARA-2 int16 validate");
2358 }
2359
2360 #[test]
2361 fn parse_yaml_with_explicit_schema_version_2() {
2362 let yaml = r#"
2363schema_version: 2
2364outputs:
2365 - name: scores
2366 type: scores
2367 shape: [1, 80, 8400]
2368 dtype: int8
2369 quantization:
2370 scale: 0.00392
2371 dtype: int8
2372 decoder: ultralytics
2373 score_format: per_class
2374"#;
2375 let schema = SchemaV2::parse_yaml(yaml).unwrap();
2376 assert_eq!(schema.schema_version, 2);
2377 assert_eq!(schema.outputs[0].score_format, Some(ScoreFormat::PerClass));
2378 }
2379
2380 #[test]
2383 fn squeeze_padding_dims_preserves_shape_when_dshape_absent() {
2384 let (shape, dshape) = squeeze_padding_dims(vec![1, 4, 8400], vec![]);
2389 assert_eq!(shape, vec![1, 4, 8400]);
2390 assert!(dshape.is_empty());
2391 }
2392
2393 #[test]
2394 fn to_legacy_modelpack_boxes_preserves_padding_dim() {
2395 let j = r#"{
2400 "schema_version": 2,
2401 "outputs": [
2402 {"name":"boxes","type":"boxes",
2403 "shape":[1,1935,1,4],
2404 "dshape":[{"batch":1},{"num_boxes":1935},{"padding":1},{"box_coords":4}],
2405 "decoder":"modelpack"}
2406 ]
2407 }"#;
2408 let schema = SchemaV2::parse_json(j).unwrap();
2409 let legacy = schema.to_legacy_config_outputs().expect("lowers cleanly");
2410 let boxes = match &legacy.outputs[0] {
2411 crate::ConfigOutput::Boxes(b) => b,
2412 other => panic!("expected Boxes, got {other:?}"),
2413 };
2414 assert_eq!(boxes.shape, vec![1, 1935, 1, 4]);
2417 assert_eq!(
2418 boxes.dshape,
2419 vec![
2420 (DimName::Batch, 1),
2421 (DimName::NumBoxes, 1935),
2422 (DimName::Padding, 1),
2423 (DimName::BoxCoords, 4),
2424 ]
2425 );
2426 }
2427
2428 #[test]
2429 fn to_legacy_preserves_shape_for_v2_split_boxes_without_dshape() {
2430 let j = r#"{
2434 "schema_version": 2,
2435 "outputs": [
2436 {"name":"boxes","type":"boxes","shape":[1,4,8400],
2437 "dtype":"float32","decoder":"ultralytics","encoding":"direct"},
2438 {"name":"scores","type":"scores","shape":[1,80,8400],
2439 "dtype":"float32","decoder":"ultralytics","score_format":"per_class"}
2440 ]
2441 }"#;
2442 let schema = SchemaV2::parse_json(j).unwrap();
2443 let legacy = schema.to_legacy_config_outputs().expect("lowers cleanly");
2444 let boxes = match &legacy.outputs[0] {
2445 crate::ConfigOutput::Boxes(b) => b,
2446 other => panic!("expected Boxes, got {other:?}"),
2447 };
2448 assert_eq!(boxes.shape, vec![1, 4, 8400]);
2449 let scores = match &legacy.outputs[1] {
2450 crate::ConfigOutput::Scores(s) => s,
2451 other => panic!("expected Scores, got {other:?}"),
2452 };
2453 assert_eq!(scores.shape, vec![1, 80, 8400]);
2454 }
2455}