use ndarray::{Array, Array3, ArrayD, ArrayViewD, Axis, IxDyn};
use super::dfl;
use crate::configs::DimName;
use crate::schema::{
self, padding_axes, squeeze_padding_dims, Activation, LogicalOutput, LogicalType, SchemaV2,
Stride,
};
use crate::{dequantize_cpu_chunked, DecoderError, DecoderResult, Quantization};
use edgefirst_tensor::{TensorDyn, TensorMapTrait, TensorTrait};
#[derive(Debug, Clone)]
pub(crate) struct DecodeProgram {
merges: Vec<LogicalMerge>,
}
#[derive(Debug, Clone)]
enum LogicalMerge {
Direct {
#[allow(dead_code)] name: Option<String>,
shape: Vec<usize>,
padding_axes: Vec<usize>,
quant: Option<Quantization>,
},
PerScale {
children: Vec<PhysicalBinding>,
logical_shape: Vec<usize>,
feature_axis_logical: usize,
box_axis_logical: usize,
dfl: Option<DflConfig>,
},
ChannelConcat {
children: Vec<PhysicalBinding>,
logical_shape: Vec<usize>,
channel_axis: usize,
padding_axes: Vec<usize>,
},
}
#[derive(Debug, Clone)]
struct PhysicalBinding {
name: String,
shape: Vec<usize>,
dshape: Vec<(DimName, usize)>,
quant: Option<Quantization>,
stride: Option<Stride>,
#[allow(dead_code)] activation_applied: Option<Activation>,
}
#[derive(Debug, Clone)]
pub(crate) struct DflConfig {
pub(crate) reg_max: usize,
grids: Vec<DflChildGrid>,
}
#[derive(Debug, Clone)]
struct DflChildGrid {
stride: f32,
grid_x: Vec<f32>,
grid_y: Vec<f32>,
}
impl DecodeProgram {
pub fn try_from_schema(schema: &SchemaV2) -> DecoderResult<Option<Self>> {
let needs_merge = schema
.outputs
.iter()
.any(|l| l.type_.is_some() && !l.outputs.is_empty());
if !needs_merge {
return Ok(None);
}
let merges = schema
.outputs
.iter()
.filter(|l| l.type_.is_some())
.map(plan_logical)
.collect::<DecoderResult<Vec<_>>>()?;
Ok(Some(Self { merges }))
}
pub fn execute(&self, inputs: &[&TensorDyn]) -> DecoderResult<Vec<ArrayD<f32>>> {
let mut used: Vec<usize> = Vec::new();
self.merges
.iter()
.map(|m| execute_merge(m, inputs, &mut used))
.collect()
}
#[cfg(test)]
pub(crate) fn boxes_reg_max(&self) -> Option<usize> {
for m in &self.merges {
if let LogicalMerge::PerScale { dfl: Some(d), .. } = m {
return Some(d.reg_max);
}
}
None
}
}
fn plan_logical(logical: &LogicalOutput) -> DecoderResult<LogicalMerge> {
if logical.outputs.is_empty() {
let pad = padding_axes(&logical.dshape);
return Ok(LogicalMerge::Direct {
name: logical.name.clone(),
shape: logical.shape.clone(),
padding_axes: pad,
quant: logical
.quantization
.as_ref()
.map(schema_quant_to_runtime)
.transpose()?,
});
}
let first_has_stride = logical.outputs[0].stride.is_some();
if first_has_stride {
plan_per_scale(logical)
} else {
plan_channel_concat(logical)
}
}
fn plan_per_scale(logical: &LogicalOutput) -> DecoderResult<LogicalMerge> {
let mut children = logical
.outputs
.iter()
.map(physical_binding)
.collect::<DecoderResult<Vec<_>>>()?;
children.sort_by_key(|c| c.stride.map(|s| s.x()).unwrap_or(0));
let (feature_axis_logical, box_axis_logical) = logical_per_scale_axes(logical)?;
let dfl = if logical.type_ == Some(LogicalType::Boxes)
&& logical.encoding == Some(schema::BoxEncoding::Dfl)
{
Some(plan_dfl(logical, &children)?)
} else {
None
};
Ok(LogicalMerge::PerScale {
children,
logical_shape: logical.shape.clone(),
feature_axis_logical,
box_axis_logical,
dfl,
})
}
fn plan_dfl(logical: &LogicalOutput, children: &[PhysicalBinding]) -> DecoderResult<DflConfig> {
let first_feat = child_feature_count(&children[0])?;
if first_feat == 0 || first_feat % 4 != 0 {
return Err(DecoderError::InvalidConfig(format!(
"DFL logical `{}` first child feature count {first_feat} is not a positive multiple of 4",
logical.name.as_deref().unwrap_or("<anonymous>")
)));
}
let reg_max = first_feat / 4;
let mut grids = Vec::with_capacity(children.len());
for child in children {
let feat = child_feature_count(child)?;
if feat / 4 != reg_max {
return Err(DecoderError::NotSupported(format!(
"DFL logical `{}` has heterogeneous reg_max across children \
(child `{}` feature count {feat}, expected {}). \
Per-child reg_max is not yet supported.",
logical.name.as_deref().unwrap_or("<anonymous>"),
child.name,
reg_max * 4,
)));
}
let (h, w) = child_hw(child)?;
let stride = child.stride.map(|s| s.x() as f32).ok_or_else(|| {
DecoderError::InvalidConfig(format!(
"DFL child `{}` has no stride — required for anchor-grid pre-compute",
child.name
))
})?;
let (gx, gy) = dfl::make_anchor_grid(h, w);
grids.push(DflChildGrid {
stride,
grid_x: gx,
grid_y: gy,
});
}
Ok(DflConfig { reg_max, grids })
}
fn child_feature_count(child: &PhysicalBinding) -> DecoderResult<usize> {
for (i, (name, _)) in child.dshape.iter().enumerate() {
if matches!(
name,
DimName::NumFeatures
| DimName::NumClasses
| DimName::NumProtos
| DimName::BoxCoords
| DimName::NumAnchorsXFeatures
) {
return Ok(child.shape[i]);
}
}
Err(DecoderError::InvalidConfig(format!(
"per-scale child `{}` dshape {:?} lacks a feature axis",
child.name, child.dshape
)))
}
fn child_hw(child: &PhysicalBinding) -> DecoderResult<(usize, usize)> {
let mut h = None;
let mut w = None;
for (i, (name, _)) in child.dshape.iter().enumerate() {
match name {
DimName::Height => h = Some(child.shape[i]),
DimName::Width => w = Some(child.shape[i]),
_ => {}
}
}
match (h, w) {
(Some(h), Some(w)) => Ok((h, w)),
_ => Err(DecoderError::InvalidConfig(format!(
"DFL per-scale child `{}` dshape {:?} must name both `height` and `width`",
child.name, child.dshape
))),
}
}
fn plan_channel_concat(logical: &LogicalOutput) -> DecoderResult<LogicalMerge> {
let children = logical
.outputs
.iter()
.map(physical_binding)
.collect::<DecoderResult<Vec<_>>>()?;
let channel_axis = channel_axis_in_logical(logical)?;
let pad = padding_axes(&logical.dshape);
let (squeezed_shape, _) = squeeze_padding_dims(logical.shape.clone(), logical.dshape.clone());
Ok(LogicalMerge::ChannelConcat {
children,
logical_shape: squeezed_shape,
channel_axis,
padding_axes: pad,
})
}
fn physical_binding(p: &schema::PhysicalOutput) -> DecoderResult<PhysicalBinding> {
let quant = p
.quantization
.as_ref()
.map(schema_quant_to_runtime)
.transpose()?;
Ok(PhysicalBinding {
name: p.name.clone(),
shape: p.shape.clone(),
dshape: p.dshape.clone(),
quant,
stride: p.stride,
activation_applied: p.activation_applied,
})
}
fn schema_quant_to_runtime(q: &schema::Quantization) -> DecoderResult<Quantization> {
if q.is_per_channel() {
return Err(DecoderError::NotSupported(format!(
"per-channel quantization (axis {:?}, {} scales) is not yet \
supported by the HAL merge path",
q.axis,
q.scale.len(),
)));
}
Ok(Quantization::new(
*q.scale.first().unwrap_or(&0.0),
q.zero_point_at(0),
))
}
fn logical_per_scale_axes(logical: &LogicalOutput) -> DecoderResult<(usize, usize)> {
if logical.dshape.is_empty() {
return Err(DecoderError::InvalidConfig(format!(
"logical `{}` has per-scale children but no `dshape`; cannot \
infer feature / box axes for merge",
logical.name.as_deref().unwrap_or("<anonymous>")
)));
}
let feature = logical.dshape.iter().position(|(n, _)| {
matches!(
n,
DimName::NumFeatures
| DimName::NumClasses
| DimName::NumProtos
| DimName::BoxCoords
| DimName::NumAnchorsXFeatures
)
});
let boxes = logical
.dshape
.iter()
.position(|(n, _)| matches!(n, DimName::NumBoxes));
match (feature, boxes) {
(Some(f), Some(b)) => Ok((f, b)),
_ => Err(DecoderError::InvalidConfig(format!(
"logical `{}` dshape {:?} must name both a feature dim and `num_boxes`",
logical.name.as_deref().unwrap_or("<anonymous>"),
logical.dshape,
))),
}
}
fn channel_axis_in_logical(logical: &LogicalOutput) -> DecoderResult<usize> {
if !logical.dshape.is_empty() {
for (i, (name, _)) in logical.dshape.iter().enumerate() {
if matches!(
name,
DimName::BoxCoords
| DimName::NumFeatures
| DimName::NumClasses
| DimName::NumProtos
| DimName::NumAnchorsXFeatures
) {
return Ok(i);
}
}
}
Err(DecoderError::InvalidConfig(format!(
"logical `{}` has channel-sub-split children; `dshape` must name \
a channel axis (box_coords, num_features, num_classes, num_protos)",
logical.name.as_deref().unwrap_or("<anonymous>")
)))
}
fn execute_merge(
merge: &LogicalMerge,
inputs: &[&TensorDyn],
used: &mut Vec<usize>,
) -> DecoderResult<ArrayD<f32>> {
match merge {
LogicalMerge::Direct {
shape,
padding_axes,
quant,
..
} => {
let mut arr = find_and_dequantize(inputs, shape, *quant, used)?;
for &ax in padding_axes {
if ax < arr.ndim() && arr.shape()[ax] == 1 {
arr = arr.remove_axis(Axis(ax));
}
}
Ok(arr)
}
LogicalMerge::ChannelConcat {
children,
logical_shape,
channel_axis,
padding_axes,
} => execute_channel_concat(
inputs,
children,
logical_shape,
*channel_axis,
padding_axes,
used,
),
LogicalMerge::PerScale {
children,
logical_shape,
feature_axis_logical,
box_axis_logical,
dfl,
} => execute_per_scale(
inputs,
children,
logical_shape,
*feature_axis_logical,
*box_axis_logical,
dfl.as_ref(),
used,
),
}
}
fn find_unused_tensor_by_shape<'a>(
inputs: &'a [&'a TensorDyn],
shape: &[usize],
used: &mut Vec<usize>,
) -> DecoderResult<&'a TensorDyn> {
for (i, t) in inputs.iter().enumerate() {
if used.contains(&i) {
continue;
}
if t.shape() == shape {
used.push(i);
return Ok(*t);
}
}
Err(DecoderError::InvalidShape(format!(
"no remaining input tensor matches shape {shape:?} (already \
bound tensors are excluded; pass inputs in schema child order, \
or use name-keyed decode once available)"
)))
}
fn find_and_dequantize(
inputs: &[&TensorDyn],
expected_shape: &[usize],
quant: Option<Quantization>,
used: &mut Vec<usize>,
) -> DecoderResult<ArrayD<f32>> {
if let Ok(t) = find_unused_tensor_by_shape(inputs, expected_shape, used) {
return tensor_to_f32(t, quant);
}
let expected_count: usize = expected_shape.iter().product();
for (i, t) in inputs.iter().enumerate() {
if used.contains(&i) {
continue;
}
let count: usize = t.shape().iter().product();
if count != expected_count {
continue;
}
let mut padded_shape = t.shape().to_vec();
while padded_shape.len() < expected_shape.len() {
padded_shape.push(1);
}
if let Some(perm) = find_axis_permutation(&padded_shape, expected_shape) {
used.push(i);
let arr = tensor_to_f32(t, quant)?;
let mut arr = if arr.ndim() < expected_shape.len() {
arr.into_shape_with_order(IxDyn(&padded_shape))
.map_err(DecoderError::NDArrayShape)?
} else {
arr
};
arr = arr.permuted_axes(IxDyn(&perm));
arr = arr.as_standard_layout().to_owned();
debug_assert_eq!(arr.shape(), expected_shape);
return Ok(arr);
}
}
Err(DecoderError::InvalidShape(format!(
"no remaining input tensor matches shape {expected_shape:?} \
(tried exact match and element-count + permutation; already \
bound tensors are excluded)"
)))
}
fn find_axis_permutation(from: &[usize], to: &[usize]) -> Option<Vec<usize>> {
if from.len() != to.len() {
return None;
}
let n = from.len();
let mut perm = vec![0usize; n];
let mut bound = vec![false; n];
for (i, &target_dim) in to.iter().enumerate() {
let mut found = false;
for (j, &source_dim) in from.iter().enumerate() {
if !bound[j] && source_dim == target_dim {
perm[i] = j;
bound[j] = true;
found = true;
break;
}
}
if !found {
return None;
}
}
Some(perm)
}
fn tensor_to_f32(t: &TensorDyn, quant: Option<Quantization>) -> DecoderResult<ArrayD<f32>> {
let shape = t.shape().to_vec();
match t {
TensorDyn::F16(tensor) => {
let m = tensor
.map()
.map_err(|e| DecoderError::Internal(format!("tensor map: {e}")))?;
use half::slice::HalfFloatSliceExt;
let total: usize = shape.iter().product();
let mut out = vec![0.0_f32; total];
m.as_slice().convert_to_f32_slice(&mut out);
Ok(Array::from_shape_vec(IxDyn(&shape), out)?)
}
TensorDyn::F32(tensor) => {
let m = tensor
.map()
.map_err(|e| DecoderError::Internal(format!("tensor map: {e}")))?;
let view = ArrayViewD::from_shape(IxDyn(&shape), m.as_slice())?;
Ok(view.to_owned())
}
TensorDyn::F64(tensor) => {
let m = tensor
.map()
.map_err(|e| DecoderError::Internal(format!("tensor map: {e}")))?;
let view = ArrayViewD::from_shape(IxDyn(&shape), m.as_slice())?;
Ok(view.mapv(|v| v as f32))
}
TensorDyn::U8(_)
| TensorDyn::I8(_)
| TensorDyn::U16(_)
| TensorDyn::I16(_)
| TensorDyn::U32(_)
| TensorDyn::I32(_) => dequantize_integer_tensor(t, quant, &shape),
other => Err(DecoderError::NotSupported(format!(
"merge: unsupported tensor dtype {:?}",
other.dtype()
))),
}
}
fn dequantize_integer_tensor(
t: &TensorDyn,
quant: Option<Quantization>,
shape: &[usize],
) -> DecoderResult<ArrayD<f32>> {
let quant = quant.unwrap_or(Quantization::new(1.0, 0));
let total: usize = shape.iter().product();
let mut out = vec![0.0_f32; total];
macro_rules! dq {
($tensor:expr) => {{
let m = $tensor
.map()
.map_err(|e| DecoderError::Internal(format!("tensor map: {e}")))?;
dequantize_cpu_chunked(m.as_slice(), quant, &mut out);
}};
}
match t {
TensorDyn::U8(tensor) => dq!(tensor),
TensorDyn::I8(tensor) => dq!(tensor),
TensorDyn::U16(tensor) => dq!(tensor),
TensorDyn::I16(tensor) => dq!(tensor),
TensorDyn::U32(tensor) => dq!(tensor),
TensorDyn::I32(tensor) => dq!(tensor),
_ => unreachable!("dequantize_integer_tensor called on non-integer dtype"),
}
let arr = Array::from_shape_vec(IxDyn(shape), out)?;
Ok(arr)
}
fn execute_channel_concat(
inputs: &[&TensorDyn],
children: &[PhysicalBinding],
logical_shape: &[usize],
channel_axis: usize,
padding_axes: &[usize],
used: &mut Vec<usize>,
) -> DecoderResult<ArrayD<f32>> {
let mut parts = Vec::with_capacity(children.len());
for child in children {
parts.push(find_and_dequantize(
inputs,
&child.shape,
child.quant,
used,
)?);
}
let views: Vec<_> = parts.iter().map(|a| a.view()).collect();
let mut merged =
ndarray::concatenate(Axis(channel_axis), &views).map_err(DecoderError::NDArrayShape)?;
for &ax in padding_axes {
if ax < merged.ndim() && merged.shape()[ax] == 1 {
merged = merged.remove_axis(Axis(ax));
}
}
if merged.shape() != logical_shape {
return Err(DecoderError::InvalidShape(format!(
"channel-concat produced shape {:?} but logical expected {:?}",
merged.shape(),
logical_shape
)));
}
Ok(merged)
}
fn execute_per_scale(
inputs: &[&TensorDyn],
children: &[PhysicalBinding],
logical_shape: &[usize],
feature_axis_logical: usize,
box_axis_logical: usize,
dfl_cfg: Option<&DflConfig>,
used: &mut Vec<usize>,
) -> DecoderResult<ArrayD<f32>> {
if children.is_empty() {
return Err(DecoderError::InvalidConfig(
"per-scale merge with zero children".into(),
));
}
let mut per_scale_parts: Vec<Array3<f32>> = Vec::with_capacity(children.len());
let mut feature_count: Option<usize> = None;
let mut batch: Option<usize> = None;
for (idx, child) in children.iter().enumerate() {
let arr = find_and_dequantize(inputs, &child.shape, child.quant, used)?;
let (b, features, part) = child_to_batch_feature_spatial(arr, child)?;
match batch {
None => batch = Some(b),
Some(prev) if prev != b => {
return Err(DecoderError::InvalidShape(format!(
"per-scale children have inconsistent batch: {prev} vs {b}"
)));
}
_ => {}
}
match feature_count {
None => feature_count = Some(features),
Some(prev) if prev != features => {
return Err(DecoderError::InvalidShape(format!(
"per-scale children have inconsistent feature count: {prev} vs {features}"
)));
}
_ => {}
}
let part = match dfl_cfg {
Some(cfg) => dfl_decode_child(part, cfg, idx)?,
None => part,
};
per_scale_parts.push(part);
}
let views: Vec<_> = per_scale_parts.iter().map(|a| a.view()).collect();
let merged = ndarray::concatenate(Axis(2), &views).map_err(DecoderError::NDArrayShape)?;
reshape_to_logical(
merged.into_dyn(),
logical_shape,
feature_axis_logical,
box_axis_logical,
)
}
fn dfl_decode_child(
part: Array3<f32>,
cfg: &DflConfig,
child_idx: usize,
) -> DecoderResult<Array3<f32>> {
let (batch, features, n) = part.dim();
let expected_feat = 4 * cfg.reg_max;
if features != expected_feat {
return Err(DecoderError::InvalidShape(format!(
"DFL child {child_idx}: feature count {features} != 4 × reg_max ({expected_feat})"
)));
}
if batch != 1 {
return Err(DecoderError::NotSupported(format!(
"DFL decode with batch={batch} is not supported (only batch=1 today)"
)));
}
let grid = cfg
.grids
.get(child_idx)
.ok_or_else(|| DecoderError::Internal(format!("DFL grid missing for child {child_idx}")))?;
if grid.grid_x.len() != n {
return Err(DecoderError::InvalidShape(format!(
"DFL child {child_idx}: anchor count {n} != precomputed grid {}",
grid.grid_x.len()
)));
}
let transposed = part.permuted_axes([0, 2, 1]);
let contiguous = transposed.as_standard_layout().to_owned();
let flat = contiguous
.as_slice()
.ok_or_else(|| DecoderError::Internal("DFL transposed slice not contiguous".into()))?;
let decoded = dfl::decode_dfl_level(
flat,
1,
n,
cfg.reg_max,
&grid.grid_x,
&grid.grid_y,
grid.stride,
);
let decoded_nhwc = Array::from_shape_vec(ndarray::Ix3(1, n, 4), decoded)
.map_err(DecoderError::NDArrayShape)?;
let out = decoded_nhwc
.permuted_axes([0, 2, 1])
.as_standard_layout()
.to_owned();
Ok(out)
}
fn child_to_batch_feature_spatial(
arr: ArrayD<f32>,
child: &PhysicalBinding,
) -> DecoderResult<(usize, usize, Array3<f32>)> {
if child.dshape.is_empty() {
return Err(DecoderError::InvalidConfig(format!(
"per-scale child `{}` must declare `dshape` for layout \
disambiguation (NCHW vs NHWC)",
child.name
)));
}
let shape = arr.shape().to_vec();
if shape.len() != child.dshape.len() {
return Err(DecoderError::InvalidShape(format!(
"per-scale child `{}` shape rank {} does not match dshape rank {}",
child.name,
shape.len(),
child.dshape.len()
)));
}
let mut batch = 1usize;
let mut height = 1usize;
let mut width = 1usize;
let mut features = 1usize;
for (i, (name, _)) in child.dshape.iter().enumerate() {
let size = shape[i];
match name {
DimName::Batch => batch = size,
DimName::Height => height = size,
DimName::Width => width = size,
DimName::NumFeatures
| DimName::NumClasses
| DimName::NumProtos
| DimName::BoxCoords
| DimName::NumAnchorsXFeatures => features = size,
DimName::NumBoxes | DimName::Padding => {}
}
}
let b_axis = axis_index(&child.dshape, &[DimName::Batch]).ok_or_else(|| {
DecoderError::InvalidConfig(format!(
"per-scale child `{}` dshape {:?} lacks a `batch` axis",
child.name, child.dshape
))
})?;
let f_axis = axis_index(
&child.dshape,
&[
DimName::NumFeatures,
DimName::NumClasses,
DimName::NumProtos,
DimName::BoxCoords,
DimName::NumAnchorsXFeatures,
],
)
.ok_or_else(|| {
DecoderError::InvalidConfig(format!(
"per-scale child `{}` dshape {:?} lacks a feature axis",
child.name, child.dshape
))
})?;
let h_axis = axis_index(&child.dshape, &[DimName::Height]);
let w_axis = axis_index(&child.dshape, &[DimName::Width]);
let mut perm: Vec<usize> = vec![b_axis, f_axis];
if let Some(h) = h_axis {
perm.push(h);
}
if let Some(w) = w_axis {
perm.push(w);
}
for i in 0..child.dshape.len() {
if !perm.contains(&i) {
perm.push(i);
}
}
debug_assert_eq!(perm.len(), child.dshape.len());
let permuted = arr.permuted_axes(IxDyn(&perm));
let contiguous = permuted
.as_standard_layout()
.to_owned()
.into_dimensionality::<IxDyn>()
.map_err(DecoderError::NDArrayShape)?;
let spatial = height * width;
let reshaped = contiguous
.into_shape_with_order(ndarray::Ix3(batch, features, spatial))
.map_err(|e| {
DecoderError::InvalidShape(format!(
"per-scale: failed to reshape permuted child `{}` to \
(batch,features,spatial)=({batch},{features},{spatial}): {e}",
child.name
))
})?;
Ok((batch, features, reshaped))
}
fn axis_index(dshape: &[(DimName, usize)], any_of: &[DimName]) -> Option<usize> {
dshape.iter().position(|(n, _)| any_of.contains(n))
}
fn reshape_to_logical(
merged: ArrayD<f32>,
logical_shape: &[usize],
feature_axis_logical: usize,
box_axis_logical: usize,
) -> DecoderResult<ArrayD<f32>> {
if logical_shape.len() == 3
&& feature_axis_logical == 1
&& box_axis_logical == 2
&& merged.shape() == logical_shape
{
return Ok(merged);
}
let total: usize = logical_shape.iter().product();
if merged.len() != total {
return Err(DecoderError::InvalidShape(format!(
"merged shape {:?} has {} elements; logical shape {:?} \
expects {} elements",
merged.shape(),
merged.len(),
logical_shape,
total
)));
}
let batch_pos_logical = (0..logical_shape.len())
.find(|&i| i != feature_axis_logical && i != box_axis_logical)
.unwrap_or(0);
let mut target = vec![0usize; logical_shape.len()];
target[batch_pos_logical] = 0; target[feature_axis_logical] = 1;
target[box_axis_logical] = 2;
if logical_shape.len() != 3 {
return Err(DecoderError::NotSupported(format!(
"per-scale merge into logical rank {} is not yet supported \
(only rank-3 [batch, features, boxes] today)",
logical_shape.len()
)));
}
let inv_perm: Vec<usize> = (0..3)
.map(|src| target.iter().position(|&t| t == src).unwrap())
.collect();
let out = merged.permuted_axes(inv_perm);
Ok(out.as_standard_layout().to_owned())
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use crate::schema::{
BoxEncoding, DType, DecoderKind, LogicalOutput, LogicalType, PhysicalOutput, PhysicalType,
Quantization as SchemaQuant, SchemaV2, ScoreFormat, Stride as SchemaStride,
};
use edgefirst_tensor::{Tensor, TensorDyn, TensorMapTrait, TensorMemory, TensorTrait};
fn make_u8_tensor(shape: &[usize], values: &[u8]) -> TensorDyn {
let t = Tensor::<u8>::new(shape, Some(TensorMemory::Mem), None).unwrap();
let mut m = t.map().unwrap();
let slice = m.as_mut_slice();
slice[..values.len()].copy_from_slice(values);
drop(m);
TensorDyn::U8(t)
}
fn make_i16_tensor(shape: &[usize], values: &[i16]) -> TensorDyn {
let t = Tensor::<i16>::new(shape, Some(TensorMemory::Mem), None).unwrap();
let mut m = t.map().unwrap();
let slice = m.as_mut_slice();
slice[..values.len()].copy_from_slice(values);
drop(m);
TensorDyn::I16(t)
}
fn make_f32_tensor(shape: &[usize], values: &[f32]) -> TensorDyn {
let t = Tensor::<f32>::new(shape, Some(TensorMemory::Mem), None).unwrap();
let mut m = t.map().unwrap();
let slice = m.as_mut_slice();
slice[..values.len()].copy_from_slice(values);
drop(m);
TensorDyn::F32(t)
}
fn make_f16_tensor(shape: &[usize], values: &[f32]) -> TensorDyn {
let t = Tensor::<half::f16>::new(shape, Some(TensorMemory::Mem), None).unwrap();
let mut m = t.map().unwrap();
let slice = m.as_mut_slice();
for (dst, &src) in slice.iter_mut().zip(values.iter()) {
*dst = half::f16::from_f32(src);
}
drop(m);
TensorDyn::F16(t)
}
#[test]
fn tensor_to_f32_widens_f16_natively() {
let t = make_f16_tensor(&[2, 3], &[1.0, -2.0, 0.5, 0.25, -0.125, 4.0]);
let arr = super::tensor_to_f32(&t, None).unwrap();
assert_eq!(arr.shape(), &[2, 3]);
let flat: Vec<f32> = arr.iter().copied().collect();
assert_eq!(flat, vec![1.0, -2.0, 0.5, 0.25, -0.125, 4.0]);
}
fn per_tensor_q(scale: f32, zp: i32, dt: DType) -> SchemaQuant {
SchemaQuant {
scale: vec![scale],
zero_point: Some(vec![zp]),
axis: None,
dtype: Some(dt),
}
}
#[test]
fn typeless_logical_not_included_in_decode_program() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![
LogicalOutput {
name: Some("user_custom".into()),
type_: None,
shape: vec![1, 32],
dshape: vec![],
decoder: None,
encoding: None,
score_format: None,
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![],
},
LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 3),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![PhysicalOutput {
name: "boxes_raw".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 4, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 3),
],
dtype: DType::Float32,
quantization: None,
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
}],
},
],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema)
.unwrap()
.expect("typed boxes has children → program must be Some");
let legacy = schema.to_legacy_config_outputs().unwrap();
assert_eq!(legacy.outputs.len(), 1);
let boxes_tensor = make_f32_tensor(&[1, 4, 3], &[1.0; 12]);
let inputs: Vec<&TensorDyn> = vec![&boxes_tensor];
let merged = program.execute(&inputs).unwrap();
assert_eq!(
merged.len(),
1,
"decode program must emit one tensor per typed logical, not \
per schema-order logical (would otherwise misalign with \
legacy ConfigOutputs passed to decode_float)"
);
}
#[test]
fn flat_schema_has_no_decode_program() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 8400],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 8400),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: Some(DType::Float32),
quantization: None,
outputs: vec![],
}],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap();
assert!(program.is_none());
}
#[test]
fn channel_concat_merges_xy_and_wh_to_logical_shape() {
let boxes_logical = LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 3),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![
PhysicalOutput {
name: "xy".into(),
type_: Some(PhysicalType::BoxesXy),
shape: vec![1, 2, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 2),
(DimName::NumBoxes, 3),
],
dtype: DType::Int16,
quantization: Some(per_tensor_q(0.01, 0, DType::Int16)),
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "wh".into(),
type_: Some(PhysicalType::BoxesWh),
shape: vec![1, 2, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 2),
(DimName::NumBoxes, 3),
],
dtype: DType::Int16,
quantization: Some(per_tensor_q(0.02, 0, DType::Int16)),
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
],
};
let mut schema = SchemaV2::default();
schema.outputs.push(LogicalOutput {
shape: vec![1, 3, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 3),
(DimName::NumBoxes, 3),
],
outputs: vec![
PhysicalOutput {
name: "xy".into(),
type_: Some(PhysicalType::BoxesXy),
shape: vec![1, 1, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 1),
(DimName::NumBoxes, 3),
],
dtype: DType::Int16,
quantization: Some(per_tensor_q(0.01, 0, DType::Int16)),
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "wh".into(),
type_: Some(PhysicalType::BoxesWh),
shape: vec![1, 2, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 2),
(DimName::NumBoxes, 3),
],
dtype: DType::Int16,
quantization: Some(per_tensor_q(0.02, 0, DType::Int16)),
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
],
..boxes_logical
});
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
let xy = make_i16_tensor(&[1, 1, 3], &[100, 200, 300]);
let wh = make_i16_tensor(&[1, 2, 3], &[10, 20, 30, 40, 50, 60]);
let inputs: Vec<&TensorDyn> = vec![&xy, &wh];
let merged = program.execute(&inputs).unwrap();
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].shape(), &[1, 3, 3]);
let arr = &merged[0];
assert!((arr[[0, 0, 0]] - 1.0).abs() < 1e-5);
assert!((arr[[0, 0, 2]] - 3.0).abs() < 1e-5);
assert!((arr[[0, 1, 0]] - 0.2).abs() < 1e-5);
assert!((arr[[0, 2, 2]] - 1.2).abs() < 1e-5);
}
#[test]
fn per_scale_merge_nhwc_to_nchw() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 5],
dshape: vec![
(DimName::Batch, 1),
(DimName::NumFeatures, 4),
(DimName::NumBoxes, 5),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![
PhysicalOutput {
name: "b0".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 2, 2, 4],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 2),
(DimName::Width, 2),
(DimName::NumFeatures, 4),
],
dtype: DType::Uint8,
quantization: Some(per_tensor_q(1.0, 0, DType::Uint8)),
stride: Some(SchemaStride::Square(8)),
scale_index: Some(0),
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "b1".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 1, 1, 4],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 1),
(DimName::Width, 1),
(DimName::NumFeatures, 4),
],
dtype: DType::Uint8,
quantization: Some(per_tensor_q(1.0, 0, DType::Uint8)),
stride: Some(SchemaStride::Square(16)),
scale_index: Some(1),
activation_applied: None,
activation_required: None,
},
],
}],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
let b0 = make_u8_tensor(
&[1, 2, 2, 4],
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
);
let b1 = make_u8_tensor(&[1, 1, 1, 4], &[100, 101, 102, 103]);
let inputs: Vec<&TensorDyn> = vec![&b0, &b1];
let merged = program.execute(&inputs).unwrap();
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].shape(), &[1, 4, 5]);
let arr = &merged[0];
assert_eq!(arr[[0, 0, 0]], 1.0);
assert_eq!(arr[[0, 0, 1]], 5.0);
assert_eq!(arr[[0, 0, 2]], 9.0);
assert_eq!(arr[[0, 0, 3]], 13.0);
assert_eq!(arr[[0, 0, 4]], 100.0);
assert_eq!(arr[[0, 3, 0]], 4.0);
assert_eq!(arr[[0, 3, 4]], 103.0);
}
#[test]
fn direct_logical_with_float_tensor_pass_through() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![
LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 3),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: Some(DType::Float32),
quantization: None,
outputs: vec![],
},
LogicalOutput {
name: Some("scores".into()),
type_: Some(LogicalType::Scores),
shape: vec![1, 2, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::NumClasses, 2),
(DimName::NumBoxes, 3),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: None,
score_format: Some(ScoreFormat::PerClass),
normalized: None,
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![
PhysicalOutput {
name: "s0".into(),
type_: Some(PhysicalType::Scores),
shape: vec![1, 1, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::NumClasses, 1),
(DimName::NumBoxes, 3),
],
dtype: DType::Uint8,
quantization: Some(per_tensor_q(0.5, 0, DType::Uint8)),
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "s1".into(),
type_: Some(PhysicalType::Scores),
shape: vec![1, 1, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::NumClasses, 1),
(DimName::NumBoxes, 3),
],
dtype: DType::Uint8,
quantization: Some(per_tensor_q(0.25, 0, DType::Uint8)),
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
],
},
],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
let boxes = make_f32_tensor(
&[1, 4, 3],
&[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
);
let s0 = make_u8_tensor(&[1, 1, 3], &[2, 4, 6]);
let s1 = make_u8_tensor(&[1, 1, 3], &[8, 16, 24]);
let inputs: Vec<&TensorDyn> = vec![&boxes, &s0, &s1];
let merged = program.execute(&inputs).unwrap();
assert_eq!(merged.len(), 2);
assert!((merged[0][[0, 0, 0]] - 0.1).abs() < 1e-6);
assert!((merged[0][[0, 3, 2]] - 1.2).abs() < 1e-6);
assert!((merged[1][[0, 0, 0]] - 1.0).abs() < 1e-6);
assert!((merged[1][[0, 0, 2]] - 3.0).abs() < 1e-6);
assert!((merged[1][[0, 1, 0]] - 2.0).abs() < 1e-6);
assert!((merged[1][[0, 1, 2]] - 6.0).abs() < 1e-6);
}
#[test]
fn dfl_split_with_per_scale_children_is_accepted_and_exposes_reg_max() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 6400],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 6400),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Dfl),
score_format: None,
normalized: Some(false),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![PhysicalOutput {
name: "b0".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 80, 80, 64],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 80),
(DimName::Width, 80),
(DimName::NumFeatures, 64),
],
dtype: DType::Uint8,
quantization: Some(per_tensor_q(0.01, 128, DType::Uint8)),
stride: Some(SchemaStride::Square(8)),
scale_index: Some(0),
activation_applied: None,
activation_required: None,
}],
}],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
assert_eq!(program.boxes_reg_max(), Some(16));
}
#[test]
fn per_scale_dfl_merge_produces_4ch_pixel_coordinates() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 5],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 5),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Dfl),
score_format: None,
normalized: Some(false),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![
PhysicalOutput {
name: "b0".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 2, 2, 16],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 2),
(DimName::Width, 2),
(DimName::NumFeatures, 16),
],
dtype: DType::Float32,
quantization: None,
stride: Some(SchemaStride::Square(8)),
scale_index: Some(0),
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "b1".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 1, 1, 16],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 1),
(DimName::Width, 1),
(DimName::NumFeatures, 16),
],
dtype: DType::Float32,
quantization: None,
stride: Some(SchemaStride::Square(16)),
scale_index: Some(1),
activation_applied: None,
activation_required: None,
},
],
}],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
let b0 = make_f32_tensor(&[1, 2, 2, 16], &[1.0f32; 2 * 2 * 16]);
let b1 = make_f32_tensor(&[1, 1, 1, 16], &[1.0f32; 16]);
let inputs: Vec<&TensorDyn> = vec![&b0, &b1];
let merged = program.execute(&inputs).unwrap();
assert_eq!(merged.len(), 1);
assert_eq!(merged[0].shape(), &[1, 4, 5]);
let arr = &merged[0];
assert!(
(arr[[0, 0, 0]] - 4.0).abs() < 1e-3,
"xc[0]={}",
arr[[0, 0, 0]]
);
assert!(
(arr[[0, 0, 1]] - 12.0).abs() < 1e-3,
"xc[1]={}",
arr[[0, 0, 1]]
);
assert!(
(arr[[0, 0, 2]] - 4.0).abs() < 1e-3,
"xc[2]={}",
arr[[0, 0, 2]]
);
assert!(
(arr[[0, 0, 3]] - 12.0).abs() < 1e-3,
"xc[3]={}",
arr[[0, 0, 3]]
);
assert!(
(arr[[0, 0, 4]] - 8.0).abs() < 1e-3,
"xc[4]={}",
arr[[0, 0, 4]]
);
assert!((arr[[0, 1, 0]] - 4.0).abs() < 1e-3);
assert!((arr[[0, 1, 2]] - 12.0).abs() < 1e-3);
assert!((arr[[0, 1, 4]] - 8.0).abs() < 1e-3);
for a in 0..4 {
assert!((arr[[0, 2, a]] - 24.0).abs() < 1e-3);
assert!((arr[[0, 3, a]] - 24.0).abs() < 1e-3);
}
assert!((arr[[0, 2, 4]] - 48.0).abs() < 1e-3);
assert!((arr[[0, 3, 4]] - 48.0).abs() < 1e-3);
}
#[test]
fn dfl_children_declared_out_of_stride_order_are_sorted_ascending() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 5],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 5),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Dfl),
score_format: None,
normalized: Some(false),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![
PhysicalOutput {
name: "b_big".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 1, 1, 16],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 1),
(DimName::Width, 1),
(DimName::NumFeatures, 16),
],
dtype: DType::Float32,
quantization: None,
stride: Some(SchemaStride::Square(16)),
scale_index: Some(1),
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "b_small".into(),
type_: Some(PhysicalType::Boxes),
shape: vec![1, 2, 2, 16],
dshape: vec![
(DimName::Batch, 1),
(DimName::Height, 2),
(DimName::Width, 2),
(DimName::NumFeatures, 16),
],
dtype: DType::Float32,
quantization: None,
stride: Some(SchemaStride::Square(8)),
scale_index: Some(0),
activation_applied: None,
activation_required: None,
},
],
}],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
let big = make_f32_tensor(&[1, 1, 1, 16], &[1.0f32; 16]);
let small = make_f32_tensor(&[1, 2, 2, 16], &[1.0f32; 2 * 2 * 16]);
let inputs: Vec<&TensorDyn> = vec![&big, &small];
let merged = program.execute(&inputs).unwrap();
let arr = &merged[0];
for a in 0..4 {
assert!(
(arr[[0, 2, a]] - 24.0).abs() < 1e-3,
"anchor {a} w={} (expected 24 stride-8)",
arr[[0, 2, a]]
);
}
assert!(
(arr[[0, 2, 4]] - 48.0).abs() < 1e-3,
"last anchor w={} (expected 48 stride-16)",
arr[[0, 2, 4]]
);
}
#[test]
fn dequantize_affine_reference_values_match_validator() {
let schema = SchemaV2 {
schema_version: 2,
outputs: vec![
LogicalOutput {
name: Some("scores".into()),
type_: Some(LogicalType::Scores),
shape: vec![1, 1, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::NumClasses, 1),
(DimName::NumBoxes, 3),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: None,
score_format: Some(ScoreFormat::PerClass),
normalized: None,
anchors: None,
stride: None,
dtype: Some(DType::Uint8),
quantization: Some(per_tensor_q(0.130, 70, DType::Uint8)),
outputs: vec![],
},
LogicalOutput {
name: Some("boxes".into()),
type_: Some(LogicalType::Boxes),
shape: vec![1, 4, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 4),
(DimName::NumBoxes, 3),
],
decoder: Some(DecoderKind::Ultralytics),
encoding: Some(BoxEncoding::Direct),
score_format: None,
normalized: Some(true),
anchors: None,
stride: None,
dtype: None,
quantization: None,
outputs: vec![
PhysicalOutput {
name: "b0".into(),
type_: Some(PhysicalType::BoxesXy),
shape: vec![1, 2, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 2),
(DimName::NumBoxes, 3),
],
dtype: DType::Float32,
quantization: None,
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
PhysicalOutput {
name: "b1".into(),
type_: Some(PhysicalType::BoxesWh),
shape: vec![1, 2, 3],
dshape: vec![
(DimName::Batch, 1),
(DimName::BoxCoords, 2),
(DimName::NumBoxes, 3),
],
dtype: DType::Float32,
quantization: None,
stride: None,
scale_index: None,
activation_applied: None,
activation_required: None,
},
],
},
],
..Default::default()
};
let program = DecodeProgram::try_from_schema(&schema).unwrap().unwrap();
let scores = make_u8_tensor(&[1, 1, 3], &[0, 70, 255]);
let xy = make_f32_tensor(&[1, 2, 3], &[0.0f32; 6]);
let wh = make_f32_tensor(&[1, 2, 3], &[0.0f32; 6]);
let inputs: Vec<&TensorDyn> = vec![&scores, &xy, &wh];
let merged = program.execute(&inputs).unwrap();
let scores_out = &merged[0];
assert!(
(scores_out[[0, 0, 0]] - (-9.10)).abs() < 1e-4,
"{}",
scores_out[[0, 0, 0]]
);
assert!(
(scores_out[[0, 0, 1]] - 0.00).abs() < 1e-4,
"{}",
scores_out[[0, 0, 1]]
);
assert!(
(scores_out[[0, 0, 2]] - 24.05).abs() < 1e-3,
"{}",
scores_out[[0, 0, 2]]
);
}
#[test]
fn find_and_dequantize_exact_match_preferred() {
let t = make_f32_tensor(&[1, 3, 4], &[1.0; 12]);
let inputs: Vec<&TensorDyn> = vec![&t];
let mut used = Vec::new();
let arr = find_and_dequantize(&inputs, &[1, 3, 4], None, &mut used).unwrap();
assert_eq!(arr.shape(), &[1, 3, 4]);
assert_eq!(used, vec![0]);
}
#[test]
fn find_and_dequantize_permuted_nchw_to_nhwc() {
let t = make_f32_tensor(
&[1, 3, 2, 2],
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
);
let inputs: Vec<&TensorDyn> = vec![&t];
let mut used = Vec::new();
let arr = find_and_dequantize(&inputs, &[1, 2, 2, 3], None, &mut used).unwrap();
assert_eq!(arr.shape(), &[1, 2, 2, 3]);
assert_eq!(arr[[0, 0, 0, 0]], 1.0);
assert_eq!(arr[[0, 0, 0, 1]], 5.0);
assert_eq!(arr[[0, 0, 0, 2]], 9.0);
assert_eq!(arr[[0, 0, 1, 0]], 2.0);
assert_eq!(arr[[0, 0, 1, 1]], 6.0);
assert_eq!(arr[[0, 0, 1, 2]], 10.0);
assert_eq!(arr[[0, 1, 0, 0]], 3.0);
assert_eq!(arr[[0, 1, 0, 1]], 7.0);
assert_eq!(arr[[0, 1, 0, 2]], 11.0);
}
#[test]
fn find_and_dequantize_stripped_trailing_unit_dim() {
let t = make_f32_tensor(&[1, 6], &[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]);
let inputs: Vec<&TensorDyn> = vec![&t];
let mut used = Vec::new();
let arr = find_and_dequantize(&inputs, &[1, 6, 1], None, &mut used).unwrap();
assert_eq!(arr.shape(), &[1, 6, 1]);
assert_eq!(arr[[0, 0, 0]], 10.0);
assert_eq!(arr[[0, 5, 0]], 60.0);
}
#[test]
fn find_and_dequantize_skips_already_used() {
let t0 = make_f32_tensor(&[2, 3], &[0.0; 6]);
let t1 = make_f32_tensor(&[3, 2], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let inputs: Vec<&TensorDyn> = vec![&t0, &t1];
let mut used = vec![0]; let arr = find_and_dequantize(&inputs, &[2, 3], None, &mut used).unwrap();
assert_eq!(arr.shape(), &[2, 3]);
assert!(used.contains(&1));
assert_eq!(arr[[0, 0]], 1.0);
assert_eq!(arr[[0, 1]], 3.0);
assert_eq!(arr[[0, 2]], 5.0);
assert_eq!(arr[[1, 0]], 2.0);
assert_eq!(arr[[1, 1]], 4.0);
assert_eq!(arr[[1, 2]], 6.0);
}
#[test]
fn find_and_dequantize_no_match_returns_error() {
let t = make_f32_tensor(&[2, 5], &[0.0; 10]);
let inputs: Vec<&TensorDyn> = vec![&t];
let mut used = Vec::new();
let result = find_and_dequantize(&inputs, &[3, 4], None, &mut used);
assert!(result.is_err());
}
#[test]
fn find_axis_permutation_identity() {
let perm = find_axis_permutation(&[1, 3, 4], &[1, 3, 4]);
assert_eq!(perm, Some(vec![0, 1, 2]));
}
#[test]
fn find_axis_permutation_nchw_to_nhwc() {
let perm = find_axis_permutation(&[1, 3, 2, 2], &[1, 2, 2, 3]);
assert_eq!(perm, Some(vec![0, 2, 3, 1]));
}
#[test]
fn find_axis_permutation_no_match() {
assert_eq!(find_axis_permutation(&[1, 3, 4], &[1, 4, 5]), None);
}
#[test]
fn find_axis_permutation_different_lengths() {
assert_eq!(find_axis_permutation(&[1, 3], &[1, 3, 1]), None);
}
}