use super::kernels::grids::make_anchor_grid;
use super::kernels::CpuFeatures;
use super::{Activation, DecodeDtype};
use crate::configs::DimName;
use crate::schema::BoxEncoding;
use crate::schema::{LogicalOutput, LogicalType, PhysicalOutput, SchemaV2};
use crate::{DecoderError, DecoderResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Layout {
Nhwc,
Nchw,
}
#[derive(Debug)]
#[allow(dead_code)] pub(crate) struct PerScalePlan {
pub(crate) levels: Vec<LevelPlan>,
pub(crate) total_anchors: usize,
pub(crate) num_classes: usize,
pub(crate) num_mask_coefs: usize,
pub(crate) box_encoding: BoxEncoding,
pub(crate) score_activation: Activation,
pub(crate) out_dtype: DecodeDtype,
pub(crate) cpu_features: CpuFeatures,
pub(crate) box_dispatch: super::kernels::dispatch::BoxLevelDispatch,
pub(crate) score_dispatch: super::kernels::dispatch::ScoreLevelDispatch,
pub(crate) mc_dispatch: Option<super::kernels::dispatch::MaskCoefDispatch>,
pub(crate) proto_dispatch: Option<super::kernels::dispatch::ProtoDispatch>,
pub(crate) proto_shape: Option<Box<[usize]>>,
}
#[derive(Debug)]
#[allow(dead_code)] pub(crate) struct LevelPlan {
pub(crate) stride: f32,
pub(crate) h: usize,
pub(crate) w: usize,
pub(crate) reg_max: usize,
pub(crate) anchor_offset: usize,
pub(crate) grid_x: Box<[f32]>,
pub(crate) grid_y: Box<[f32]>,
pub(crate) box_shape: Box<[usize]>,
pub(crate) score_shape: Box<[usize]>,
pub(crate) mc_shape: Option<Box<[usize]>>,
pub(crate) layout: Layout,
}
impl PerScalePlan {
#[allow(dead_code)] pub(crate) fn try_from_schema(
schema: &SchemaV2,
out_dtype: DecodeDtype,
) -> DecoderResult<Option<Self>> {
let boxes = match find_logical_by_type(schema, LogicalType::Boxes) {
Some(b) if !b.outputs.is_empty() => b,
_ => return Ok(None),
};
let scores = match find_logical_by_type(schema, LogicalType::Scores) {
Some(s) if !s.outputs.is_empty() => s,
_ => return Ok(None),
};
let mc = find_logical_by_type(schema, LogicalType::MaskCoefs);
let protos = find_logical_by_type(schema, LogicalType::Protos);
let box_encoding = boxes.encoding.unwrap_or(BoxEncoding::Dfl);
if !matches!(box_encoding, BoxEncoding::Dfl | BoxEncoding::Direct) {
return Ok(None); }
let levels = build_levels(boxes, scores, mc, box_encoding)?;
for lvl in &levels {
if box_encoding == BoxEncoding::Dfl && lvl.reg_max > 64 {
return Err(DecoderError::NotSupported(format!(
"DFL reg_max={} exceeds Phase 1 stack-scratch cap of 64",
lvl.reg_max
)));
}
}
let total_anchors: usize = levels.iter().map(|l| l.h * l.w).sum();
let num_classes = scores
.outputs
.first()
.and_then(|s| s.shape.last())
.copied()
.unwrap_or(0);
let num_mask_coefs = mc
.and_then(|m| m.outputs.first())
.and_then(|c| c.shape.last())
.copied()
.unwrap_or(0);
let score_activation = Activation::from_schema(
scores
.activation_required
.or_else(|| scores.outputs.first().and_then(|s| s.activation_required)),
);
let proto_shape = protos.map(|p| {
if p.outputs.is_empty() {
p.shape.clone().into_boxed_slice()
} else {
p.outputs[0].shape.clone().into_boxed_slice()
}
});
let cpu_features = CpuFeatures::from_env_or_probe()?;
let box_dtype = boxes
.outputs
.first()
.and_then(|b| b.quantization.as_ref())
.and_then(|q| q.dtype)
.ok_or_else(|| {
DecoderError::InvalidConfig(
"per-scale boxes child missing quantization dtype".into(),
)
})?;
let box_dtype = schema_dtype_to_tensor_dtype(box_dtype)?;
let score_dtype = scores
.outputs
.first()
.and_then(|s| s.quantization.as_ref())
.and_then(|q| q.dtype)
.ok_or_else(|| {
DecoderError::InvalidConfig(
"per-scale scores child missing quantization dtype".into(),
)
})?;
let score_dtype = schema_dtype_to_tensor_dtype(score_dtype)?;
let box_dispatch = super::kernels::dispatch::BoxLevelDispatch::select(
box_encoding,
box_dtype,
out_dtype,
&cpu_features,
)?;
let score_dispatch = super::kernels::dispatch::ScoreLevelDispatch::select(
score_dtype,
out_dtype,
&cpu_features,
)?;
let mc_dispatch = if let Some(m) = mc {
let mc_dtype = m
.outputs
.first()
.and_then(|c| c.quantization.as_ref())
.and_then(|q| q.dtype)
.ok_or_else(|| {
DecoderError::InvalidConfig(
"per-scale mask_coefs child missing quantization dtype".into(),
)
})?;
let mc_dtype = schema_dtype_to_tensor_dtype(mc_dtype)?;
Some(super::kernels::dispatch::MaskCoefDispatch::select(
mc_dtype,
out_dtype,
&cpu_features,
)?)
} else {
None
};
let proto_dispatch = if let Some(p) = protos {
let proto_dtype = if p.outputs.is_empty() {
p.quantization.as_ref().and_then(|q| q.dtype)
} else {
p.outputs
.first()
.and_then(|c| c.quantization.as_ref())
.and_then(|q| q.dtype)
}
.ok_or_else(|| {
DecoderError::InvalidConfig("per-scale protos missing quantization dtype".into())
})?;
let proto_dtype = schema_dtype_to_tensor_dtype(proto_dtype)?;
Some(super::kernels::dispatch::ProtoDispatch::select(
proto_dtype,
out_dtype,
&cpu_features,
)?)
} else {
None
};
Ok(Some(Self {
levels,
total_anchors,
num_classes,
num_mask_coefs,
box_encoding,
score_activation,
out_dtype,
cpu_features,
box_dispatch,
score_dispatch,
mc_dispatch,
proto_dispatch,
proto_shape,
}))
}
}
fn schema_dtype_to_tensor_dtype(d: crate::schema::DType) -> DecoderResult<edgefirst_tensor::DType> {
use crate::schema::DType as S;
use edgefirst_tensor::DType as T;
Ok(match d {
S::Int8 => T::I8,
S::Uint8 => T::U8,
S::Int16 => T::I16,
S::Uint16 => T::U16,
S::Int32 => T::I32,
S::Uint32 => T::U32,
S::Float16 => T::F16,
S::Float32 => T::F32,
})
}
fn find_logical_by_type(schema: &SchemaV2, t: LogicalType) -> Option<&LogicalOutput> {
schema.outputs.iter().find(|l| l.type_ == Some(t))
}
fn build_levels(
boxes: &LogicalOutput,
scores: &LogicalOutput,
mc: Option<&LogicalOutput>,
encoding: BoxEncoding,
) -> DecoderResult<Vec<LevelPlan>> {
use std::collections::BTreeMap;
type ChildTriple<'a> = (
Option<&'a PhysicalOutput>,
Option<&'a PhysicalOutput>,
Option<&'a PhysicalOutput>,
);
let mut by_stride: BTreeMap<u32, ChildTriple<'_>> = BTreeMap::new();
for child in &boxes.outputs {
let s = child.stride.map(|s| s.x()).ok_or_else(|| {
DecoderError::InvalidConfig(format!("box child {:?} missing stride", child.name))
})?;
by_stride.entry(s).or_insert((None, None, None)).0 = Some(child);
}
for child in &scores.outputs {
let s = child.stride.map(|s| s.x()).ok_or_else(|| {
DecoderError::InvalidConfig(format!("score child {:?} missing stride", child.name))
})?;
by_stride.entry(s).or_insert((None, None, None)).1 = Some(child);
}
if let Some(mc) = mc {
for child in &mc.outputs {
let s = child.stride.map(|s| s.x()).ok_or_else(|| {
DecoderError::InvalidConfig(format!("mc child {:?} missing stride", child.name))
})?;
by_stride.entry(s).or_insert((None, None, None)).2 = Some(child);
}
}
let mut levels = Vec::with_capacity(by_stride.len());
let mut anchor_offset = 0;
for (stride, (b, s, m)) in by_stride {
let b = b.ok_or_else(|| {
DecoderError::InvalidConfig(format!("stride {stride}: missing box child"))
})?;
let s = s.ok_or_else(|| {
DecoderError::InvalidConfig(format!("stride {stride}: missing score child"))
})?;
let (h, w) = extract_hw(b)?;
let layout = detect_layout(b)?;
let s_layout = detect_layout(s)?;
if s_layout != layout {
return Err(DecoderError::InvalidConfig(format!(
"stride {stride}: score child layout ({s_layout:?}) differs from box \
child layout ({layout:?}); per-level mixed layouts are not supported"
)));
}
if let Some(mc_child) = m {
let mc_layout = detect_layout(mc_child)?;
if mc_layout != layout {
return Err(DecoderError::InvalidConfig(format!(
"stride {stride}: mask_coefs child layout ({mc_layout:?}) differs \
from box child layout ({layout:?}); per-level mixed layouts are \
not supported"
)));
}
}
let reg_max = match encoding {
BoxEncoding::Dfl => {
let feat = box_channel_count(b)?;
if feat == 0 || feat % 4 != 0 {
return Err(DecoderError::InvalidConfig(format!(
"DFL box feature count {feat} is not a positive multiple of 4"
)));
}
feat / 4
}
BoxEncoding::Direct => 1, BoxEncoding::Anchor => unreachable!("filtered above"),
};
let (gx, gy) = make_anchor_grid(h, w);
levels.push(LevelPlan {
stride: stride as f32,
h,
w,
reg_max,
anchor_offset,
grid_x: gx,
grid_y: gy,
box_shape: b.shape.clone().into_boxed_slice(),
score_shape: s.shape.clone().into_boxed_slice(),
mc_shape: m.map(|c| c.shape.clone().into_boxed_slice()),
layout,
});
anchor_offset += h * w;
}
Ok(levels)
}
fn detect_layout(p: &PhysicalOutput) -> DecoderResult<Layout> {
use DimName::*;
let channel_pos = p
.dshape
.iter()
.position(|(name, _)| matches!(name, BoxCoords | NumClasses | NumFeatures | NumProtos));
let rank = p.shape.len();
match channel_pos {
Some(idx) if idx == rank - 1 => Ok(Layout::Nhwc),
Some(idx) if idx == 1 && rank >= 4 => Ok(Layout::Nchw),
Some(idx) if idx == 1 && rank == 3 => Ok(Layout::Nhwc),
Some(idx) => Err(DecoderError::InvalidConfig(format!(
"child {:?}: channel axis at position {idx} of {rank}-D dshape; \
expected last position (NHWC) or position 1 (NCHW for 4-D)",
p.name
))),
None => Ok(Layout::Nhwc),
}
}
fn box_channel_count(p: &PhysicalOutput) -> DecoderResult<usize> {
use DimName::*;
for (i, (name, _)) in p.dshape.iter().enumerate() {
if matches!(name, BoxCoords | NumFeatures) {
return Ok(p.shape[i]);
}
}
p.shape
.last()
.copied()
.ok_or_else(|| DecoderError::InvalidConfig(format!("box child {:?}: empty shape", p.name)))
}
fn extract_hw(p: &PhysicalOutput) -> DecoderResult<(usize, usize)> {
let mut h = None;
let mut w = None;
for (i, (name, _)) in p.dshape.iter().enumerate() {
match name {
DimName::Height => h = Some(p.shape[i]),
DimName::Width => w = Some(p.shape[i]),
_ => {}
}
}
if h.is_none() && w.is_none() && p.shape.len() == 4 {
h = Some(p.shape[1]);
w = Some(p.shape[2]);
}
Ok((
h.ok_or_else(|| {
DecoderError::InvalidConfig(format!("child {:?}: missing height", p.name))
})?,
w.ok_or_else(|| DecoderError::InvalidConfig(format!("child {:?}: missing width", p.name)))?,
))
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture_yolov8n_schema() -> SchemaV2 {
let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
serde_json::from_str(json).expect("yolov8n fixture must parse")
}
fn fixture_yolo26n_schema() -> SchemaV2 {
let json = include_str!("../../../../testdata/per_scale/synthetic_yolo26n_schema.json");
serde_json::from_str(json).expect("yolo26n fixture must parse")
}
fn fixture_flat_schema() -> SchemaV2 {
let json = include_str!("../../../../testdata/per_scale/synthetic_flat_schema.json");
serde_json::from_str(json).expect("flat fixture must parse")
}
#[test]
fn try_from_schema_yolov8n_returns_some() {
let schema = fixture_yolov8n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.expect("should plan successfully")
.expect("yolov8n schema is per-scale");
assert_eq!(plan.levels.len(), 3);
assert_eq!(plan.total_anchors, 80 * 80 + 40 * 40 + 20 * 20);
assert_eq!(plan.num_classes, 80);
assert_eq!(plan.num_mask_coefs, 32);
assert_eq!(plan.box_encoding, BoxEncoding::Dfl);
assert_eq!(plan.score_activation, Activation::Sigmoid);
assert!(plan.proto_shape.is_some());
}
#[test]
fn try_from_schema_yolo26n_returns_some_with_ltrb_encoding() {
let schema = fixture_yolo26n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.expect("should plan successfully")
.expect("yolo26n schema is per-scale");
assert_eq!(plan.levels.len(), 3);
assert_eq!(plan.box_encoding, BoxEncoding::Direct);
for lvl in &plan.levels {
assert_eq!(lvl.box_shape.last().copied(), Some(4));
}
}
#[test]
fn try_from_schema_returns_none_for_flat_schema() {
let schema = fixture_flat_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32).unwrap();
assert!(plan.is_none(), "flat schema should fall through to legacy");
}
#[test]
fn try_from_schema_strides_sorted_ascending() {
let schema = fixture_yolov8n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.unwrap()
.unwrap();
let strides: Vec<f32> = plan.levels.iter().map(|l| l.stride).collect();
assert_eq!(strides, vec![8.0, 16.0, 32.0]);
}
#[test]
fn try_from_schema_anchor_offsets_cumulative() {
let schema = fixture_yolov8n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.unwrap()
.unwrap();
assert_eq!(plan.levels[0].anchor_offset, 0);
assert_eq!(plan.levels[1].anchor_offset, 80 * 80);
assert_eq!(plan.levels[2].anchor_offset, 80 * 80 + 40 * 40);
}
#[test]
fn try_from_schema_grids_pre_computed() {
let schema = fixture_yolov8n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.unwrap()
.unwrap();
for lvl in &plan.levels {
assert_eq!(lvl.grid_x.len(), lvl.h * lvl.w);
assert_eq!(lvl.grid_y.len(), lvl.h * lvl.w);
}
}
#[test]
fn try_from_schema_dfl_reg_max_is_16() {
let schema = fixture_yolov8n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.unwrap()
.unwrap();
for lvl in &plan.levels {
assert_eq!(lvl.reg_max, 16);
}
}
#[test]
fn try_from_schema_yolov8n_selects_dfl_i8_to_f32_dispatch() {
let schema = fixture_yolov8n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.unwrap()
.unwrap();
use crate::per_scale::kernels::dispatch::BoxLevelDispatch;
let ok = match plan.box_dispatch {
BoxLevelDispatch::DflI8ToF32Scalar => true,
#[cfg(target_arch = "aarch64")]
BoxLevelDispatch::DflI8ToF32NeonBase | BoxLevelDispatch::DflI8ToF32NeonFp16 => true,
_ => false,
};
assert!(ok, "unexpected dispatch: {:?}", plan.box_dispatch);
assert!(plan.mc_dispatch.is_some());
assert!(plan.proto_dispatch.is_some());
}
#[test]
fn try_from_schema_yolo26n_selects_ltrb_dispatch() {
let schema = fixture_yolo26n_schema();
let plan = PerScalePlan::try_from_schema(&schema, DecodeDtype::F32)
.unwrap()
.unwrap();
use crate::per_scale::kernels::dispatch::BoxLevelDispatch;
let ok = match plan.box_dispatch {
BoxLevelDispatch::LtrbI8ToF32Scalar => true,
#[cfg(target_arch = "aarch64")]
BoxLevelDispatch::LtrbI8ToF32NeonBase => true,
_ => false,
};
assert!(ok, "unexpected dispatch: {:?}", plan.box_dispatch);
}
}