1use crate::configs::{self, deserialize_dshape, DimName, QuantTuple};
47use crate::{ConfigOutput, ConfigOutputs, DecoderError, DecoderResult};
48use serde::{Deserialize, Serialize};
49
50pub const MAX_SUPPORTED_SCHEMA_VERSION: u32 = 2;
54
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
61pub struct SchemaV2 {
62 pub schema_version: u32,
64
65 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub input: Option<InputSpec>,
72
73 #[serde(default, skip_serializing_if = "Vec::is_empty")]
76 pub outputs: Vec<LogicalOutput>,
77
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub nms: Option<NmsMode>,
81
82 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub decoder_version: Option<DecoderVersion>,
88}
89
90impl Default for SchemaV2 {
91 fn default() -> Self {
92 Self {
93 schema_version: 2,
94 input: None,
95 outputs: Vec::new(),
96 nms: None,
97 decoder_version: None,
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub struct InputSpec {
105 pub shape: Vec<usize>,
107
108 #[serde(
111 default,
112 deserialize_with = "deserialize_dshape",
113 skip_serializing_if = "Vec::is_empty"
114 )]
115 pub dshape: Vec<(DimName, usize)>,
116
117 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub cameraadaptor: Option<String>,
122}
123
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
132pub struct LogicalOutput {
133 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub name: Option<String>,
136
137 #[serde(rename = "type")]
139 pub type_: LogicalType,
140
141 pub shape: Vec<usize>,
144
145 #[serde(
147 default,
148 deserialize_with = "deserialize_dshape",
149 skip_serializing_if = "Vec::is_empty"
150 )]
151 pub dshape: Vec<(DimName, usize)>,
152
153 #[serde(default, skip_serializing_if = "Option::is_none")]
156 pub decoder: Option<DecoderKind>,
157
158 #[serde(default, skip_serializing_if = "Option::is_none")]
160 pub encoding: Option<BoxEncoding>,
161
162 #[serde(default, skip_serializing_if = "Option::is_none")]
164 pub score_format: Option<ScoreFormat>,
165
166 #[serde(default, skip_serializing_if = "Option::is_none")]
171 pub normalized: Option<bool>,
172
173 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub anchors: Option<Vec<[f32; 2]>>,
177
178 #[serde(default, skip_serializing_if = "Option::is_none")]
182 pub stride: Option<Stride>,
183
184 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub dtype: Option<DType>,
188
189 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub quantization: Option<Quantization>,
193
194 #[serde(default, skip_serializing_if = "Vec::is_empty")]
198 pub outputs: Vec<PhysicalOutput>,
199}
200
201impl LogicalOutput {
202 pub fn is_split(&self) -> bool {
205 !self.outputs.is_empty()
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215pub struct PhysicalOutput {
216 pub name: String,
220
221 #[serde(rename = "type")]
224 pub type_: PhysicalType,
225
226 pub shape: Vec<usize>,
228
229 #[serde(
232 default,
233 deserialize_with = "deserialize_dshape",
234 skip_serializing_if = "Vec::is_empty"
235 )]
236 pub dshape: Vec<(DimName, usize)>,
237
238 pub dtype: DType,
240
241 #[serde(default, skip_serializing_if = "Option::is_none")]
244 pub quantization: Option<Quantization>,
245
246 #[serde(default, skip_serializing_if = "Option::is_none")]
249 pub stride: Option<Stride>,
250
251 #[serde(default, skip_serializing_if = "Option::is_none")]
254 pub scale_index: Option<usize>,
255
256 #[serde(default, skip_serializing_if = "Option::is_none")]
260 pub activation_applied: Option<Activation>,
261
262 #[serde(default, skip_serializing_if = "Option::is_none")]
265 pub activation_required: Option<Activation>,
266}
267
268#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
274pub struct Quantization {
275 #[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
278 pub scale: Vec<f32>,
279
280 #[serde(
284 default,
285 deserialize_with = "deserialize_opt_scalar_or_vec_i32",
286 skip_serializing_if = "Option::is_none"
287 )]
288 pub zero_point: Option<Vec<i32>>,
289
290 #[serde(default, skip_serializing_if = "Option::is_none")]
293 pub axis: Option<usize>,
294
295 #[serde(default, skip_serializing_if = "Option::is_none")]
299 pub dtype: Option<DType>,
300}
301
302impl Quantization {
303 pub fn is_per_tensor(&self) -> bool {
305 self.scale.len() == 1
306 }
307
308 pub fn is_per_channel(&self) -> bool {
310 self.scale.len() > 1
311 }
312
313 pub fn is_symmetric(&self) -> bool {
315 match &self.zero_point {
316 None => true,
317 Some(zps) => zps.iter().all(|&z| z == 0),
318 }
319 }
320
321 pub fn zero_point_at(&self, channel: usize) -> i32 {
324 match &self.zero_point {
325 None => 0,
326 Some(zps) if zps.len() == 1 => zps[0],
327 Some(zps) => zps.get(channel).copied().unwrap_or(0),
328 }
329 }
330
331 pub fn scale_at(&self, channel: usize) -> f32 {
333 if self.scale.len() == 1 {
334 self.scale[0]
335 } else {
336 self.scale.get(channel).copied().unwrap_or(0.0)
337 }
338 }
339}
340
341fn deserialize_scalar_or_vec_f32<'de, D>(de: D) -> Result<Vec<f32>, D::Error>
343where
344 D: serde::Deserializer<'de>,
345{
346 #[derive(Deserialize)]
347 #[serde(untagged)]
348 enum OneOrMany {
349 One(f32),
350 Many(Vec<f32>),
351 }
352 match OneOrMany::deserialize(de)? {
353 OneOrMany::One(v) => Ok(vec![v]),
354 OneOrMany::Many(vs) => Ok(vs),
355 }
356}
357
358fn deserialize_opt_scalar_or_vec_i32<'de, D>(de: D) -> Result<Option<Vec<i32>>, D::Error>
360where
361 D: serde::Deserializer<'de>,
362{
363 #[derive(Deserialize)]
364 #[serde(untagged)]
365 enum OneOrMany {
366 One(i32),
367 Many(Vec<i32>),
368 }
369 match Option::<OneOrMany>::deserialize(de)? {
370 None => Ok(None),
371 Some(OneOrMany::One(v)) => Ok(Some(vec![v])),
372 Some(OneOrMany::Many(vs)) => Ok(Some(vs)),
373 }
374}
375
376#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
379#[serde(untagged)]
380pub enum Stride {
381 Square(u32),
382 Rect([u32; 2]),
383}
384
385impl Stride {
386 pub fn x(self) -> u32 {
388 match self {
389 Stride::Square(s) => s,
390 Stride::Rect([sx, _]) => sx,
391 }
392 }
393
394 pub fn y(self) -> u32 {
396 match self {
397 Stride::Square(s) => s,
398 Stride::Rect([_, sy]) => sy,
399 }
400 }
401}
402
403#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
405#[serde(rename_all = "snake_case")]
406pub enum LogicalType {
407 Boxes,
409 Scores,
411 Objectness,
413 Classes,
415 MaskCoefs,
417 Protos,
419 Landmarks,
421 Detections,
423 Segmentation,
425 Masks,
427 Detection,
429}
430
431#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
437#[serde(rename_all = "snake_case")]
438pub enum PhysicalType {
439 Boxes,
440 Scores,
441 Objectness,
442 Classes,
443 MaskCoefs,
444 Protos,
445 Landmarks,
446 Detections,
447 Segmentation,
448 Masks,
449 Detection,
450 BoxesXy,
452 BoxesWh,
454}
455
456#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
458#[serde(rename_all = "snake_case")]
459pub enum BoxEncoding {
460 Dfl,
463 Direct,
466 Anchor,
469}
470
471#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
473#[serde(rename_all = "snake_case")]
474pub enum ScoreFormat {
475 PerClass,
478 ObjXClass,
482}
483
484#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
486#[serde(rename_all = "snake_case")]
487pub enum Activation {
488 Sigmoid,
489 Softmax,
490 Tanh,
491}
492
493#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
495pub enum DecoderKind {
496 #[serde(rename = "modelpack")]
498 ModelPack,
499 #[serde(rename = "ultralytics")]
501 Ultralytics,
502}
503
504#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
506#[serde(rename_all = "snake_case")]
507pub enum DecoderVersion {
508 Yolov5,
509 Yolov8,
510 Yolo11,
511 Yolo26,
512}
513
514impl DecoderVersion {
515 pub fn is_end_to_end(self) -> bool {
517 matches!(self, DecoderVersion::Yolo26)
518 }
519}
520
521#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
523#[serde(rename_all = "snake_case")]
524pub enum NmsMode {
525 ClassAgnostic,
527 ClassAware,
530}
531
532#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
534#[serde(rename_all = "snake_case")]
535pub enum DType {
536 Int8,
537 Uint8,
538 Int16,
539 Uint16,
540 Int32,
541 Uint32,
542 Float16,
543 Float32,
544}
545
546impl DType {
547 pub fn size_bytes(self) -> usize {
549 match self {
550 DType::Int8 | DType::Uint8 => 1,
551 DType::Int16 | DType::Uint16 | DType::Float16 => 2,
552 DType::Int32 | DType::Uint32 | DType::Float32 => 4,
553 }
554 }
555
556 pub fn is_integer(self) -> bool {
558 matches!(
559 self,
560 DType::Int8
561 | DType::Uint8
562 | DType::Int16
563 | DType::Uint16
564 | DType::Int32
565 | DType::Uint32
566 )
567 }
568
569 pub fn is_float(self) -> bool {
571 matches!(self, DType::Float16 | DType::Float32)
572 }
573}
574
575impl SchemaV2 {
580 pub fn parse_json(s: &str) -> DecoderResult<Self> {
588 let value: serde_json::Value = serde_json::from_str(s)?;
589 Self::from_json_value(value)
590 }
591
592 pub fn parse_yaml(s: &str) -> DecoderResult<Self> {
596 let value: serde_yaml::Value = serde_yaml::from_str(s)?;
597 let json = serde_json::to_value(value)
598 .map_err(|e| DecoderError::InvalidConfig(format!("yaml→json bridge failed: {e}")))?;
599 Self::from_json_value(json)
600 }
601
602 pub fn parse_file(path: impl AsRef<std::path::Path>) -> DecoderResult<Self> {
606 let path = path.as_ref();
607 let content = std::fs::read_to_string(path)
608 .map_err(|e| DecoderError::InvalidConfig(format!("read {}: {e}", path.display())))?;
609 let ext = path
610 .extension()
611 .and_then(|e| e.to_str())
612 .map(str::to_ascii_lowercase);
613 match ext.as_deref() {
614 Some("json") => Self::parse_json(&content),
615 Some("yaml") | Some("yml") => Self::parse_yaml(&content),
616 _ => Self::parse_json(&content).or_else(|_| Self::parse_yaml(&content)),
617 }
618 }
619
620 pub fn from_json_value(value: serde_json::Value) -> DecoderResult<Self> {
623 let version = value
624 .get("schema_version")
625 .and_then(|v| v.as_u64())
626 .map(|v| v as u32)
627 .unwrap_or(1);
628
629 if version > MAX_SUPPORTED_SCHEMA_VERSION {
630 return Err(DecoderError::NotSupported(format!(
631 "schema_version {version} is not supported by this HAL \
632 (maximum supported version is {MAX_SUPPORTED_SCHEMA_VERSION}); \
633 upgrade the HAL or downgrade the metadata"
634 )));
635 }
636
637 if version >= 2 {
638 serde_json::from_value(value).map_err(DecoderError::Json)
639 } else {
640 let v1: ConfigOutputs = serde_json::from_value(value).map_err(DecoderError::Json)?;
641 Self::from_v1(&v1)
642 }
643 }
644
645 pub fn from_v1(v1: &ConfigOutputs) -> DecoderResult<Self> {
661 let outputs = v1
662 .outputs
663 .iter()
664 .map(logical_from_v1)
665 .collect::<DecoderResult<Vec<_>>>()?;
666 Ok(SchemaV2 {
667 schema_version: 2,
668 input: None,
669 outputs,
670 nms: v1.nms.as_ref().map(NmsMode::from_v1),
671 decoder_version: v1.decoder_version.as_ref().map(DecoderVersion::from_v1),
672 })
673 }
674}
675
676impl SchemaV2 {
677 pub fn to_legacy_config_outputs(&self) -> DecoderResult<ConfigOutputs> {
704 let mut outputs = Vec::with_capacity(self.outputs.len());
705 for logical in &self.outputs {
706 if logical.type_ == LogicalType::Boxes
713 && logical.encoding == Some(BoxEncoding::Dfl)
714 && logical.outputs.is_empty()
715 {
716 return Err(DecoderError::NotSupported(format!(
717 "`boxes` output `{}` has `encoding: dfl` on a flat \
718 logical (no per-scale children); the HAL's DFL \
719 decode kernel only runs inside the per-scale merge \
720 path. Split the boxes output into per-FPN-level \
721 children (Hailo convention) or pre-decode to 4 \
722 channels in the model graph (TFLite convention).",
723 logical.name.as_deref().unwrap_or("<anonymous>"),
724 )));
725 }
726 if let Some(q) = &logical.quantization {
727 if q.is_per_channel() {
728 return Err(DecoderError::NotSupported(format!(
729 "logical `{}` uses per-channel quantization \
730 (axis {:?}, {} scales); the v1 decoder only \
731 supports per-tensor quantization",
732 logical.name.as_deref().unwrap_or("<anonymous>"),
733 q.axis,
734 q.scale.len(),
735 )));
736 }
737 }
738 outputs.push(logical_to_legacy_config_output(logical)?);
739 }
740 Ok(ConfigOutputs {
741 outputs,
742 nms: self.nms.map(NmsMode::to_v1),
743 decoder_version: self.decoder_version.map(|v| v.to_v1()),
744 })
745 }
746
747 pub fn validate(&self) -> DecoderResult<()> {
765 if self.schema_version == 0 || self.schema_version > MAX_SUPPORTED_SCHEMA_VERSION {
766 return Err(DecoderError::InvalidConfig(format!(
767 "schema_version {} outside supported range [1, {MAX_SUPPORTED_SCHEMA_VERSION}]",
768 self.schema_version
769 )));
770 }
771
772 for logical in &self.outputs {
773 validate_logical(logical)?;
774 }
775
776 Ok(())
777 }
778}
779
780fn validate_logical(logical: &LogicalOutput) -> DecoderResult<()> {
781 if logical.outputs.is_empty() {
782 return Ok(());
783 }
784
785 for child in &logical.outputs {
787 if child.name.is_empty() {
788 return Err(DecoderError::InvalidConfig(format!(
789 "physical child of logical `{}` is missing `name`; name is \
790 required for tensor binding",
791 logical.name.as_deref().unwrap_or("<anonymous>")
792 )));
793 }
794 }
795
796 for (i, a) in logical.outputs.iter().enumerate() {
800 for b in &logical.outputs[i + 1..] {
801 if a.shape == b.shape && a.type_ == b.type_ {
802 return Err(DecoderError::InvalidConfig(format!(
803 "physical children `{}` and `{}` share shape {:?} and \
804 type; tensor binding cannot be resolved",
805 a.name, b.name, a.shape
806 )));
807 }
808 }
809 }
810
811 let strided: Vec<_> = logical.outputs.iter().map(|c| c.stride.is_some()).collect();
815 let all_strided = strided.iter().all(|&b| b);
816 let none_strided = strided.iter().all(|&b| !b);
817 if !(all_strided || none_strided) {
818 return Err(DecoderError::InvalidConfig(format!(
819 "logical `{}` mixes per-scale children (with stride) and \
820 channel sub-split children (without stride); decomposition \
821 must be uniform",
822 logical.name.as_deref().unwrap_or("<anonymous>")
823 )));
824 }
825
826 if logical.type_ == LogicalType::Boxes && logical.encoding == Some(BoxEncoding::Dfl) {
829 for child in &logical.outputs {
830 if let Some(feat) = last_feature_axis(child) {
831 if feat % 4 != 0 {
832 return Err(DecoderError::InvalidConfig(format!(
833 "DFL boxes child `{}` feature axis {feat} is not \
834 divisible by 4 (reg_max×4)",
835 child.name
836 )));
837 }
838 }
839 }
840 }
841
842 Ok(())
843}
844
845fn last_feature_axis(child: &PhysicalOutput) -> Option<usize> {
848 for (name, size) in &child.dshape {
851 if matches!(
852 name,
853 DimName::NumFeatures
854 | DimName::NumClasses
855 | DimName::NumProtos
856 | DimName::BoxCoords
857 | DimName::NumAnchorsXFeatures
858 ) {
859 return Some(*size);
860 }
861 }
862 child.shape.last().copied()
863}
864
865fn quantization_from_v1(q: Option<QuantTuple>) -> Option<Quantization> {
866 q.map(|QuantTuple(scale, zp)| Quantization {
867 scale: vec![scale],
868 zero_point: Some(vec![zp]),
869 axis: None,
870 dtype: None,
871 })
872}
873
874fn logical_from_v1(v1: &ConfigOutput) -> DecoderResult<LogicalOutput> {
875 match v1 {
876 ConfigOutput::Detection(d) => {
877 let encoding = match (d.decoder, d.anchors.is_some()) {
883 (configs::DecoderType::ModelPack, true) => Some(BoxEncoding::Anchor),
884 (configs::DecoderType::Ultralytics, _) => Some(BoxEncoding::Direct),
885 (configs::DecoderType::ModelPack, false) => None,
888 };
889 Ok(LogicalOutput {
890 name: None,
891 type_: LogicalType::Detection,
892 shape: d.shape.clone(),
893 dshape: d.dshape.clone(),
894 decoder: Some(DecoderKind::from_v1(d.decoder)),
895 encoding,
896 score_format: None,
897 normalized: d.normalized,
898 anchors: d.anchors.clone(),
899 stride: None,
900 dtype: None,
901 quantization: quantization_from_v1(d.quantization),
902 outputs: Vec::new(),
903 })
904 }
905 ConfigOutput::Boxes(b) => Ok(LogicalOutput {
906 name: None,
907 type_: LogicalType::Boxes,
908 shape: b.shape.clone(),
909 dshape: b.dshape.clone(),
910 decoder: Some(DecoderKind::from_v1(b.decoder)),
911 encoding: Some(BoxEncoding::Direct),
915 score_format: None,
916 normalized: b.normalized,
917 anchors: None,
918 stride: None,
919 dtype: None,
920 quantization: quantization_from_v1(b.quantization),
921 outputs: Vec::new(),
922 }),
923 ConfigOutput::Scores(s) => Ok(LogicalOutput {
924 name: None,
925 type_: LogicalType::Scores,
926 shape: s.shape.clone(),
927 dshape: s.dshape.clone(),
928 decoder: Some(DecoderKind::from_v1(s.decoder)),
929 encoding: None,
930 score_format: Some(ScoreFormat::PerClass),
934 normalized: None,
935 anchors: None,
936 stride: None,
937 dtype: None,
938 quantization: quantization_from_v1(s.quantization),
939 outputs: Vec::new(),
940 }),
941 ConfigOutput::Protos(p) => Ok(LogicalOutput {
942 name: None,
943 type_: LogicalType::Protos,
944 shape: p.shape.clone(),
945 dshape: p.dshape.clone(),
946 decoder: Some(DecoderKind::from_v1(p.decoder)),
948 encoding: None,
949 score_format: None,
950 normalized: None,
951 anchors: None,
952 stride: None,
953 dtype: None,
954 quantization: quantization_from_v1(p.quantization),
955 outputs: Vec::new(),
956 }),
957 ConfigOutput::MaskCoefficients(m) => Ok(LogicalOutput {
958 name: None,
959 type_: LogicalType::MaskCoefs,
960 shape: m.shape.clone(),
961 dshape: m.dshape.clone(),
962 decoder: Some(DecoderKind::from_v1(m.decoder)),
963 encoding: None,
964 score_format: None,
965 normalized: None,
966 anchors: None,
967 stride: None,
968 dtype: None,
969 quantization: quantization_from_v1(m.quantization),
970 outputs: Vec::new(),
971 }),
972 ConfigOutput::Segmentation(seg) => Ok(LogicalOutput {
973 name: None,
974 type_: LogicalType::Segmentation,
975 shape: seg.shape.clone(),
976 dshape: seg.dshape.clone(),
977 decoder: Some(DecoderKind::from_v1(seg.decoder)),
978 encoding: None,
979 score_format: None,
980 normalized: None,
981 anchors: None,
982 stride: None,
983 dtype: None,
984 quantization: quantization_from_v1(seg.quantization),
985 outputs: Vec::new(),
986 }),
987 ConfigOutput::Mask(m) => Ok(LogicalOutput {
988 name: None,
989 type_: LogicalType::Masks,
990 shape: m.shape.clone(),
991 dshape: m.dshape.clone(),
992 decoder: Some(DecoderKind::from_v1(m.decoder)),
993 encoding: None,
994 score_format: None,
995 normalized: None,
996 anchors: None,
997 stride: None,
998 dtype: None,
999 quantization: quantization_from_v1(m.quantization),
1000 outputs: Vec::new(),
1001 }),
1002 ConfigOutput::Classes(c) => Ok(LogicalOutput {
1003 name: None,
1004 type_: LogicalType::Classes,
1005 shape: c.shape.clone(),
1006 dshape: c.dshape.clone(),
1007 decoder: Some(DecoderKind::from_v1(c.decoder)),
1008 encoding: None,
1009 score_format: None,
1010 normalized: None,
1011 anchors: None,
1012 stride: None,
1013 dtype: None,
1014 quantization: quantization_from_v1(c.quantization),
1015 outputs: Vec::new(),
1016 }),
1017 }
1018}
1019
1020impl DecoderKind {
1021 pub fn from_v1(v: configs::DecoderType) -> Self {
1023 match v {
1024 configs::DecoderType::ModelPack => DecoderKind::ModelPack,
1025 configs::DecoderType::Ultralytics => DecoderKind::Ultralytics,
1026 }
1027 }
1028
1029 pub fn to_v1(self) -> configs::DecoderType {
1031 match self {
1032 DecoderKind::ModelPack => configs::DecoderType::ModelPack,
1033 DecoderKind::Ultralytics => configs::DecoderType::Ultralytics,
1034 }
1035 }
1036}
1037
1038impl DecoderVersion {
1039 pub fn from_v1(v: &configs::DecoderVersion) -> Self {
1041 match v {
1042 configs::DecoderVersion::Yolov5 => DecoderVersion::Yolov5,
1043 configs::DecoderVersion::Yolov8 => DecoderVersion::Yolov8,
1044 configs::DecoderVersion::Yolo11 => DecoderVersion::Yolo11,
1045 configs::DecoderVersion::Yolo26 => DecoderVersion::Yolo26,
1046 }
1047 }
1048
1049 pub fn to_v1(self) -> configs::DecoderVersion {
1051 match self {
1052 DecoderVersion::Yolov5 => configs::DecoderVersion::Yolov5,
1053 DecoderVersion::Yolov8 => configs::DecoderVersion::Yolov8,
1054 DecoderVersion::Yolo11 => configs::DecoderVersion::Yolo11,
1055 DecoderVersion::Yolo26 => configs::DecoderVersion::Yolo26,
1056 }
1057 }
1058}
1059
1060impl NmsMode {
1061 pub fn from_v1(v: &configs::Nms) -> Self {
1063 match v {
1064 configs::Nms::ClassAgnostic => NmsMode::ClassAgnostic,
1065 configs::Nms::ClassAware => NmsMode::ClassAware,
1066 }
1067 }
1068
1069 pub fn to_v1(self) -> configs::Nms {
1071 match self {
1072 NmsMode::ClassAgnostic => configs::Nms::ClassAgnostic,
1073 NmsMode::ClassAware => configs::Nms::ClassAware,
1074 }
1075 }
1076}
1077
1078fn quantization_to_legacy(q: &Quantization) -> DecoderResult<QuantTuple> {
1081 if q.is_per_channel() {
1082 return Err(DecoderError::NotSupported(
1083 "per-channel quantization cannot be expressed as a v1 QuantTuple".into(),
1084 ));
1085 }
1086 let scale = *q.scale.first().unwrap_or(&0.0);
1087 let zp = q.zero_point_at(0);
1088 Ok(QuantTuple(scale, zp))
1089}
1090
1091pub(crate) fn squeeze_padding_dims(
1097 shape: Vec<usize>,
1098 dshape: Vec<(DimName, usize)>,
1099) -> (Vec<usize>, Vec<(DimName, usize)>) {
1100 let keep: Vec<bool> = dshape
1101 .iter()
1102 .map(|(n, _)| !matches!(n, DimName::Padding))
1103 .collect();
1104 let shape = shape
1105 .into_iter()
1106 .zip(keep.iter())
1107 .filter_map(|(s, &k)| k.then_some(s))
1108 .collect();
1109 let dshape = dshape
1110 .into_iter()
1111 .zip(keep.iter())
1112 .filter_map(|(d, &k)| k.then_some(d))
1113 .collect();
1114 (shape, dshape)
1115}
1116
1117pub(crate) fn padding_axes(dshape: &[(DimName, usize)]) -> Vec<usize> {
1122 let mut v: Vec<usize> = dshape
1123 .iter()
1124 .enumerate()
1125 .filter_map(|(i, (n, _))| matches!(n, DimName::Padding).then_some(i))
1126 .collect();
1127 v.sort_by(|a, b| b.cmp(a));
1128 v
1129}
1130
1131fn logical_to_legacy_config_output(logical: &LogicalOutput) -> DecoderResult<ConfigOutput> {
1132 let decoder = logical
1133 .decoder
1134 .map(|d| d.to_v1())
1135 .unwrap_or(configs::DecoderType::Ultralytics);
1136 let quantization = logical
1137 .quantization
1138 .as_ref()
1139 .map(quantization_to_legacy)
1140 .transpose()?;
1141 let (shape, dshape) = squeeze_padding_dims(logical.shape.clone(), logical.dshape.clone());
1146
1147 Ok(match logical.type_ {
1148 LogicalType::Boxes => ConfigOutput::Boxes(configs::Boxes {
1149 decoder,
1150 quantization,
1151 shape,
1152 dshape,
1153 normalized: logical.normalized,
1154 }),
1155 LogicalType::Scores => ConfigOutput::Scores(configs::Scores {
1156 decoder,
1157 quantization,
1158 shape,
1159 dshape,
1160 }),
1161 LogicalType::Protos => ConfigOutput::Protos(configs::Protos {
1162 decoder,
1163 quantization,
1164 shape,
1165 dshape,
1166 }),
1167 LogicalType::MaskCoefs => ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
1168 decoder,
1169 quantization,
1170 shape,
1171 dshape,
1172 }),
1173 LogicalType::Segmentation => ConfigOutput::Segmentation(configs::Segmentation {
1174 decoder,
1175 quantization,
1176 shape,
1177 dshape,
1178 }),
1179 LogicalType::Masks => ConfigOutput::Mask(configs::Mask {
1180 decoder,
1181 quantization,
1182 shape,
1183 dshape,
1184 }),
1185 LogicalType::Classes => ConfigOutput::Classes(configs::Classes {
1186 decoder,
1187 quantization,
1188 shape,
1189 dshape,
1190 }),
1191 LogicalType::Detection | LogicalType::Detections => {
1195 ConfigOutput::Detection(configs::Detection {
1196 anchors: logical.anchors.clone(),
1197 decoder,
1198 quantization,
1199 shape,
1200 dshape,
1201 normalized: logical.normalized,
1202 })
1203 }
1204 LogicalType::Objectness | LogicalType::Landmarks => {
1207 return Err(DecoderError::NotSupported(format!(
1208 "logical type {:?} has no legacy v1 equivalent; use the \
1209 native v2 decoder path",
1210 logical.type_
1211 )));
1212 }
1213 })
1214}
1215
1216#[cfg(test)]
1217#[cfg_attr(coverage_nightly, coverage(off))]
1218mod tests {
1219 use super::*;
1220
1221 #[test]
1222 fn schema_default_is_v2() {
1223 let s = SchemaV2::default();
1224 assert_eq!(s.schema_version, 2);
1225 assert!(s.outputs.is_empty());
1226 }
1227
1228 #[test]
1229 fn dtype_roundtrip() {
1230 for d in [
1231 DType::Int8,
1232 DType::Uint8,
1233 DType::Int16,
1234 DType::Uint16,
1235 DType::Float16,
1236 DType::Float32,
1237 ] {
1238 let j = serde_json::to_string(&d).unwrap();
1239 let back: DType = serde_json::from_str(&j).unwrap();
1240 assert_eq!(back, d);
1241 }
1242 }
1243
1244 #[test]
1245 fn dtype_widths() {
1246 assert_eq!(DType::Int8.size_bytes(), 1);
1247 assert_eq!(DType::Float16.size_bytes(), 2);
1248 assert_eq!(DType::Float32.size_bytes(), 4);
1249 }
1250
1251 #[test]
1252 fn stride_accepts_scalar_or_pair() {
1253 let a: Stride = serde_json::from_str("8").unwrap();
1254 let b: Stride = serde_json::from_str("[8, 16]").unwrap();
1255 assert_eq!(a, Stride::Square(8));
1256 assert_eq!(b, Stride::Rect([8, 16]));
1257 assert_eq!(a.x(), 8);
1258 assert_eq!(a.y(), 8);
1259 assert_eq!(b.x(), 8);
1260 assert_eq!(b.y(), 16);
1261 }
1262
1263 #[test]
1264 fn quantization_scalar_scale() {
1265 let j = r#"{"scale": 0.00392, "zero_point": 0, "dtype": "int8"}"#;
1266 let q: Quantization = serde_json::from_str(j).unwrap();
1267 assert!(q.is_per_tensor());
1268 assert!(q.is_symmetric());
1269 assert_eq!(q.scale_at(0), 0.00392);
1270 assert_eq!(q.scale_at(5), 0.00392);
1271 assert_eq!(q.zero_point_at(0), 0);
1272 }
1273
1274 #[test]
1275 fn quantization_per_channel() {
1276 let j = r#"{"scale": [0.054, 0.089, 0.195], "axis": 0, "dtype": "int8"}"#;
1277 let q: Quantization = serde_json::from_str(j).unwrap();
1278 assert!(q.is_per_channel());
1279 assert!(q.is_symmetric());
1280 assert_eq!(q.axis, Some(0));
1281 assert_eq!(q.scale_at(0), 0.054);
1282 assert_eq!(q.scale_at(2), 0.195);
1283 }
1284
1285 #[test]
1286 fn quantization_asymmetric_per_tensor() {
1287 let j = r#"{"scale": 0.176, "zero_point": 198, "dtype": "uint8"}"#;
1288 let q: Quantization = serde_json::from_str(j).unwrap();
1289 assert!(!q.is_symmetric());
1290 assert_eq!(q.zero_point_at(0), 198);
1291 assert_eq!(q.zero_point_at(10), 198);
1292 }
1293
1294 #[test]
1295 fn quantization_symmetric_default_zero_point() {
1296 let j = r#"{"scale": 0.00392, "dtype": "int8"}"#;
1297 let q: Quantization = serde_json::from_str(j).unwrap();
1298 assert!(q.is_symmetric());
1299 assert_eq!(q.zero_point_at(0), 0);
1300 }
1301
1302 #[test]
1303 fn logical_output_flat_tflite_boxes() {
1304 let j = r#"{
1306 "name": "boxes", "type": "boxes",
1307 "shape": [1, 64, 8400],
1308 "dshape": [{"batch": 1}, {"num_features": 64}, {"num_boxes": 8400}],
1309 "dtype": "int8",
1310 "quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
1311 "decoder": "ultralytics",
1312 "encoding": "dfl",
1313 "normalized": true
1314 }"#;
1315 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1316 assert_eq!(lo.type_, LogicalType::Boxes);
1317 assert_eq!(lo.encoding, Some(BoxEncoding::Dfl));
1318 assert_eq!(lo.normalized, Some(true));
1319 assert!(!lo.is_split());
1320 assert_eq!(lo.dtype, Some(DType::Int8));
1321 }
1322
1323 #[test]
1324 fn logical_output_hailo_per_scale_split() {
1325 let j = r#"{
1327 "name": "boxes", "type": "boxes",
1328 "shape": [1, 64, 8400],
1329 "encoding": "dfl", "decoder": "ultralytics", "normalized": true,
1330 "outputs": [
1331 {
1332 "name": "boxes_0", "type": "boxes",
1333 "stride": 8, "scale_index": 0,
1334 "shape": [1, 80, 80, 64],
1335 "dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_features": 64}],
1336 "dtype": "uint8",
1337 "quantization": {"scale": 0.0234, "zero_point": 128, "dtype": "uint8"}
1338 }
1339 ]
1340 }"#;
1341 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1342 assert!(lo.is_split());
1343 assert_eq!(lo.outputs.len(), 1);
1344 let child = &lo.outputs[0];
1345 assert_eq!(child.name, "boxes_0");
1346 assert_eq!(child.type_, PhysicalType::Boxes);
1347 assert_eq!(child.stride, Some(Stride::Square(8)));
1348 assert_eq!(child.scale_index, Some(0));
1349 assert_eq!(child.dtype, DType::Uint8);
1350 }
1351
1352 #[test]
1353 fn logical_output_ara2_xy_wh_channel_split() {
1354 let j = r#"{
1356 "name": "boxes", "type": "boxes",
1357 "shape": [1, 4, 8400, 1],
1358 "encoding": "direct", "decoder": "ultralytics", "normalized": true,
1359 "outputs": [
1360 {
1361 "name": "_model_22_Div_1_output_0", "type": "boxes_xy",
1362 "shape": [1, 2, 8400, 1],
1363 "dshape": [{"batch": 1}, {"box_coords": 2}, {"num_boxes": 8400}, {"padding": 1}],
1364 "dtype": "int16",
1365 "quantization": {"scale": 3.129e-5, "zero_point": 0, "dtype": "int16"}
1366 },
1367 {
1368 "name": "_model_22_Sub_1_output_0", "type": "boxes_wh",
1369 "shape": [1, 2, 8400, 1],
1370 "dshape": [{"batch": 1}, {"box_coords": 2}, {"num_boxes": 8400}, {"padding": 1}],
1371 "dtype": "int16",
1372 "quantization": {"scale": 3.149e-5, "zero_point": 0, "dtype": "int16"}
1373 }
1374 ]
1375 }"#;
1376 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1377 assert_eq!(lo.encoding, Some(BoxEncoding::Direct));
1378 assert_eq!(lo.outputs.len(), 2);
1379 assert_eq!(lo.outputs[0].type_, PhysicalType::BoxesXy);
1380 assert_eq!(lo.outputs[1].type_, PhysicalType::BoxesWh);
1381 assert!(lo.outputs[0].stride.is_none());
1382 assert!(lo.outputs[1].stride.is_none());
1383 }
1384
1385 #[test]
1386 fn logical_output_hailo_scores_sigmoid_applied() {
1387 let j = r#"{
1388 "name": "scores", "type": "scores",
1389 "shape": [1, 80, 8400],
1390 "decoder": "ultralytics", "score_format": "per_class",
1391 "outputs": [
1392 {
1393 "name": "scores_0", "type": "scores",
1394 "stride": 8, "scale_index": 0,
1395 "shape": [1, 80, 80, 80],
1396 "dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_classes": 80}],
1397 "dtype": "uint8",
1398 "quantization": {"scale": 0.003922, "dtype": "uint8"},
1399 "activation_applied": "sigmoid"
1400 }
1401 ]
1402 }"#;
1403 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1404 assert_eq!(lo.score_format, Some(ScoreFormat::PerClass));
1405 let child = &lo.outputs[0];
1406 assert_eq!(child.activation_applied, Some(Activation::Sigmoid));
1407 assert!(child.activation_required.is_none());
1408 }
1409
1410 #[test]
1411 fn yolo26_end_to_end_detections() {
1412 let j = r#"{
1413 "schema_version": 2,
1414 "decoder_version": "yolo26",
1415 "outputs": [{
1416 "name": "output0", "type": "detections",
1417 "shape": [1, 100, 6],
1418 "dshape": [{"batch": 1}, {"num_boxes": 100}, {"num_features": 6}],
1419 "dtype": "int8",
1420 "quantization": {"scale": 0.0078, "zero_point": 0, "dtype": "int8"},
1421 "normalized": false,
1422 "decoder": "ultralytics"
1423 }]
1424 }"#;
1425 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1426 assert_eq!(s.decoder_version, Some(DecoderVersion::Yolo26));
1427 assert!(s.decoder_version.unwrap().is_end_to_end());
1428 assert_eq!(s.outputs[0].type_, LogicalType::Detections);
1429 assert_eq!(s.outputs[0].normalized, Some(false));
1430 assert!(s.nms.is_none());
1431 }
1432
1433 #[test]
1434 fn modelpack_anchor_detection_with_rect_stride() {
1435 let j = r#"{
1436 "schema_version": 2,
1437 "outputs": [{
1438 "name": "output_0", "type": "detection",
1439 "shape": [1, 40, 40, 54],
1440 "dshape": [{"batch": 1}, {"height": 40}, {"width": 40}, {"num_anchors_x_features": 54}],
1441 "dtype": "uint8",
1442 "quantization": {"scale": 0.176, "zero_point": 198, "dtype": "uint8"},
1443 "decoder": "modelpack",
1444 "encoding": "anchor",
1445 "stride": [16, 16],
1446 "anchors": [[0.054, 0.065], [0.089, 0.139], [0.195, 0.196]]
1447 }]
1448 }"#;
1449 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1450 let lo = &s.outputs[0];
1451 assert_eq!(lo.encoding, Some(BoxEncoding::Anchor));
1452 assert_eq!(lo.stride, Some(Stride::Rect([16, 16])));
1453 assert_eq!(lo.anchors.as_ref().map(|a| a.len()), Some(3));
1454 }
1455
1456 #[test]
1457 fn yolov5_obj_x_class_objectness_logical() {
1458 let j = r#"{
1459 "name": "objectness", "type": "objectness",
1460 "shape": [1, 3, 8400],
1461 "decoder": "ultralytics",
1462 "outputs": [{
1463 "name": "objectness_0", "type": "objectness",
1464 "stride": 8, "scale_index": 0,
1465 "shape": [1, 80, 80, 3],
1466 "dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_features": 3}],
1467 "dtype": "uint8",
1468 "quantization": {"scale": 0.0039, "zero_point": 0, "dtype": "uint8"},
1469 "activation_applied": "sigmoid"
1470 }]
1471 }"#;
1472 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1473 assert_eq!(lo.type_, LogicalType::Objectness);
1474 assert_eq!(lo.outputs[0].activation_applied, Some(Activation::Sigmoid));
1475 }
1476
1477 #[test]
1478 fn direct_protos_no_decoder() {
1479 let j = r#"{
1481 "name": "protos", "type": "protos",
1482 "shape": [1, 32, 160, 160],
1483 "dshape": [{"batch": 1}, {"num_protos": 32}, {"height": 160}, {"width": 160}],
1484 "dtype": "uint8",
1485 "quantization": {"scale": 0.0203, "zero_point": 45, "dtype": "uint8"},
1486 "stride": 4
1487 }"#;
1488 let lo: LogicalOutput = serde_json::from_str(j).unwrap();
1489 assert_eq!(lo.type_, LogicalType::Protos);
1490 assert!(lo.decoder.is_none());
1491 assert_eq!(lo.stride, Some(Stride::Square(4)));
1492 }
1493
1494 #[test]
1495 fn full_yolov8_tflite_flat_detection() {
1496 let j = r#"{
1498 "schema_version": 2,
1499 "decoder_version": "yolov8",
1500 "nms": "class_agnostic",
1501 "input": { "shape": [1, 640, 640, 3], "cameraadaptor": "rgb" },
1502 "outputs": [
1503 {
1504 "name": "boxes", "type": "boxes",
1505 "shape": [1, 64, 8400],
1506 "dshape": [{"batch": 1}, {"num_features": 64}, {"num_boxes": 8400}],
1507 "dtype": "int8",
1508 "quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
1509 "decoder": "ultralytics",
1510 "encoding": "dfl",
1511 "normalized": true
1512 },
1513 {
1514 "name": "scores", "type": "scores",
1515 "shape": [1, 80, 8400],
1516 "dshape": [{"batch": 1}, {"num_classes": 80}, {"num_boxes": 8400}],
1517 "dtype": "int8",
1518 "quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
1519 "decoder": "ultralytics",
1520 "score_format": "per_class"
1521 }
1522 ]
1523 }"#;
1524 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1525 assert_eq!(s.schema_version, 2);
1526 assert_eq!(s.decoder_version, Some(DecoderVersion::Yolov8));
1527 assert_eq!(s.nms, Some(NmsMode::ClassAgnostic));
1528 assert_eq!(s.input.as_ref().unwrap().shape, vec![1, 640, 640, 3]);
1529 assert_eq!(s.outputs.len(), 2);
1530 }
1531
1532 #[test]
1533 fn schema_unknown_version_parses_without_validation() {
1534 let j = r#"{"schema_version": 99, "outputs": []}"#;
1537 let s: SchemaV2 = serde_json::from_str(j).unwrap();
1538 assert_eq!(s.schema_version, 99);
1539 }
1540
1541 #[test]
1542 fn serde_roundtrip_preserves_fields() {
1543 let original = SchemaV2 {
1544 schema_version: 2,
1545 input: Some(InputSpec {
1546 shape: vec![1, 3, 640, 640],
1547 dshape: vec![],
1548 cameraadaptor: Some("rgb".into()),
1549 }),
1550 outputs: vec![LogicalOutput {
1551 name: Some("boxes".into()),
1552 type_: LogicalType::Boxes,
1553 shape: vec![1, 4, 8400],
1554 dshape: vec![],
1555 decoder: Some(DecoderKind::Ultralytics),
1556 encoding: Some(BoxEncoding::Direct),
1557 score_format: None,
1558 normalized: Some(true),
1559 anchors: None,
1560 stride: None,
1561 dtype: Some(DType::Float32),
1562 quantization: None,
1563 outputs: vec![],
1564 }],
1565 nms: Some(NmsMode::ClassAgnostic),
1566 decoder_version: Some(DecoderVersion::Yolov8),
1567 };
1568 let j = serde_json::to_string(&original).unwrap();
1569 let parsed: SchemaV2 = serde_json::from_str(&j).unwrap();
1570 assert_eq!(parsed, original);
1571 }
1572
1573 #[test]
1576 fn parse_v1_yaml_yolov8_seg_testdata() {
1577 let yaml = include_str!(concat!(
1578 env!("CARGO_MANIFEST_DIR"),
1579 "/../../testdata/yolov8_seg.yaml"
1580 ));
1581 let schema = SchemaV2::parse_yaml(yaml).expect("parse v1 yaml");
1582 assert_eq!(schema.schema_version, 2);
1583 assert_eq!(schema.outputs.len(), 2);
1584 let det = &schema.outputs[0];
1586 assert_eq!(det.type_, LogicalType::Detection);
1587 assert_eq!(det.shape, vec![1, 116, 8400]);
1588 assert_eq!(det.decoder, Some(DecoderKind::Ultralytics));
1589 assert_eq!(det.encoding, Some(BoxEncoding::Direct));
1590 let q = det.quantization.as_ref().unwrap();
1591 assert_eq!(q.scale.len(), 1);
1592 assert!((q.scale[0] - 0.021_287_762).abs() < 1e-6);
1593 assert_eq!(q.zero_point, Some(vec![31]));
1594 let protos = &schema.outputs[1];
1596 assert_eq!(protos.type_, LogicalType::Protos);
1597 assert_eq!(protos.shape, vec![1, 160, 160, 32]);
1598 }
1599
1600 #[test]
1601 fn parse_v1_json_modelpack_split_testdata() {
1602 let json = include_str!(concat!(
1603 env!("CARGO_MANIFEST_DIR"),
1604 "/../../testdata/modelpack_split.json"
1605 ));
1606 let schema = SchemaV2::parse_json(json).expect("parse v1 json");
1607 assert_eq!(schema.schema_version, 2);
1608 assert_eq!(schema.outputs.len(), 2);
1609 for out in &schema.outputs {
1611 assert_eq!(out.type_, LogicalType::Detection);
1612 assert_eq!(out.decoder, Some(DecoderKind::ModelPack));
1613 assert_eq!(out.encoding, Some(BoxEncoding::Anchor));
1614 assert_eq!(out.anchors.as_ref().map(|a| a.len()), Some(3));
1615 }
1616 }
1617
1618 #[test]
1619 fn parse_v2_json_direct_when_schema_version_present() {
1620 let j = r#"{
1621 "schema_version": 2,
1622 "outputs": [{
1623 "name": "boxes", "type": "boxes",
1624 "shape": [1, 4, 8400],
1625 "dshape": [{"batch": 1}, {"box_coords": 4}, {"num_boxes": 8400}],
1626 "dtype": "float32",
1627 "decoder": "ultralytics",
1628 "encoding": "direct",
1629 "normalized": true
1630 }]
1631 }"#;
1632 let schema = SchemaV2::parse_json(j).unwrap();
1633 assert_eq!(schema.schema_version, 2);
1634 assert_eq!(schema.outputs[0].type_, LogicalType::Boxes);
1635 }
1636
1637 #[test]
1638 fn parse_rejects_future_schema_version() {
1639 let j = r#"{"schema_version": 99, "outputs": []}"#;
1640 let err = SchemaV2::parse_json(j).unwrap_err();
1641 matches!(err, DecoderError::NotSupported(_));
1642 }
1643
1644 #[test]
1645 fn parse_absent_schema_version_treats_as_v1() {
1646 let j = r#"{
1648 "outputs": [
1649 {
1650 "type": "boxes", "decoder": "ultralytics",
1651 "shape": [1, 4, 8400],
1652 "quantization": [0.00392, 0]
1653 },
1654 {
1655 "type": "scores", "decoder": "ultralytics",
1656 "shape": [1, 80, 8400],
1657 "quantization": [0.00392, 0]
1658 }
1659 ]
1660 }"#;
1661 let schema = SchemaV2::parse_json(j).expect("v1 legacy parse");
1662 assert_eq!(schema.schema_version, 2); assert_eq!(schema.outputs.len(), 2);
1664 assert_eq!(schema.outputs[0].type_, LogicalType::Boxes);
1665 assert_eq!(schema.outputs[1].type_, LogicalType::Scores);
1666 assert_eq!(schema.outputs[1].score_format, Some(ScoreFormat::PerClass));
1668 }
1669
1670 #[test]
1671 fn from_v1_preserves_nms_and_decoder_version() {
1672 let v1 = ConfigOutputs {
1673 outputs: vec![ConfigOutput::Boxes(crate::configs::Boxes {
1674 decoder: crate::configs::DecoderType::Ultralytics,
1675 quantization: Some(crate::configs::QuantTuple(0.01, 5)),
1676 shape: vec![1, 4, 8400],
1677 dshape: vec![],
1678 normalized: Some(true),
1679 })],
1680 nms: Some(crate::configs::Nms::ClassAware),
1681 decoder_version: Some(crate::configs::DecoderVersion::Yolo11),
1682 };
1683 let v2 = SchemaV2::from_v1(&v1).unwrap();
1684 assert_eq!(v2.nms, Some(NmsMode::ClassAware));
1685 assert_eq!(v2.decoder_version, Some(DecoderVersion::Yolo11));
1686 assert_eq!(v2.outputs[0].normalized, Some(true));
1687 let q = v2.outputs[0].quantization.as_ref().unwrap();
1688 assert_eq!(q.scale, vec![0.01]);
1689 assert_eq!(q.zero_point, Some(vec![5]));
1690 assert_eq!(q.dtype, None); }
1692
1693 #[test]
1694 fn from_v1_modelpack_anchor_detection_maps_encoding() {
1695 let v1 = ConfigOutputs {
1696 outputs: vec![ConfigOutput::Detection(crate::configs::Detection {
1697 anchors: Some(vec![[0.1, 0.2], [0.3, 0.4]]),
1698 decoder: crate::configs::DecoderType::ModelPack,
1699 quantization: Some(crate::configs::QuantTuple(0.176, 198)),
1700 shape: vec![1, 40, 40, 54],
1701 dshape: vec![],
1702 normalized: None,
1703 })],
1704 nms: None,
1705 decoder_version: None,
1706 };
1707 let v2 = SchemaV2::from_v1(&v1).unwrap();
1708 assert_eq!(v2.outputs[0].encoding, Some(BoxEncoding::Anchor));
1709 assert_eq!(v2.outputs[0].decoder, Some(DecoderKind::ModelPack));
1710 assert_eq!(v2.outputs[0].anchors.as_ref().map(|a| a.len()), Some(2));
1711 }
1712
1713 #[test]
1716 fn validate_accepts_flat_v2_yolov8_detection() {
1717 let j = r#"{
1718 "schema_version": 2,
1719 "outputs": [
1720 {"name":"boxes","type":"boxes","shape":[1,64,8400],
1721 "dtype":"int8","decoder":"ultralytics","encoding":"dfl"},
1722 {"name":"scores","type":"scores","shape":[1,80,8400],
1723 "dtype":"int8","decoder":"ultralytics","score_format":"per_class"}
1724 ]
1725 }"#;
1726 SchemaV2::parse_json(j).unwrap().validate().unwrap();
1727 }
1728
1729 #[test]
1730 fn validate_rejects_unnamed_physical_child() {
1731 let j = r#"{
1732 "schema_version": 2,
1733 "outputs": [{
1734 "name":"boxes","type":"boxes","shape":[1,64,8400],
1735 "encoding":"dfl","decoder":"ultralytics",
1736 "outputs": [{
1737 "name":"","type":"boxes","stride":8,
1738 "shape":[1,80,80,64],"dtype":"uint8"
1739 }]
1740 }]
1741 }"#;
1742 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
1743 let msg = format!("{err}");
1744 assert!(msg.contains("missing `name`"), "got: {msg}");
1745 }
1746
1747 #[test]
1748 fn validate_rejects_duplicate_physical_shapes() {
1749 let j = r#"{
1750 "schema_version": 2,
1751 "outputs": [{
1752 "name":"boxes","type":"boxes","shape":[1,64,8400],
1753 "encoding":"dfl","decoder":"ultralytics",
1754 "outputs": [
1755 {"name":"a","type":"boxes","stride":8,"shape":[1,80,80,64],"dtype":"uint8"},
1756 {"name":"b","type":"boxes","stride":16,"shape":[1,80,80,64],"dtype":"uint8"}
1757 ]
1758 }]
1759 }"#;
1760 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
1761 let msg = format!("{err}");
1762 assert!(msg.contains("share shape"), "got: {msg}");
1763 }
1764
1765 #[test]
1766 fn validate_rejects_mixed_decomposition() {
1767 let j = r#"{
1769 "schema_version": 2,
1770 "outputs": [{
1771 "name":"boxes","type":"boxes","shape":[1,4,8400,1],
1772 "encoding":"direct","decoder":"ultralytics",
1773 "outputs": [
1774 {"name":"xy","type":"boxes_xy","shape":[1,2,8400,1],"dtype":"int16"},
1775 {"name":"p0","type":"boxes","stride":8,"shape":[1,80,80,64],"dtype":"uint8"}
1776 ]
1777 }]
1778 }"#;
1779 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
1780 let msg = format!("{err}");
1781 assert!(msg.contains("uniform"), "got: {msg}");
1782 }
1783
1784 #[test]
1785 fn validate_rejects_dfl_boxes_feature_not_divisible_by_4() {
1786 let j = r#"{
1787 "schema_version": 2,
1788 "outputs": [{
1789 "name":"boxes","type":"boxes","shape":[1,63,8400],
1790 "encoding":"dfl","decoder":"ultralytics",
1791 "outputs": [{
1792 "name":"b0","type":"boxes","stride":8,
1793 "shape":[1,80,80,63],
1794 "dshape":[{"batch":1},{"height":80},{"width":80},{"num_features":63}],
1795 "dtype":"uint8"
1796 }]
1797 }]
1798 }"#;
1799 let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
1800 let msg = format!("{err}");
1801 assert!(msg.contains("not"), "got: {msg}");
1802 assert!(msg.contains("divisible by 4"), "got: {msg}");
1803 }
1804
1805 #[test]
1806 fn validate_accepts_hailo_per_scale_yolov8() {
1807 let j = r#"{
1808 "schema_version": 2,
1809 "outputs": [{
1810 "name":"boxes","type":"boxes","shape":[1,64,8400],
1811 "encoding":"dfl","decoder":"ultralytics","normalized":true,
1812 "outputs": [
1813 {"name":"b0","type":"boxes","stride":8,
1814 "shape":[1,80,80,64],
1815 "dshape":[{"batch":1},{"height":80},{"width":80},{"num_features":64}],
1816 "dtype":"uint8",
1817 "quantization":{"scale":0.0234,"zero_point":128,"dtype":"uint8"}},
1818 {"name":"b1","type":"boxes","stride":16,
1819 "shape":[1,40,40,64],
1820 "dshape":[{"batch":1},{"height":40},{"width":40},{"num_features":64}],
1821 "dtype":"uint8",
1822 "quantization":{"scale":0.0198,"zero_point":130,"dtype":"uint8"}},
1823 {"name":"b2","type":"boxes","stride":32,
1824 "shape":[1,20,20,64],
1825 "dshape":[{"batch":1},{"height":20},{"width":20},{"num_features":64}],
1826 "dtype":"uint8",
1827 "quantization":{"scale":0.0312,"zero_point":125,"dtype":"uint8"}}
1828 ]
1829 }]
1830 }"#;
1831 let s = SchemaV2::parse_json(j).unwrap();
1832 s.validate().unwrap();
1833 }
1834
1835 #[test]
1836 fn validate_accepts_ara2_xy_wh() {
1837 let j = r#"{
1838 "schema_version": 2,
1839 "outputs": [{
1840 "name":"boxes","type":"boxes","shape":[1,4,8400,1],
1841 "encoding":"direct","decoder":"ultralytics","normalized":true,
1842 "outputs": [
1843 {"name":"xy","type":"boxes_xy","shape":[1,2,8400,1],
1844 "dshape":[{"batch":1},{"box_coords":2},{"num_boxes":8400},{"padding":1}],
1845 "dtype":"int16",
1846 "quantization":{"scale":3.1e-5,"zero_point":0,"dtype":"int16"}},
1847 {"name":"wh","type":"boxes_wh","shape":[1,2,8400,1],
1848 "dshape":[{"batch":1},{"box_coords":2},{"num_boxes":8400},{"padding":1}],
1849 "dtype":"int16",
1850 "quantization":{"scale":3.2e-5,"zero_point":0,"dtype":"int16"}}
1851 ]
1852 }]
1853 }"#;
1854 SchemaV2::parse_json(j).unwrap().validate().unwrap();
1855 }
1856
1857 #[test]
1858 fn parse_file_auto_detects_json() {
1859 let tmp = std::env::temp_dir().join(format!("schema_v2_test_{}.json", std::process::id()));
1860 std::fs::write(&tmp, r#"{"schema_version":2,"outputs":[]}"#).unwrap();
1861 let s = SchemaV2::parse_file(&tmp).unwrap();
1862 assert_eq!(s.schema_version, 2);
1863 let _ = std::fs::remove_file(&tmp);
1864 }
1865
1866 #[test]
1867 fn parse_file_auto_detects_yaml() {
1868 let tmp = std::env::temp_dir().join(format!("schema_v2_test_{}.yaml", std::process::id()));
1869 std::fs::write(&tmp, "schema_version: 2\noutputs: []\n").unwrap();
1870 let s = SchemaV2::parse_file(&tmp).unwrap();
1871 assert_eq!(s.schema_version, 2);
1872 let _ = std::fs::remove_file(&tmp);
1873 }
1874
1875 #[test]
1878 fn parse_real_ara2_int8_dvm_metadata() {
1879 let json = include_str!(concat!(
1880 env!("CARGO_MANIFEST_DIR"),
1881 "/../../testdata/ara2_int8_edgefirst.json"
1882 ));
1883 let schema = SchemaV2::parse_json(json).expect("ARA-2 int8 parse");
1884 assert_eq!(schema.schema_version, 2);
1885 assert_eq!(schema.decoder_version, Some(DecoderVersion::Yolov8));
1886 assert_eq!(schema.nms, Some(NmsMode::ClassAgnostic));
1887 assert_eq!(schema.input.as_ref().unwrap().shape, vec![1, 3, 640, 640]);
1888
1889 assert_eq!(schema.outputs.len(), 4);
1891 let boxes = &schema.outputs[0];
1892 assert_eq!(boxes.type_, LogicalType::Boxes);
1893 assert_eq!(boxes.encoding, Some(BoxEncoding::Direct));
1894 assert_eq!(boxes.normalized, Some(true));
1895 assert_eq!(boxes.shape, vec![1, 4, 8400, 1]); assert_eq!(boxes.outputs.len(), 2);
1897 assert_eq!(boxes.outputs[0].type_, PhysicalType::BoxesXy);
1898 assert_eq!(boxes.outputs[1].type_, PhysicalType::BoxesWh);
1899 let q_xy = boxes.outputs[0].quantization.as_ref().unwrap();
1901 assert_eq!(q_xy.dtype, Some(DType::Int8));
1902 assert!((q_xy.scale[0] - 0.004_177_792).abs() < 1e-6);
1903 assert_eq!(q_xy.zero_point_at(0), -122);
1904
1905 let scores = &schema.outputs[1];
1906 assert_eq!(scores.type_, LogicalType::Scores);
1907 assert_eq!(scores.score_format, Some(ScoreFormat::PerClass));
1908 assert_eq!(scores.shape, vec![1, 80, 8400, 1]);
1909
1910 let mask_coefs = &schema.outputs[2];
1911 assert_eq!(mask_coefs.type_, LogicalType::MaskCoefs);
1912 assert_eq!(mask_coefs.shape, vec![1, 32, 8400, 1]);
1913
1914 let protos = &schema.outputs[3];
1915 assert_eq!(protos.type_, LogicalType::Protos);
1916 assert_eq!(protos.shape, vec![1, 32, 160, 160]);
1917
1918 schema.validate().expect("ARA-2 int8 validate");
1920 }
1921
1922 #[test]
1923 fn parse_real_ara2_int16_dvm_metadata() {
1924 let json = include_str!(concat!(
1925 env!("CARGO_MANIFEST_DIR"),
1926 "/../../testdata/ara2_int16_edgefirst.json"
1927 ));
1928 let schema = SchemaV2::parse_json(json).expect("ARA-2 int16 parse");
1929 assert_eq!(schema.schema_version, 2);
1930 assert_eq!(schema.outputs.len(), 4);
1931 let boxes = &schema.outputs[0];
1932 assert_eq!(boxes.outputs.len(), 2);
1933 let q_xy = boxes.outputs[0].quantization.as_ref().unwrap();
1934 assert_eq!(q_xy.dtype, Some(DType::Int16));
1935 assert!((q_xy.scale[0] - 3.211_570_6e-5).abs() < 1e-10);
1936 assert_eq!(q_xy.zero_point_at(0), 0);
1937 let mc_q = schema.outputs[2].quantization.as_ref().unwrap();
1939 assert_eq!(mc_q.dtype, Some(DType::Int16));
1940 schema.validate().expect("ARA-2 int16 validate");
1941 }
1942
1943 #[test]
1944 fn parse_yaml_with_explicit_schema_version_2() {
1945 let yaml = r#"
1946schema_version: 2
1947outputs:
1948 - name: scores
1949 type: scores
1950 shape: [1, 80, 8400]
1951 dtype: int8
1952 quantization:
1953 scale: 0.00392
1954 dtype: int8
1955 decoder: ultralytics
1956 score_format: per_class
1957"#;
1958 let schema = SchemaV2::parse_yaml(yaml).unwrap();
1959 assert_eq!(schema.schema_version, 2);
1960 assert_eq!(schema.outputs[0].score_format, Some(ScoreFormat::PerClass));
1961 }
1962}