use crate::schema::SchemaV2;
use crate::{DecoderError, DecoderResult};
use edgefirst_tensor::{Quantization as TQ, TensorDyn};
pub fn apply_schema_quant(schema: &SchemaV2, tensors: &mut [&mut TensorDyn]) -> DecoderResult<()> {
for logical in &schema.outputs {
for child in &logical.outputs {
if let Some(q) = child.quantization.as_ref() {
attach_per_tensor_quant_by_shape(tensors, &child.shape, q)?;
}
}
if logical.outputs.is_empty() {
if let Some(q) = logical.quantization.as_ref() {
attach_per_tensor_quant_by_shape(tensors, &logical.shape, q)?;
}
}
}
Ok(())
}
fn attach_per_tensor_quant_by_shape(
tensors: &mut [&mut TensorDyn],
expected_shape: &[usize],
schema_q: &crate::schema::Quantization,
) -> DecoderResult<()> {
if schema_q.scale.is_empty() {
return Err(DecoderError::InvalidShape(format!(
"apply_schema_quant: schema declares quantization for shape \
{expected_shape:?} but `scale` is empty"
)));
}
if !schema_q.is_per_tensor() {
return Ok(());
}
let scale = schema_q.scale[0];
let zp = schema_q.zero_point_at(0);
for t in tensors.iter_mut() {
if t.shape() == expected_shape {
macro_rules! try_attach {
($variant:ident) => {
if let TensorDyn::$variant(inner) = &mut **t {
let tq = TQ::per_tensor(scale, zp);
inner.set_quantization(tq).map_err(|e| {
DecoderError::Internal(format!(
"apply_schema_quant set_quantization failed: {e}"
))
})?;
return Ok(());
}
};
}
try_attach!(I8);
try_attach!(U8);
try_attach!(I16);
try_attach!(U16);
return Ok(());
}
}
Err(DecoderError::InvalidShape(format!(
"apply_schema_quant: no tensor matches {expected_shape:?}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::SchemaV2;
use edgefirst_tensor::{Tensor, TensorMemory};
#[test]
fn applies_quant_to_int8_by_shape() {
let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
let schema: SchemaV2 = serde_json::from_str(json).unwrap();
let t = Tensor::<i8>::new(&[1, 80, 80, 64], Some(TensorMemory::Mem), None).unwrap();
assert!(t.quantization().is_none(), "fresh tensor has no quant");
let mut td = TensorDyn::I8(t);
let mut tensors: Vec<&mut TensorDyn> = vec![&mut td];
apply_schema_quant(&schema, &mut tensors).unwrap_err();
}
#[test]
fn applies_quant_to_full_yolov8_input_set() {
let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
let schema: SchemaV2 = serde_json::from_str(json).unwrap();
let shapes_int8 = [
vec![1, 80, 80, 64],
vec![1, 80, 80, 80],
vec![1, 80, 80, 32],
vec![1, 40, 40, 64],
vec![1, 40, 40, 80],
vec![1, 40, 40, 32],
vec![1, 20, 20, 64],
vec![1, 20, 20, 80],
vec![1, 20, 20, 32],
vec![1, 160, 160, 32],
];
let mut owned: Vec<TensorDyn> = shapes_int8
.iter()
.map(|s| TensorDyn::I8(Tensor::<i8>::new(s, Some(TensorMemory::Mem), None).unwrap()))
.collect();
let mut refs: Vec<&mut TensorDyn> = owned.iter_mut().collect();
apply_schema_quant(&schema, &mut refs).unwrap();
for td in &owned {
if let TensorDyn::I8(t) = td {
assert!(
t.quantization().is_some(),
"tensor missing quant after apply"
);
}
}
}
#[test]
fn empty_scale_errors_instead_of_silently_skipping() {
use crate::schema::{DType, Quantization};
let bad_q = Quantization {
scale: vec![],
zero_point: None,
axis: None,
dtype: Some(DType::Int8),
};
let mut td =
TensorDyn::I8(Tensor::<i8>::new(&[1, 2, 3], Some(TensorMemory::Mem), None).unwrap());
let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
let err = attach_per_tensor_quant_by_shape(&mut refs, &[1, 2, 3], &bad_q)
.expect_err("empty scale must be rejected");
match err {
DecoderError::InvalidShape(msg) => {
assert!(msg.contains("`scale` is empty"), "msg = {msg}")
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn skips_float_tensors_silently() {
let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
let schema: SchemaV2 = serde_json::from_str(json).unwrap();
let shape = vec![1, 80, 80, 64];
let t = Tensor::<f32>::new(&shape, Some(TensorMemory::Mem), None).unwrap();
let mut td = TensorDyn::F32(t);
let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
let r = apply_schema_quant(&schema, &mut refs);
assert!(r.is_err());
if let TensorDyn::F32(_) = &td {
} else {
panic!("unexpected dtype");
}
}
}