use crate::configs::{self, deserialize_dshape, DimName, QuantTuple};
use crate::{ConfigOutput, ConfigOutputs, DecoderError, DecoderResult};
use serde::{Deserialize, Serialize};
pub const MAX_SUPPORTED_SCHEMA_VERSION: u32 = 2;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SchemaV2 {
pub schema_version: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input: Option<InputSpec>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub outputs: Vec<LogicalOutput>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub nms: Option<NmsMode>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decoder_version: Option<DecoderVersion>,
}
impl Default for SchemaV2 {
fn default() -> Self {
Self {
schema_version: 2,
input: None,
outputs: Vec::new(),
nms: None,
decoder_version: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InputSpec {
pub shape: Vec<usize>,
#[serde(
default,
deserialize_with = "deserialize_dshape",
skip_serializing_if = "Vec::is_empty"
)]
pub dshape: Vec<(DimName, usize)>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cameraadaptor: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct LogicalOutput {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
pub type_: Option<LogicalType>,
pub shape: Vec<usize>,
#[serde(
default,
deserialize_with = "deserialize_dshape",
skip_serializing_if = "Vec::is_empty"
)]
pub dshape: Vec<(DimName, usize)>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decoder: Option<DecoderKind>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub encoding: Option<BoxEncoding>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub score_format: Option<ScoreFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub normalized: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub anchors: Option<Vec<[f32; 2]>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stride: Option<Stride>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dtype: Option<DType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub quantization: Option<Quantization>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub outputs: Vec<PhysicalOutput>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub activation_applied: Option<Activation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub activation_required: Option<Activation>,
}
impl LogicalOutput {
pub fn is_split(&self) -> bool {
!self.outputs.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PhysicalOutput {
pub name: String,
#[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
pub type_: Option<PhysicalType>,
pub shape: Vec<usize>,
#[serde(
default,
deserialize_with = "deserialize_dshape",
skip_serializing_if = "Vec::is_empty"
)]
pub dshape: Vec<(DimName, usize)>,
pub dtype: DType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub quantization: Option<Quantization>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stride: Option<Stride>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub scale_index: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub activation_applied: Option<Activation>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub activation_required: Option<Activation>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Quantization {
#[serde(deserialize_with = "deserialize_scalar_or_vec_f32")]
pub scale: Vec<f32>,
#[serde(
default,
deserialize_with = "deserialize_opt_scalar_or_vec_i32",
skip_serializing_if = "Option::is_none"
)]
pub zero_point: Option<Vec<i32>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub axis: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dtype: Option<DType>,
}
impl Quantization {
pub fn is_per_tensor(&self) -> bool {
self.scale.len() == 1
}
pub fn is_per_channel(&self) -> bool {
self.scale.len() > 1
}
pub fn is_symmetric(&self) -> bool {
match &self.zero_point {
None => true,
Some(zps) => zps.iter().all(|&z| z == 0),
}
}
pub fn zero_point_at(&self, channel: usize) -> i32 {
match &self.zero_point {
None => 0,
Some(zps) if zps.len() == 1 => zps[0],
Some(zps) => zps.get(channel).copied().unwrap_or(0),
}
}
pub fn scale_at(&self, channel: usize) -> f32 {
if self.scale.len() == 1 {
self.scale[0]
} else {
self.scale.get(channel).copied().unwrap_or(0.0)
}
}
}
impl TryFrom<&Quantization> for edgefirst_tensor::Quantization {
type Error = edgefirst_tensor::Error;
fn try_from(q: &Quantization) -> Result<Self, Self::Error> {
match (q.scale.as_slice(), q.zero_point.as_deref(), q.axis) {
([scale], None, None) => Ok(Self::per_tensor_symmetric(*scale)),
([scale], Some([zp]), None) => Ok(Self::per_tensor(*scale, *zp)),
([scale], Some([zp]), Some(_)) => Ok(Self::per_tensor(*scale, *zp)),
([scale], None, Some(_)) => Ok(Self::per_tensor_symmetric(*scale)),
(scales, None, Some(axis)) if scales.len() > 1 => {
Self::per_channel_symmetric(scales.to_vec(), axis)
}
(scales, Some(zps), Some(axis)) if scales.len() > 1 => {
Self::per_channel(scales.to_vec(), zps.to_vec(), axis)
}
(scales, _, None) if scales.len() > 1 => {
Err(edgefirst_tensor::Error::QuantizationInvalid {
field: "axis",
expected: "Some(axis) for per-channel".into(),
got: "None".into(),
})
}
_ => Err(edgefirst_tensor::Error::QuantizationInvalid {
field: "scale",
expected: "non-empty".into(),
got: format!("len={}", q.scale.len()),
}),
}
}
}
fn deserialize_scalar_or_vec_f32<'de, D>(de: D) -> Result<Vec<f32>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum OneOrMany {
One(f32),
Many(Vec<f32>),
}
match OneOrMany::deserialize(de)? {
OneOrMany::One(v) => Ok(vec![v]),
OneOrMany::Many(vs) => Ok(vs),
}
}
fn deserialize_opt_scalar_or_vec_i32<'de, D>(de: D) -> Result<Option<Vec<i32>>, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum OneOrMany {
One(i32),
Many(Vec<i32>),
}
match Option::<OneOrMany>::deserialize(de)? {
None => Ok(None),
Some(OneOrMany::One(v)) => Ok(Some(vec![v])),
Some(OneOrMany::Many(vs)) => Ok(Some(vs)),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Stride {
Square(u32),
Rect([u32; 2]),
}
impl Stride {
pub fn x(self) -> u32 {
match self {
Stride::Square(s) => s,
Stride::Rect([sx, _]) => sx,
}
}
pub fn y(self) -> u32 {
match self {
Stride::Square(s) => s,
Stride::Rect([_, sy]) => sy,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LogicalType {
Boxes,
Scores,
Objectness,
Classes,
MaskCoefs,
Protos,
Landmarks,
Detections,
Segmentation,
Masks,
Detection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PhysicalType {
Boxes,
Scores,
Objectness,
Classes,
MaskCoefs,
Protos,
Landmarks,
Detections,
Segmentation,
Masks,
Detection,
BoxesXy,
BoxesWh,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BoxEncoding {
Dfl,
#[serde(alias = "ltrb")]
Direct,
Anchor,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ScoreFormat {
PerClass,
ObjXClass,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Activation {
Sigmoid,
Softmax,
Tanh,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DecoderKind {
#[serde(rename = "modelpack")]
ModelPack,
#[serde(rename = "ultralytics")]
Ultralytics,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DecoderVersion {
Yolov5,
Yolov8,
Yolo11,
Yolo26,
}
impl DecoderVersion {
pub fn is_end_to_end(self) -> bool {
matches!(self, DecoderVersion::Yolo26)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum NmsMode {
ClassAgnostic,
ClassAware,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DType {
Int8,
Uint8,
Int16,
Uint16,
Int32,
Uint32,
Float16,
Float32,
}
impl DType {
pub fn size_bytes(self) -> usize {
match self {
DType::Int8 | DType::Uint8 => 1,
DType::Int16 | DType::Uint16 | DType::Float16 => 2,
DType::Int32 | DType::Uint32 | DType::Float32 => 4,
}
}
pub fn is_integer(self) -> bool {
matches!(
self,
DType::Int8
| DType::Uint8
| DType::Int16
| DType::Uint16
| DType::Int32
| DType::Uint32
)
}
pub fn is_float(self) -> bool {
matches!(self, DType::Float16 | DType::Float32)
}
}
impl SchemaV2 {
pub fn parse_json(s: &str) -> DecoderResult<Self> {
let value: serde_json::Value = serde_json::from_str(s)?;
Self::from_json_value(value)
}
pub fn parse_yaml(s: &str) -> DecoderResult<Self> {
let value: serde_yaml::Value = serde_yaml::from_str(s)?;
let json = serde_json::to_value(value)
.map_err(|e| DecoderError::InvalidConfig(format!("yaml→json bridge failed: {e}")))?;
Self::from_json_value(json)
}
pub fn parse_file(path: impl AsRef<std::path::Path>) -> DecoderResult<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)
.map_err(|e| DecoderError::InvalidConfig(format!("read {}: {e}", path.display())))?;
let ext = path
.extension()
.and_then(|e| e.to_str())
.map(str::to_ascii_lowercase);
match ext.as_deref() {
Some("json") => Self::parse_json(&content),
Some("yaml") | Some("yml") => Self::parse_yaml(&content),
_ => Self::parse_json(&content).or_else(|_| Self::parse_yaml(&content)),
}
}
pub fn from_json_value(value: serde_json::Value) -> DecoderResult<Self> {
let version = value
.get("schema_version")
.and_then(|v| v.as_u64())
.map(|v| v as u32)
.unwrap_or(1);
if version > MAX_SUPPORTED_SCHEMA_VERSION {
return Err(DecoderError::NotSupported(format!(
"schema_version {version} is not supported by this HAL \
(maximum supported version is {MAX_SUPPORTED_SCHEMA_VERSION}); \
upgrade the HAL or downgrade the metadata"
)));
}
if version >= 2 {
serde_json::from_value(value).map_err(DecoderError::Json)
} else {
let v1: ConfigOutputs = serde_json::from_value(value).map_err(DecoderError::Json)?;
Self::from_v1(&v1)
}
}
pub fn from_v1(v1: &ConfigOutputs) -> DecoderResult<Self> {
let outputs = v1
.outputs
.iter()
.map(logical_from_v1)
.collect::<DecoderResult<Vec<_>>>()?;
Ok(SchemaV2 {
schema_version: 2,
input: None,
outputs,
nms: v1.nms.as_ref().map(NmsMode::from_v1),
decoder_version: v1.decoder_version.as_ref().map(DecoderVersion::from_v1),
})
}
}
impl SchemaV2 {
pub fn to_legacy_config_outputs(&self) -> DecoderResult<ConfigOutputs> {
let mut outputs = Vec::with_capacity(self.outputs.len());
for logical in &self.outputs {
if logical.type_.is_none() {
continue;
}
if logical.type_ == Some(LogicalType::Boxes)
&& logical.encoding == Some(BoxEncoding::Dfl)
&& logical.outputs.is_empty()
{
return Err(DecoderError::NotSupported(format!(
"`boxes` output `{}` has `encoding: dfl` on a flat \
logical (no per-scale children); the HAL's DFL \
decode kernel only runs inside the per-scale merge \
path. Split the boxes output into per-FPN-level \
children (Hailo convention) or pre-decode to 4 \
channels in the model graph (TFLite convention).",
logical.name.as_deref().unwrap_or("<anonymous>"),
)));
}
if let Some(q) = &logical.quantization {
if q.is_per_channel() {
return Err(DecoderError::NotSupported(format!(
"logical `{}` uses per-channel quantization \
(axis {:?}, {} scales); the v1 decoder only \
supports per-tensor quantization",
logical.name.as_deref().unwrap_or("<anonymous>"),
q.axis,
q.scale.len(),
)));
}
}
outputs.push(logical_to_legacy_config_output(logical)?);
}
Ok(ConfigOutputs {
outputs,
nms: self.nms.map(NmsMode::to_v1),
decoder_version: self.decoder_version.map(|v| v.to_v1()),
})
}
pub fn validate(&self) -> DecoderResult<()> {
if self.schema_version == 0 || self.schema_version > MAX_SUPPORTED_SCHEMA_VERSION {
return Err(DecoderError::InvalidConfig(format!(
"schema_version {} outside supported range [1, {MAX_SUPPORTED_SCHEMA_VERSION}]",
self.schema_version
)));
}
for logical in &self.outputs {
validate_logical(logical)?;
}
Ok(())
}
}
fn validate_logical(logical: &LogicalOutput) -> DecoderResult<()> {
if logical.outputs.is_empty() {
return Ok(());
}
for child in &logical.outputs {
if child.name.is_empty() {
return Err(DecoderError::InvalidConfig(format!(
"physical child of logical `{}` is missing `name`; name is \
required for tensor binding",
logical.name.as_deref().unwrap_or("<anonymous>")
)));
}
}
for (i, a) in logical.outputs.iter().enumerate() {
for b in &logical.outputs[i + 1..] {
let (Some(ta), Some(tb)) = (a.type_, b.type_) else {
continue;
};
if a.shape == b.shape && ta == tb {
return Err(DecoderError::InvalidConfig(format!(
"physical children `{}` and `{}` share shape {:?} and \
type; tensor binding cannot be resolved",
a.name, b.name, a.shape
)));
}
}
}
let strided: Vec<_> = logical.outputs.iter().map(|c| c.stride.is_some()).collect();
let all_strided = strided.iter().all(|&b| b);
let none_strided = strided.iter().all(|&b| !b);
if !(all_strided || none_strided) {
return Err(DecoderError::InvalidConfig(format!(
"logical `{}` mixes per-scale children (with stride) and \
channel sub-split children (without stride); decomposition \
must be uniform",
logical.name.as_deref().unwrap_or("<anonymous>")
)));
}
if logical.type_ == Some(LogicalType::Boxes) && logical.encoding == Some(BoxEncoding::Dfl) {
for child in &logical.outputs {
if let Some(feat) = last_feature_axis(child) {
if feat % 4 != 0 {
return Err(DecoderError::InvalidConfig(format!(
"DFL boxes child `{}` feature axis {feat} is not \
divisible by 4 (reg_max×4)",
child.name
)));
}
}
}
}
Ok(())
}
pub(crate) fn last_feature_axis(child: &PhysicalOutput) -> Option<usize> {
for (name, size) in &child.dshape {
if matches!(
name,
DimName::NumFeatures
| DimName::NumClasses
| DimName::NumProtos
| DimName::BoxCoords
| DimName::NumAnchorsXFeatures
) {
return Some(*size);
}
}
child.shape.last().copied()
}
fn quantization_from_v1(q: Option<QuantTuple>) -> Option<Quantization> {
q.map(|QuantTuple(scale, zp)| Quantization {
scale: vec![scale],
zero_point: Some(vec![zp]),
axis: None,
dtype: None,
})
}
fn logical_from_v1(v1: &ConfigOutput) -> DecoderResult<LogicalOutput> {
match v1 {
ConfigOutput::Detection(d) => {
let encoding = match (d.decoder, d.anchors.is_some()) {
(configs::DecoderType::ModelPack, true) => Some(BoxEncoding::Anchor),
(configs::DecoderType::Ultralytics, _) => Some(BoxEncoding::Direct),
(configs::DecoderType::ModelPack, false) => None,
};
Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Detection),
shape: d.shape.clone(),
dshape: d.dshape.clone(),
decoder: Some(DecoderKind::from_v1(d.decoder)),
encoding,
score_format: None,
normalized: d.normalized,
anchors: d.anchors.clone(),
stride: None,
dtype: None,
quantization: quantization_from_v1(d.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
})
}
ConfigOutput::Boxes(b) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Boxes),
shape: b.shape.clone(),
dshape: b.dshape.clone(),
decoder: Some(DecoderKind::from_v1(b.decoder)),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: b.normalized,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(b.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
ConfigOutput::Scores(s) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Scores),
shape: s.shape.clone(),
dshape: s.dshape.clone(),
decoder: Some(DecoderKind::from_v1(s.decoder)),
encoding: None,
score_format: Some(ScoreFormat::PerClass),
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(s.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
ConfigOutput::Protos(p) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Protos),
shape: p.shape.clone(),
dshape: p.dshape.clone(),
decoder: Some(DecoderKind::from_v1(p.decoder)),
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(p.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
ConfigOutput::MaskCoefficients(m) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::MaskCoefs),
shape: m.shape.clone(),
dshape: m.dshape.clone(),
decoder: Some(DecoderKind::from_v1(m.decoder)),
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(m.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
ConfigOutput::Segmentation(seg) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Segmentation),
shape: seg.shape.clone(),
dshape: seg.dshape.clone(),
decoder: Some(DecoderKind::from_v1(seg.decoder)),
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(seg.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
ConfigOutput::Mask(m) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Masks),
shape: m.shape.clone(),
dshape: m.dshape.clone(),
decoder: Some(DecoderKind::from_v1(m.decoder)),
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(m.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
ConfigOutput::Classes(c) => Ok(LogicalOutput {
name: None,
type_: Some(LogicalType::Classes),
shape: c.shape.clone(),
dshape: c.dshape.clone(),
decoder: Some(DecoderKind::from_v1(c.decoder)),
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: quantization_from_v1(c.quantization),
outputs: Vec::new(),
activation_applied: None,
activation_required: None,
}),
}
}
impl DecoderKind {
pub fn from_v1(v: configs::DecoderType) -> Self {
match v {
configs::DecoderType::ModelPack => DecoderKind::ModelPack,
configs::DecoderType::Ultralytics => DecoderKind::Ultralytics,
}
}
pub fn to_v1(self) -> configs::DecoderType {
match self {
DecoderKind::ModelPack => configs::DecoderType::ModelPack,
DecoderKind::Ultralytics => configs::DecoderType::Ultralytics,
}
}
}
impl DecoderVersion {
pub fn from_v1(v: &configs::DecoderVersion) -> Self {
match v {
configs::DecoderVersion::Yolov5 => DecoderVersion::Yolov5,
configs::DecoderVersion::Yolov8 => DecoderVersion::Yolov8,
configs::DecoderVersion::Yolo11 => DecoderVersion::Yolo11,
configs::DecoderVersion::Yolo26 => DecoderVersion::Yolo26,
}
}
pub fn to_v1(self) -> configs::DecoderVersion {
match self {
DecoderVersion::Yolov5 => configs::DecoderVersion::Yolov5,
DecoderVersion::Yolov8 => configs::DecoderVersion::Yolov8,
DecoderVersion::Yolo11 => configs::DecoderVersion::Yolo11,
DecoderVersion::Yolo26 => configs::DecoderVersion::Yolo26,
}
}
}
impl NmsMode {
pub fn from_v1(v: &configs::Nms) -> Self {
match v {
configs::Nms::Auto | configs::Nms::ClassAgnostic => NmsMode::ClassAgnostic,
configs::Nms::ClassAware => NmsMode::ClassAware,
}
}
pub fn to_v1(self) -> configs::Nms {
match self {
NmsMode::ClassAgnostic => configs::Nms::ClassAgnostic,
NmsMode::ClassAware => configs::Nms::ClassAware,
}
}
}
fn quantization_to_legacy(q: &Quantization) -> DecoderResult<QuantTuple> {
if q.is_per_channel() {
return Err(DecoderError::NotSupported(
"per-channel quantization cannot be expressed as a v1 QuantTuple".into(),
));
}
let scale = *q.scale.first().unwrap_or(&0.0);
let zp = q.zero_point_at(0);
Ok(QuantTuple(scale, zp))
}
pub(crate) fn squeeze_padding_dims(
shape: Vec<usize>,
dshape: Vec<(DimName, usize)>,
) -> (Vec<usize>, Vec<(DimName, usize)>) {
if dshape.is_empty() {
return (shape, dshape);
}
let keep: Vec<bool> = dshape
.iter()
.map(|(n, _)| !matches!(n, DimName::Padding))
.collect();
let shape = shape
.into_iter()
.zip(keep.iter())
.filter_map(|(s, &k)| k.then_some(s))
.collect();
let dshape = dshape
.into_iter()
.zip(keep.iter())
.filter_map(|(d, &k)| k.then_some(d))
.collect();
(shape, dshape)
}
pub(crate) fn padding_axes(dshape: &[(DimName, usize)]) -> Vec<usize> {
let mut v: Vec<usize> = dshape
.iter()
.enumerate()
.filter_map(|(i, (n, _))| matches!(n, DimName::Padding).then_some(i))
.collect();
v.sort_by(|a, b| b.cmp(a));
v
}
fn logical_to_legacy_config_output(logical: &LogicalOutput) -> DecoderResult<ConfigOutput> {
let decoder = logical
.decoder
.map(|d| d.to_v1())
.unwrap_or(configs::DecoderType::Ultralytics);
let quantization = logical
.quantization
.as_ref()
.map(quantization_to_legacy)
.transpose()?;
let (shape, dshape) = squeeze_padding_dims(logical.shape.clone(), logical.dshape.clone());
let ty = logical.type_.ok_or_else(|| {
DecoderError::InvalidConfig(format!(
"logical output `{}` has no type; typeless outputs should be \
filtered before legacy conversion",
logical.name.as_deref().unwrap_or("<anonymous>")
))
})?;
Ok(match ty {
LogicalType::Boxes => ConfigOutput::Boxes(configs::Boxes {
decoder,
quantization,
shape,
dshape,
normalized: logical.normalized,
}),
LogicalType::Scores => ConfigOutput::Scores(configs::Scores {
decoder,
quantization,
shape,
dshape,
}),
LogicalType::Protos => ConfigOutput::Protos(configs::Protos {
decoder,
quantization,
shape,
dshape,
}),
LogicalType::MaskCoefs => ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
decoder,
quantization,
shape,
dshape,
}),
LogicalType::Segmentation => ConfigOutput::Segmentation(configs::Segmentation {
decoder,
quantization,
shape,
dshape,
}),
LogicalType::Masks => ConfigOutput::Mask(configs::Mask {
decoder,
quantization,
shape,
dshape,
}),
LogicalType::Classes => ConfigOutput::Classes(configs::Classes {
decoder,
quantization,
shape,
dshape,
}),
LogicalType::Detection | LogicalType::Detections => {
ConfigOutput::Detection(configs::Detection {
anchors: logical.anchors.clone(),
decoder,
quantization,
shape,
dshape,
normalized: logical.normalized,
})
}
LogicalType::Objectness | LogicalType::Landmarks => {
return Err(DecoderError::NotSupported(format!(
"logical type {:?} has no legacy v1 equivalent; use the \
native v2 decoder path",
ty
)));
}
})
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
#[test]
fn schema_default_is_v2() {
let s = SchemaV2::default();
assert_eq!(s.schema_version, 2);
assert!(s.outputs.is_empty());
}
#[test]
fn fixtures_round_trip_through_serde() {
let yolov8 = include_str!("../../../testdata/per_scale/synthetic_yolov8n_schema.json");
let _: super::SchemaV2 = serde_json::from_str(yolov8).expect("yolov8n fixture must parse");
let yolo26 = include_str!("../../../testdata/per_scale/synthetic_yolo26n_schema.json");
let _: super::SchemaV2 = serde_json::from_str(yolo26).expect("yolo26n fixture must parse");
let flat = include_str!("../../../testdata/per_scale/synthetic_flat_schema.json");
let _: super::SchemaV2 = serde_json::from_str(flat).expect("flat fixture must parse");
}
#[test]
fn box_encoding_accepts_ltrb_alias_for_direct() {
let dfl: BoxEncoding = serde_json::from_str("\"dfl\"").unwrap();
assert_eq!(dfl, BoxEncoding::Dfl);
let direct: BoxEncoding = serde_json::from_str("\"direct\"").unwrap();
assert_eq!(direct, BoxEncoding::Direct);
let ltrb: BoxEncoding = serde_json::from_str("\"ltrb\"").unwrap();
assert_eq!(ltrb, BoxEncoding::Direct);
}
#[test]
fn dtype_roundtrip() {
for d in [
DType::Int8,
DType::Uint8,
DType::Int16,
DType::Uint16,
DType::Float16,
DType::Float32,
] {
let j = serde_json::to_string(&d).unwrap();
let back: DType = serde_json::from_str(&j).unwrap();
assert_eq!(back, d);
}
}
#[test]
fn dtype_widths() {
assert_eq!(DType::Int8.size_bytes(), 1);
assert_eq!(DType::Float16.size_bytes(), 2);
assert_eq!(DType::Float32.size_bytes(), 4);
}
#[test]
fn stride_accepts_scalar_or_pair() {
let a: Stride = serde_json::from_str("8").unwrap();
let b: Stride = serde_json::from_str("[8, 16]").unwrap();
assert_eq!(a, Stride::Square(8));
assert_eq!(b, Stride::Rect([8, 16]));
assert_eq!(a.x(), 8);
assert_eq!(a.y(), 8);
assert_eq!(b.x(), 8);
assert_eq!(b.y(), 16);
}
#[test]
fn quantization_scalar_scale() {
let j = r#"{"scale": 0.00392, "zero_point": 0, "dtype": "int8"}"#;
let q: Quantization = serde_json::from_str(j).unwrap();
assert!(q.is_per_tensor());
assert!(q.is_symmetric());
assert_eq!(q.scale_at(0), 0.00392);
assert_eq!(q.scale_at(5), 0.00392);
assert_eq!(q.zero_point_at(0), 0);
}
#[test]
fn quantization_per_channel() {
let j = r#"{"scale": [0.054, 0.089, 0.195], "axis": 0, "dtype": "int8"}"#;
let q: Quantization = serde_json::from_str(j).unwrap();
assert!(q.is_per_channel());
assert!(q.is_symmetric());
assert_eq!(q.axis, Some(0));
assert_eq!(q.scale_at(0), 0.054);
assert_eq!(q.scale_at(2), 0.195);
}
#[test]
fn quantization_asymmetric_per_tensor() {
let j = r#"{"scale": 0.176, "zero_point": 198, "dtype": "uint8"}"#;
let q: Quantization = serde_json::from_str(j).unwrap();
assert!(!q.is_symmetric());
assert_eq!(q.zero_point_at(0), 198);
assert_eq!(q.zero_point_at(10), 198);
}
#[test]
fn quantization_symmetric_default_zero_point() {
let j = r#"{"scale": 0.00392, "dtype": "int8"}"#;
let q: Quantization = serde_json::from_str(j).unwrap();
assert!(q.is_symmetric());
assert_eq!(q.zero_point_at(0), 0);
}
#[test]
fn quantization_to_tensor_per_tensor_asymmetric() {
let q = Quantization {
scale: vec![0.1],
zero_point: Some(vec![-5]),
axis: None,
dtype: Some(DType::Int8),
};
let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
assert!(t.is_per_tensor());
assert!(!t.is_symmetric());
assert_eq!(t.scale(), &[0.1]);
assert_eq!(t.zero_point(), Some(&[-5][..]));
}
#[test]
fn quantization_to_tensor_per_tensor_symmetric() {
let q = Quantization {
scale: vec![0.05],
zero_point: None,
axis: None,
dtype: Some(DType::Int8),
};
let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
assert!(t.is_per_tensor());
assert!(t.is_symmetric());
}
#[test]
fn quantization_to_tensor_per_channel_asymmetric() {
let q = Quantization {
scale: vec![0.1, 0.2, 0.3],
zero_point: Some(vec![-1, 0, 1]),
axis: Some(2),
dtype: Some(DType::Int8),
};
let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
assert!(t.is_per_channel());
assert_eq!(t.axis(), Some(2));
assert_eq!(t.scale().len(), 3);
assert_eq!(t.zero_point().map(|z| z.len()), Some(3));
}
#[test]
fn quantization_to_tensor_per_channel_symmetric() {
let q = Quantization {
scale: vec![0.054, 0.089, 0.195],
zero_point: None,
axis: Some(0),
dtype: Some(DType::Int8),
};
let t: edgefirst_tensor::Quantization = (&q).try_into().unwrap();
assert!(t.is_per_channel());
assert!(t.is_symmetric());
assert_eq!(t.axis(), Some(0));
}
#[test]
fn quantization_to_tensor_per_channel_missing_axis_errors() {
let q = Quantization {
scale: vec![0.1, 0.2, 0.3],
zero_point: None,
axis: None,
dtype: None,
};
let err = edgefirst_tensor::Quantization::try_from(&q).unwrap_err();
assert!(matches!(
err,
edgefirst_tensor::Error::QuantizationInvalid { .. }
));
}
#[test]
fn logical_output_flat_tflite_boxes() {
let j = r#"{
"name": "boxes", "type": "boxes",
"shape": [1, 64, 8400],
"dshape": [{"batch": 1}, {"num_features": 64}, {"num_boxes": 8400}],
"dtype": "int8",
"quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
"decoder": "ultralytics",
"encoding": "dfl",
"normalized": true
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert_eq!(lo.type_, Some(LogicalType::Boxes));
assert_eq!(lo.encoding, Some(BoxEncoding::Dfl));
assert_eq!(lo.normalized, Some(true));
assert!(!lo.is_split());
assert_eq!(lo.dtype, Some(DType::Int8));
}
#[test]
fn logical_output_hailo_per_scale_split() {
let j = r#"{
"name": "boxes", "type": "boxes",
"shape": [1, 64, 8400],
"encoding": "dfl", "decoder": "ultralytics", "normalized": true,
"outputs": [
{
"name": "boxes_0", "type": "boxes",
"stride": 8, "scale_index": 0,
"shape": [1, 80, 80, 64],
"dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_features": 64}],
"dtype": "uint8",
"quantization": {"scale": 0.0234, "zero_point": 128, "dtype": "uint8"}
}
]
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert!(lo.is_split());
assert_eq!(lo.outputs.len(), 1);
let child = &lo.outputs[0];
assert_eq!(child.name, "boxes_0");
assert_eq!(child.type_, Some(PhysicalType::Boxes));
assert_eq!(child.stride, Some(Stride::Square(8)));
assert_eq!(child.scale_index, Some(0));
assert_eq!(child.dtype, DType::Uint8);
}
#[test]
fn logical_output_ara2_xy_wh_channel_split() {
let j = r#"{
"name": "boxes", "type": "boxes",
"shape": [1, 4, 8400, 1],
"encoding": "direct", "decoder": "ultralytics", "normalized": true,
"outputs": [
{
"name": "_model_22_Div_1_output_0", "type": "boxes_xy",
"shape": [1, 2, 8400, 1],
"dshape": [{"batch": 1}, {"box_coords": 2}, {"num_boxes": 8400}, {"padding": 1}],
"dtype": "int16",
"quantization": {"scale": 3.129e-5, "zero_point": 0, "dtype": "int16"}
},
{
"name": "_model_22_Sub_1_output_0", "type": "boxes_wh",
"shape": [1, 2, 8400, 1],
"dshape": [{"batch": 1}, {"box_coords": 2}, {"num_boxes": 8400}, {"padding": 1}],
"dtype": "int16",
"quantization": {"scale": 3.149e-5, "zero_point": 0, "dtype": "int16"}
}
]
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert_eq!(lo.encoding, Some(BoxEncoding::Direct));
assert_eq!(lo.outputs.len(), 2);
assert_eq!(lo.outputs[0].type_, Some(PhysicalType::BoxesXy));
assert_eq!(lo.outputs[1].type_, Some(PhysicalType::BoxesWh));
assert!(lo.outputs[0].stride.is_none());
assert!(lo.outputs[1].stride.is_none());
}
#[test]
fn logical_output_hailo_scores_sigmoid_applied() {
let j = r#"{
"name": "scores", "type": "scores",
"shape": [1, 80, 8400],
"decoder": "ultralytics", "score_format": "per_class",
"outputs": [
{
"name": "scores_0", "type": "scores",
"stride": 8, "scale_index": 0,
"shape": [1, 80, 80, 80],
"dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_classes": 80}],
"dtype": "uint8",
"quantization": {"scale": 0.003922, "dtype": "uint8"},
"activation_applied": "sigmoid"
}
]
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert_eq!(lo.score_format, Some(ScoreFormat::PerClass));
let child = &lo.outputs[0];
assert_eq!(child.activation_applied, Some(Activation::Sigmoid));
assert!(child.activation_required.is_none());
}
#[test]
fn yolo26_end_to_end_detections() {
let j = r#"{
"schema_version": 2,
"decoder_version": "yolo26",
"outputs": [{
"name": "output0", "type": "detections",
"shape": [1, 100, 6],
"dshape": [{"batch": 1}, {"num_boxes": 100}, {"num_features": 6}],
"dtype": "int8",
"quantization": {"scale": 0.0078, "zero_point": 0, "dtype": "int8"},
"normalized": false,
"decoder": "ultralytics"
}]
}"#;
let s: SchemaV2 = serde_json::from_str(j).unwrap();
assert_eq!(s.decoder_version, Some(DecoderVersion::Yolo26));
assert!(s.decoder_version.unwrap().is_end_to_end());
assert_eq!(s.outputs[0].type_, Some(LogicalType::Detections));
assert_eq!(s.outputs[0].normalized, Some(false));
assert!(s.nms.is_none());
}
#[test]
fn modelpack_anchor_detection_with_rect_stride() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name": "output_0", "type": "detection",
"shape": [1, 40, 40, 54],
"dshape": [{"batch": 1}, {"height": 40}, {"width": 40}, {"num_anchors_x_features": 54}],
"dtype": "uint8",
"quantization": {"scale": 0.176, "zero_point": 198, "dtype": "uint8"},
"decoder": "modelpack",
"encoding": "anchor",
"stride": [16, 16],
"anchors": [[0.054, 0.065], [0.089, 0.139], [0.195, 0.196]]
}]
}"#;
let s: SchemaV2 = serde_json::from_str(j).unwrap();
let lo = &s.outputs[0];
assert_eq!(lo.encoding, Some(BoxEncoding::Anchor));
assert_eq!(lo.stride, Some(Stride::Rect([16, 16])));
assert_eq!(lo.anchors.as_ref().map(|a| a.len()), Some(3));
}
#[test]
fn yolov5_obj_x_class_objectness_logical() {
let j = r#"{
"name": "objectness", "type": "objectness",
"shape": [1, 3, 8400],
"decoder": "ultralytics",
"outputs": [{
"name": "objectness_0", "type": "objectness",
"stride": 8, "scale_index": 0,
"shape": [1, 80, 80, 3],
"dshape": [{"batch": 1}, {"height": 80}, {"width": 80}, {"num_features": 3}],
"dtype": "uint8",
"quantization": {"scale": 0.0039, "zero_point": 0, "dtype": "uint8"},
"activation_applied": "sigmoid"
}]
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert_eq!(lo.type_, Some(LogicalType::Objectness));
assert_eq!(lo.outputs[0].activation_applied, Some(Activation::Sigmoid));
}
#[test]
fn direct_protos_no_decoder() {
let j = r#"{
"name": "protos", "type": "protos",
"shape": [1, 32, 160, 160],
"dshape": [{"batch": 1}, {"num_protos": 32}, {"height": 160}, {"width": 160}],
"dtype": "uint8",
"quantization": {"scale": 0.0203, "zero_point": 45, "dtype": "uint8"},
"stride": 4
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert_eq!(lo.type_, Some(LogicalType::Protos));
assert!(lo.decoder.is_none());
assert_eq!(lo.stride, Some(Stride::Square(4)));
}
#[test]
fn full_yolov8_tflite_flat_detection() {
let j = r#"{
"schema_version": 2,
"decoder_version": "yolov8",
"nms": "class_agnostic",
"input": { "shape": [1, 640, 640, 3], "cameraadaptor": "rgb" },
"outputs": [
{
"name": "boxes", "type": "boxes",
"shape": [1, 64, 8400],
"dshape": [{"batch": 1}, {"num_features": 64}, {"num_boxes": 8400}],
"dtype": "int8",
"quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
"decoder": "ultralytics",
"encoding": "dfl",
"normalized": true
},
{
"name": "scores", "type": "scores",
"shape": [1, 80, 8400],
"dshape": [{"batch": 1}, {"num_classes": 80}, {"num_boxes": 8400}],
"dtype": "int8",
"quantization": {"scale": 0.00392, "zero_point": 0, "dtype": "int8"},
"decoder": "ultralytics",
"score_format": "per_class"
}
]
}"#;
let s: SchemaV2 = serde_json::from_str(j).unwrap();
assert_eq!(s.schema_version, 2);
assert_eq!(s.decoder_version, Some(DecoderVersion::Yolov8));
assert_eq!(s.nms, Some(NmsMode::ClassAgnostic));
assert_eq!(s.input.as_ref().unwrap().shape, vec![1, 640, 640, 3]);
assert_eq!(s.outputs.len(), 2);
}
#[test]
fn schema_unknown_version_parses_without_validation() {
let j = r#"{"schema_version": 99, "outputs": []}"#;
let s: SchemaV2 = serde_json::from_str(j).unwrap();
assert_eq!(s.schema_version, 99);
}
#[test]
fn serde_roundtrip_preserves_fields() {
let original = SchemaV2 {
schema_version: 2,
input: Some(InputSpec {
shape: vec![1, 3, 640, 640],
dshape: vec![],
cameraadaptor: Some("rgb".into()),
}),
outputs: vec![LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 8400],
dshape: vec![],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: Some(DType::Float32),
quantization: None,
outputs: vec![],
activation_applied: None,
activation_required: None,
}],
nms: Some(NmsMode::ClassAgnostic),
decoder_version: Some(DecoderVersion::Yolov8),
};
let j = serde_json::to_string(&original).unwrap();
let parsed: SchemaV2 = serde_json::from_str(&j).unwrap();
assert_eq!(parsed, original);
}
#[test]
fn parse_v1_yaml_yolov8_seg_testdata() {
let yaml = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/yolov8_seg.yaml"
));
let schema = SchemaV2::parse_yaml(yaml).expect("parse v1 yaml");
assert_eq!(schema.schema_version, 2);
assert_eq!(schema.outputs.len(), 2);
let det = &schema.outputs[0];
assert_eq!(det.type_, Some(LogicalType::Detection));
assert_eq!(det.shape, vec![1, 116, 8400]);
assert_eq!(det.decoder, Some(DecoderKind::Ultralytics));
assert_eq!(det.encoding, Some(BoxEncoding::Direct));
let q = det.quantization.as_ref().unwrap();
assert_eq!(q.scale.len(), 1);
assert!((q.scale[0] - 0.021_287_762).abs() < 1e-6);
assert_eq!(q.zero_point, Some(vec![31]));
let protos = &schema.outputs[1];
assert_eq!(protos.type_, Some(LogicalType::Protos));
assert_eq!(protos.shape, vec![1, 160, 160, 32]);
}
#[test]
fn parse_v1_json_modelpack_split_testdata() {
let json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/modelpack_split.json"
));
let schema = SchemaV2::parse_json(json).expect("parse v1 json");
assert_eq!(schema.schema_version, 2);
assert_eq!(schema.outputs.len(), 2);
for out in &schema.outputs {
assert_eq!(out.type_, Some(LogicalType::Detection));
assert_eq!(out.decoder, Some(DecoderKind::ModelPack));
assert_eq!(out.encoding, Some(BoxEncoding::Anchor));
assert_eq!(out.anchors.as_ref().map(|a| a.len()), Some(3));
}
}
#[test]
fn parse_v2_json_direct_when_schema_version_present() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name": "boxes", "type": "boxes",
"shape": [1, 4, 8400],
"dshape": [{"batch": 1}, {"box_coords": 4}, {"num_boxes": 8400}],
"dtype": "float32",
"decoder": "ultralytics",
"encoding": "direct",
"normalized": true
}]
}"#;
let schema = SchemaV2::parse_json(j).unwrap();
assert_eq!(schema.schema_version, 2);
assert_eq!(schema.outputs[0].type_, Some(LogicalType::Boxes));
}
#[test]
fn parse_rejects_future_schema_version() {
let j = r#"{"schema_version": 99, "outputs": []}"#;
let err = SchemaV2::parse_json(j).unwrap_err();
matches!(err, DecoderError::NotSupported(_));
}
#[test]
fn parse_absent_schema_version_treats_as_v1() {
let j = r#"{
"outputs": [
{
"type": "boxes", "decoder": "ultralytics",
"shape": [1, 4, 8400],
"quantization": [0.00392, 0]
},
{
"type": "scores", "decoder": "ultralytics",
"shape": [1, 80, 8400],
"quantization": [0.00392, 0]
}
]
}"#;
let schema = SchemaV2::parse_json(j).expect("v1 legacy parse");
assert_eq!(schema.schema_version, 2); assert_eq!(schema.outputs.len(), 2);
assert_eq!(schema.outputs[0].type_, Some(LogicalType::Boxes));
assert_eq!(schema.outputs[1].type_, Some(LogicalType::Scores));
assert_eq!(schema.outputs[1].score_format, Some(ScoreFormat::PerClass));
}
#[test]
fn from_v1_preserves_nms_and_decoder_version() {
let v1 = ConfigOutputs {
outputs: vec![ConfigOutput::Boxes(crate::configs::Boxes {
decoder: crate::configs::DecoderType::Ultralytics,
quantization: Some(crate::configs::QuantTuple(0.01, 5)),
shape: vec![1, 4, 8400],
dshape: vec![],
normalized: Some(true),
})],
nms: Some(crate::configs::Nms::ClassAware),
decoder_version: Some(crate::configs::DecoderVersion::Yolo11),
};
let v2 = SchemaV2::from_v1(&v1).unwrap();
assert_eq!(v2.nms, Some(NmsMode::ClassAware));
assert_eq!(v2.decoder_version, Some(DecoderVersion::Yolo11));
assert_eq!(v2.outputs[0].normalized, Some(true));
let q = v2.outputs[0].quantization.as_ref().unwrap();
assert_eq!(q.scale, vec![0.01]);
assert_eq!(q.zero_point, Some(vec![5]));
assert_eq!(q.dtype, None); }
#[test]
fn typeless_logical_output_parses_and_roundtrips() {
let j = r#"{
"schema_version": 2,
"outputs": [
{
"name": "extra_telemetry",
"shape": [1, 16]
},
{
"name": "boxes",
"type": "boxes",
"shape": [1, 4, 8400]
}
]
}"#;
let schema: SchemaV2 = serde_json::from_str(j).unwrap();
assert_eq!(schema.outputs.len(), 2);
assert_eq!(schema.outputs[0].type_, None);
assert_eq!(schema.outputs[0].name.as_deref(), Some("extra_telemetry"));
assert_eq!(schema.outputs[1].type_, Some(LogicalType::Boxes));
let round = serde_json::to_string(&schema).unwrap();
let first_obj = round
.split("\"outputs\":[")
.nth(1)
.and_then(|s| s.split("}").next())
.expect("outputs array");
assert!(
!first_obj.contains("\"type\""),
"typeless output must not serialize a `type` field, got: {first_obj}"
);
}
#[test]
fn typeless_outputs_filtered_from_legacy_config() {
let schema = SchemaV2 {
schema_version: 2,
input: None,
outputs: vec![
LogicalOutput {
name: Some("diagnostic_histogram".into()),
type_: None,
shape: vec![1, 256],
dshape: vec![],
decoder: None,
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![],
activation_applied: None,
activation_required: None,
},
LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 8400],
dshape: vec![],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![],
activation_applied: None,
activation_required: None,
},
],
nms: None,
decoder_version: None,
};
let legacy = schema.to_legacy_config_outputs().unwrap();
assert_eq!(
legacy.outputs.len(),
1,
"typeless output must be filtered from legacy config"
);
assert!(
matches!(legacy.outputs[0], ConfigOutput::Boxes(_)),
"only the typed `boxes` output should survive lowering"
);
}
#[test]
fn all_typeless_schema_produces_empty_legacy_config() {
let schema = SchemaV2 {
schema_version: 2,
input: None,
outputs: vec![LogicalOutput {
name: Some("aux".into()),
type_: None,
shape: vec![1, 8],
dshape: vec![],
decoder: None,
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![],
activation_applied: None,
activation_required: None,
}],
nms: None,
decoder_version: None,
};
let legacy = schema.to_legacy_config_outputs().unwrap();
assert!(legacy.outputs.is_empty());
}
#[test]
fn typeless_physical_child_parses_and_skips_uniqueness() {
let j = r#"{
"name": "boxes",
"type": "boxes",
"shape": [1, 8400, 4],
"outputs": [
{
"name": "boxes_xy",
"type": "boxes_xy",
"shape": [1, 8400, 2],
"dtype": "float32"
},
{
"name": "aux_user_managed",
"shape": [1, 8400, 2],
"dtype": "float32"
}
]
}"#;
let lo: LogicalOutput = serde_json::from_str(j).unwrap();
assert_eq!(lo.outputs.len(), 2);
assert_eq!(lo.outputs[0].type_, Some(PhysicalType::BoxesXy));
assert_eq!(lo.outputs[1].type_, None);
let schema = SchemaV2 {
schema_version: 2,
input: None,
outputs: vec![lo],
nms: None,
decoder_version: None,
};
schema.validate().expect(
"typed + typeless children with equal shape must not trigger \
uniqueness error",
);
let s = serde_json::to_string(&schema).unwrap();
assert!(
s.contains("\"aux_user_managed\""),
"typeless child must survive round-trip: {s}"
);
let aux_obj = s
.split("\"aux_user_managed\"")
.nth(1)
.and_then(|s| s.split('}').next())
.unwrap_or("");
assert!(
!aux_obj.contains("\"type\""),
"typeless child must not serialize `type`, got: {aux_obj}"
);
}
#[test]
fn from_v1_modelpack_anchor_detection_maps_encoding() {
let v1 = ConfigOutputs {
outputs: vec![ConfigOutput::Detection(crate::configs::Detection {
anchors: Some(vec![[0.1, 0.2], [0.3, 0.4]]),
decoder: crate::configs::DecoderType::ModelPack,
quantization: Some(crate::configs::QuantTuple(0.176, 198)),
shape: vec![1, 40, 40, 54],
dshape: vec![],
normalized: None,
})],
nms: None,
decoder_version: None,
};
let v2 = SchemaV2::from_v1(&v1).unwrap();
assert_eq!(v2.outputs[0].encoding, Some(BoxEncoding::Anchor));
assert_eq!(v2.outputs[0].decoder, Some(DecoderKind::ModelPack));
assert_eq!(v2.outputs[0].anchors.as_ref().map(|a| a.len()), Some(2));
}
#[test]
fn validate_accepts_flat_v2_yolov8_detection() {
let j = r#"{
"schema_version": 2,
"outputs": [
{"name":"boxes","type":"boxes","shape":[1,64,8400],
"dtype":"int8","decoder":"ultralytics","encoding":"dfl"},
{"name":"scores","type":"scores","shape":[1,80,8400],
"dtype":"int8","decoder":"ultralytics","score_format":"per_class"}
]
}"#;
SchemaV2::parse_json(j).unwrap().validate().unwrap();
}
#[test]
fn validate_rejects_unnamed_physical_child() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name":"boxes","type":"boxes","shape":[1,64,8400],
"encoding":"dfl","decoder":"ultralytics",
"outputs": [{
"name":"","type":"boxes","stride":8,
"shape":[1,80,80,64],"dtype":"uint8"
}]
}]
}"#;
let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("missing `name`"), "got: {msg}");
}
#[test]
fn validate_rejects_duplicate_physical_shapes() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name":"boxes","type":"boxes","shape":[1,64,8400],
"encoding":"dfl","decoder":"ultralytics",
"outputs": [
{"name":"a","type":"boxes","stride":8,"shape":[1,80,80,64],"dtype":"uint8"},
{"name":"b","type":"boxes","stride":16,"shape":[1,80,80,64],"dtype":"uint8"}
]
}]
}"#;
let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("share shape"), "got: {msg}");
}
#[test]
fn validate_rejects_mixed_decomposition() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name":"boxes","type":"boxes","shape":[1,4,8400,1],
"encoding":"direct","decoder":"ultralytics",
"outputs": [
{"name":"xy","type":"boxes_xy","shape":[1,2,8400,1],"dtype":"int16"},
{"name":"p0","type":"boxes","stride":8,"shape":[1,80,80,64],"dtype":"uint8"}
]
}]
}"#;
let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("uniform"), "got: {msg}");
}
#[test]
fn validate_rejects_dfl_boxes_feature_not_divisible_by_4() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name":"boxes","type":"boxes","shape":[1,63,8400],
"encoding":"dfl","decoder":"ultralytics",
"outputs": [{
"name":"b0","type":"boxes","stride":8,
"shape":[1,80,80,63],
"dshape":[{"batch":1},{"height":80},{"width":80},{"num_features":63}],
"dtype":"uint8"
}]
}]
}"#;
let err = SchemaV2::parse_json(j).unwrap().validate().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("not"), "got: {msg}");
assert!(msg.contains("divisible by 4"), "got: {msg}");
}
#[test]
fn validate_accepts_hailo_per_scale_yolov8() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name":"boxes","type":"boxes","shape":[1,64,8400],
"encoding":"dfl","decoder":"ultralytics","normalized":true,
"outputs": [
{"name":"b0","type":"boxes","stride":8,
"shape":[1,80,80,64],
"dshape":[{"batch":1},{"height":80},{"width":80},{"num_features":64}],
"dtype":"uint8",
"quantization":{"scale":0.0234,"zero_point":128,"dtype":"uint8"}},
{"name":"b1","type":"boxes","stride":16,
"shape":[1,40,40,64],
"dshape":[{"batch":1},{"height":40},{"width":40},{"num_features":64}],
"dtype":"uint8",
"quantization":{"scale":0.0198,"zero_point":130,"dtype":"uint8"}},
{"name":"b2","type":"boxes","stride":32,
"shape":[1,20,20,64],
"dshape":[{"batch":1},{"height":20},{"width":20},{"num_features":64}],
"dtype":"uint8",
"quantization":{"scale":0.0312,"zero_point":125,"dtype":"uint8"}}
]
}]
}"#;
let s = SchemaV2::parse_json(j).unwrap();
s.validate().unwrap();
}
#[test]
fn validate_accepts_ara2_xy_wh() {
let j = r#"{
"schema_version": 2,
"outputs": [{
"name":"boxes","type":"boxes","shape":[1,4,8400,1],
"encoding":"direct","decoder":"ultralytics","normalized":true,
"outputs": [
{"name":"xy","type":"boxes_xy","shape":[1,2,8400,1],
"dshape":[{"batch":1},{"box_coords":2},{"num_boxes":8400},{"padding":1}],
"dtype":"int16",
"quantization":{"scale":3.1e-5,"zero_point":0,"dtype":"int16"}},
{"name":"wh","type":"boxes_wh","shape":[1,2,8400,1],
"dshape":[{"batch":1},{"box_coords":2},{"num_boxes":8400},{"padding":1}],
"dtype":"int16",
"quantization":{"scale":3.2e-5,"zero_point":0,"dtype":"int16"}}
]
}]
}"#;
SchemaV2::parse_json(j).unwrap().validate().unwrap();
}
#[test]
fn parse_file_auto_detects_json() {
let tmp = std::env::temp_dir().join(format!("schema_v2_test_{}.json", std::process::id()));
std::fs::write(&tmp, r#"{"schema_version":2,"outputs":[]}"#).unwrap();
let s = SchemaV2::parse_file(&tmp).unwrap();
assert_eq!(s.schema_version, 2);
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn parse_file_auto_detects_yaml() {
let tmp = std::env::temp_dir().join(format!("schema_v2_test_{}.yaml", std::process::id()));
std::fs::write(&tmp, "schema_version: 2\noutputs: []\n").unwrap();
let s = SchemaV2::parse_file(&tmp).unwrap();
assert_eq!(s.schema_version, 2);
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn parse_real_ara2_int8_dvm_metadata() {
let json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/ara2_int8_edgefirst.json"
));
let schema = SchemaV2::parse_json(json).expect("ARA-2 int8 parse");
assert_eq!(schema.schema_version, 2);
assert_eq!(schema.decoder_version, Some(DecoderVersion::Yolov8));
assert_eq!(schema.nms, Some(NmsMode::ClassAgnostic));
assert_eq!(schema.input.as_ref().unwrap().shape, vec![1, 3, 640, 640]);
assert_eq!(schema.outputs.len(), 4);
let boxes = &schema.outputs[0];
assert_eq!(boxes.type_, Some(LogicalType::Boxes));
assert_eq!(boxes.encoding, Some(BoxEncoding::Direct));
assert_eq!(boxes.normalized, Some(true));
assert_eq!(boxes.shape, vec![1, 4, 8400, 1]); assert_eq!(boxes.outputs.len(), 2);
assert_eq!(boxes.outputs[0].type_, Some(PhysicalType::BoxesXy));
assert_eq!(boxes.outputs[1].type_, Some(PhysicalType::BoxesWh));
let q_xy = boxes.outputs[0].quantization.as_ref().unwrap();
assert_eq!(q_xy.dtype, Some(DType::Int8));
assert!((q_xy.scale[0] - 0.004_177_792).abs() < 1e-6);
assert_eq!(q_xy.zero_point_at(0), -122);
let scores = &schema.outputs[1];
assert_eq!(scores.type_, Some(LogicalType::Scores));
assert_eq!(scores.score_format, Some(ScoreFormat::PerClass));
assert_eq!(scores.shape, vec![1, 80, 8400, 1]);
let mask_coefs = &schema.outputs[2];
assert_eq!(mask_coefs.type_, Some(LogicalType::MaskCoefs));
assert_eq!(mask_coefs.shape, vec![1, 32, 8400, 1]);
let protos = &schema.outputs[3];
assert_eq!(protos.type_, Some(LogicalType::Protos));
assert_eq!(protos.shape, vec![1, 32, 160, 160]);
schema.validate().expect("ARA-2 int8 validate");
}
#[test]
fn parse_real_ara2_int16_dvm_metadata() {
let json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/ara2_int16_edgefirst.json"
));
let schema = SchemaV2::parse_json(json).expect("ARA-2 int16 parse");
assert_eq!(schema.schema_version, 2);
assert_eq!(schema.outputs.len(), 4);
let boxes = &schema.outputs[0];
assert_eq!(boxes.outputs.len(), 2);
let q_xy = boxes.outputs[0].quantization.as_ref().unwrap();
assert_eq!(q_xy.dtype, Some(DType::Int16));
assert!((q_xy.scale[0] - 3.211_570_6e-5).abs() < 1e-10);
assert_eq!(q_xy.zero_point_at(0), 0);
let mc_q = schema.outputs[2].quantization.as_ref().unwrap();
assert_eq!(mc_q.dtype, Some(DType::Int16));
schema.validate().expect("ARA-2 int16 validate");
}
#[test]
fn parse_yaml_with_explicit_schema_version_2() {
let yaml = r#"
schema_version: 2
outputs:
- name: scores
type: scores
shape: [1, 80, 8400]
dtype: int8
quantization:
scale: 0.00392
dtype: int8
decoder: ultralytics
score_format: per_class
"#;
let schema = SchemaV2::parse_yaml(yaml).unwrap();
assert_eq!(schema.schema_version, 2);
assert_eq!(schema.outputs[0].score_format, Some(ScoreFormat::PerClass));
}
#[test]
fn squeeze_padding_dims_preserves_shape_when_dshape_absent() {
let (shape, dshape) = squeeze_padding_dims(vec![1, 4, 8400], vec![]);
assert_eq!(shape, vec![1, 4, 8400]);
assert!(dshape.is_empty());
}
#[test]
fn to_legacy_preserves_shape_for_v2_split_boxes_without_dshape() {
let j = r#"{
"schema_version": 2,
"outputs": [
{"name":"boxes","type":"boxes","shape":[1,4,8400],
"dtype":"float32","decoder":"ultralytics","encoding":"direct"},
{"name":"scores","type":"scores","shape":[1,80,8400],
"dtype":"float32","decoder":"ultralytics","score_format":"per_class"}
]
}"#;
let schema = SchemaV2::parse_json(j).unwrap();
let legacy = schema.to_legacy_config_outputs().expect("lowers cleanly");
let boxes = match &legacy.outputs[0] {
crate::ConfigOutput::Boxes(b) => b,
other => panic!("expected Boxes, got {other:?}"),
};
assert_eq!(boxes.shape, vec![1, 4, 8400]);
let scores = match &legacy.outputs[1] {
crate::ConfigOutput::Scores(s) => s,
other => panic!("expected Scores, got {other:?}"),
};
assert_eq!(scores.shape, vec![1, 80, 8400]);
}
}