edgefirst_decoder/per_scale/
helper.rs1use crate::schema::SchemaV2;
11use crate::{DecoderError, DecoderResult};
12use edgefirst_tensor::{Quantization as TQ, TensorDyn};
13
14pub fn apply_schema_quant(schema: &SchemaV2, tensors: &mut [&mut TensorDyn]) -> DecoderResult<()> {
29 for logical in &schema.outputs {
30 for child in &logical.outputs {
32 if let Some(q) = child.quantization.as_ref() {
33 attach_per_tensor_quant_by_shape(tensors, &child.shape, q)?;
34 }
35 }
36 if logical.outputs.is_empty() {
38 if let Some(q) = logical.quantization.as_ref() {
39 attach_per_tensor_quant_by_shape(tensors, &logical.shape, q)?;
40 }
41 }
42 }
43 Ok(())
44}
45
46fn attach_per_tensor_quant_by_shape(
47 tensors: &mut [&mut TensorDyn],
48 expected_shape: &[usize],
49 schema_q: &crate::schema::Quantization,
50) -> DecoderResult<()> {
51 if schema_q.scale.is_empty() {
55 return Err(DecoderError::InvalidShape(format!(
56 "apply_schema_quant: schema declares quantization for shape \
57 {expected_shape:?} but `scale` is empty"
58 )));
59 }
60 if !schema_q.is_per_tensor() {
61 return Ok(());
64 }
65 let scale = schema_q.scale[0];
66 let zp = schema_q.zero_point_at(0);
67
68 for t in tensors.iter_mut() {
69 if t.shape() == expected_shape {
70 macro_rules! try_attach {
72 ($variant:ident) => {
73 if let TensorDyn::$variant(inner) = &mut **t {
74 let tq = TQ::per_tensor(scale, zp);
75 inner.set_quantization(tq).map_err(|e| {
76 DecoderError::Internal(format!(
77 "apply_schema_quant set_quantization failed: {e}"
78 ))
79 })?;
80 return Ok(());
81 }
82 };
83 }
84 try_attach!(I8);
85 try_attach!(U8);
86 try_attach!(I16);
87 try_attach!(U16);
88 return Ok(());
90 }
91 }
92 Err(DecoderError::InvalidShape(format!(
93 "apply_schema_quant: no tensor matches {expected_shape:?}"
94 )))
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::schema::SchemaV2;
101 use edgefirst_tensor::{Tensor, TensorMemory};
102
103 #[test]
104 fn applies_quant_to_int8_by_shape() {
105 let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
107 let schema: SchemaV2 = serde_json::from_str(json).unwrap();
108
109 let t = Tensor::<i8>::new(&[1, 80, 80, 64], Some(TensorMemory::Mem), None).unwrap();
112 assert!(t.quantization().is_none(), "fresh tensor has no quant");
113 let mut td = TensorDyn::I8(t);
114 let mut tensors: Vec<&mut TensorDyn> = vec![&mut td];
115
116 apply_schema_quant(&schema, &mut tensors).unwrap_err();
117 }
120
121 #[test]
122 fn applies_quant_to_full_yolov8_input_set() {
123 let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
124 let schema: SchemaV2 = serde_json::from_str(json).unwrap();
125
126 let shapes_int8 = [
128 vec![1, 80, 80, 64],
129 vec![1, 80, 80, 80],
130 vec![1, 80, 80, 32],
131 vec![1, 40, 40, 64],
132 vec![1, 40, 40, 80],
133 vec![1, 40, 40, 32],
134 vec![1, 20, 20, 64],
135 vec![1, 20, 20, 80],
136 vec![1, 20, 20, 32],
137 vec![1, 160, 160, 32],
138 ];
139 let mut owned: Vec<TensorDyn> = shapes_int8
140 .iter()
141 .map(|s| TensorDyn::I8(Tensor::<i8>::new(s, Some(TensorMemory::Mem), None).unwrap()))
142 .collect();
143 let mut refs: Vec<&mut TensorDyn> = owned.iter_mut().collect();
144
145 apply_schema_quant(&schema, &mut refs).unwrap();
146
147 for td in &owned {
149 if let TensorDyn::I8(t) = td {
150 assert!(
151 t.quantization().is_some(),
152 "tensor missing quant after apply"
153 );
154 }
155 }
156 }
157
158 #[test]
159 fn empty_scale_errors_instead_of_silently_skipping() {
160 use crate::schema::{DType, Quantization};
164 let bad_q = Quantization {
165 scale: vec![],
166 zero_point: None,
167 axis: None,
168 dtype: Some(DType::Int8),
169 };
170 let mut td =
171 TensorDyn::I8(Tensor::<i8>::new(&[1, 2, 3], Some(TensorMemory::Mem), None).unwrap());
172 let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
173
174 let err = attach_per_tensor_quant_by_shape(&mut refs, &[1, 2, 3], &bad_q)
175 .expect_err("empty scale must be rejected");
176 match err {
177 DecoderError::InvalidShape(msg) => {
178 assert!(msg.contains("`scale` is empty"), "msg = {msg}")
179 }
180 other => panic!("unexpected error: {other:?}"),
181 }
182 }
183
184 #[test]
185 fn skips_float_tensors_silently() {
186 let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
187 let schema: SchemaV2 = serde_json::from_str(json).unwrap();
188
189 let shape = vec![1, 80, 80, 64];
191 let t = Tensor::<f32>::new(&shape, Some(TensorMemory::Mem), None).unwrap();
192 let mut td = TensorDyn::F32(t);
193 let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
194
195 let r = apply_schema_quant(&schema, &mut refs);
200 assert!(r.is_err());
201 if let TensorDyn::F32(_) = &td {
203 } else {
205 panic!("unexpected dtype");
206 }
207 }
208}