edgefirst_decoder/per_scale/
helper.rs1use crate::schema::SchemaV2;
11use crate::{DecoderError, DecoderResult};
12use edgefirst_tensor::{Quantization as TQ, TensorDyn};
13
14pub fn apply_schema_quant(schema: &SchemaV2, tensors: &mut [&mut TensorDyn]) -> DecoderResult<()> {
29 for logical in &schema.outputs {
30 for child in &logical.outputs {
32 if let Some(q) = child.quantization.as_ref() {
33 attach_per_tensor_quant_by_shape(tensors, &child.shape, q)?;
34 }
35 }
36 if logical.outputs.is_empty() {
38 if let Some(q) = logical.quantization.as_ref() {
39 attach_per_tensor_quant_by_shape(tensors, &logical.shape, q)?;
40 }
41 }
42 }
43 Ok(())
44}
45
46fn attach_per_tensor_quant_by_shape(
47 tensors: &mut [&mut TensorDyn],
48 expected_shape: &[usize],
49 schema_q: &crate::schema::Quantization,
50) -> DecoderResult<()> {
51 if schema_q.scale.is_empty() {
55 return Err(DecoderError::InvalidShape(format!(
56 "apply_schema_quant: schema declares quantization for shape \
57 {expected_shape:?} but `scale` is empty"
58 )));
59 }
60 if !schema_q.is_per_tensor() {
61 return Ok(());
64 }
65 let scale = schema_q.scale[0];
66 let zp = schema_q.zero_point_at(0);
67
68 for t in tensors.iter_mut() {
69 if t.shape() == expected_shape {
70 macro_rules! try_attach {
72 ($variant:ident) => {
73 if let TensorDyn::$variant(inner) = &mut **t {
74 let tq = TQ::per_tensor(scale, zp);
75 inner.set_quantization(tq).map_err(|e| {
76 DecoderError::Internal(format!(
77 "apply_schema_quant set_quantization failed: {e}"
78 ))
79 })?;
80 return Ok(());
81 }
82 };
83 }
84 try_attach!(I8);
85 try_attach!(U8);
86 try_attach!(I16);
87 try_attach!(U16);
88 return Ok(());
90 }
91 }
92 Err(DecoderError::InvalidShape(format!(
93 "apply_schema_quant: no tensor matches {expected_shape:?}"
94 )))
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::schema::SchemaV2;
101 use edgefirst_tensor::{Tensor, TensorMemory};
102
103 #[test]
104 fn applies_quant_to_int8_by_shape() {
105 let json =
107 edgefirst_bench::testdata::read_to_string("per_scale/synthetic_yolov8n_schema.json");
108 let schema: SchemaV2 = serde_json::from_str(&json).unwrap();
109
110 let t = Tensor::<i8>::new(&[1, 80, 80, 64], Some(TensorMemory::Mem), None).unwrap();
113 assert!(t.quantization().is_none(), "fresh tensor has no quant");
114 let mut td = TensorDyn::I8(t);
115 let mut tensors: Vec<&mut TensorDyn> = vec![&mut td];
116
117 apply_schema_quant(&schema, &mut tensors).unwrap_err();
118 }
121
122 #[test]
123 fn applies_quant_to_full_yolov8_input_set() {
124 let json =
125 edgefirst_bench::testdata::read_to_string("per_scale/synthetic_yolov8n_schema.json");
126 let schema: SchemaV2 = serde_json::from_str(&json).unwrap();
127
128 let shapes_int8 = [
130 vec![1, 80, 80, 64],
131 vec![1, 80, 80, 80],
132 vec![1, 80, 80, 32],
133 vec![1, 40, 40, 64],
134 vec![1, 40, 40, 80],
135 vec![1, 40, 40, 32],
136 vec![1, 20, 20, 64],
137 vec![1, 20, 20, 80],
138 vec![1, 20, 20, 32],
139 vec![1, 160, 160, 32],
140 ];
141 let mut owned: Vec<TensorDyn> = shapes_int8
142 .iter()
143 .map(|s| TensorDyn::I8(Tensor::<i8>::new(s, Some(TensorMemory::Mem), None).unwrap()))
144 .collect();
145 let mut refs: Vec<&mut TensorDyn> = owned.iter_mut().collect();
146
147 apply_schema_quant(&schema, &mut refs).unwrap();
148
149 for td in &owned {
151 if let TensorDyn::I8(t) = td {
152 assert!(
153 t.quantization().is_some(),
154 "tensor missing quant after apply"
155 );
156 }
157 }
158 }
159
160 #[test]
161 fn empty_scale_errors_instead_of_silently_skipping() {
162 use crate::schema::{DType, Quantization};
166 let bad_q = Quantization {
167 scale: vec![],
168 zero_point: None,
169 axis: None,
170 dtype: Some(DType::Int8),
171 };
172 let mut td =
173 TensorDyn::I8(Tensor::<i8>::new(&[1, 2, 3], Some(TensorMemory::Mem), None).unwrap());
174 let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
175
176 let err = attach_per_tensor_quant_by_shape(&mut refs, &[1, 2, 3], &bad_q)
177 .expect_err("empty scale must be rejected");
178 match err {
179 DecoderError::InvalidShape(msg) => {
180 assert!(msg.contains("`scale` is empty"), "msg = {msg}")
181 }
182 other => panic!("unexpected error: {other:?}"),
183 }
184 }
185
186 #[test]
187 fn skips_float_tensors_silently() {
188 let json =
189 edgefirst_bench::testdata::read_to_string("per_scale/synthetic_yolov8n_schema.json");
190 let schema: SchemaV2 = serde_json::from_str(&json).unwrap();
191
192 let shape = vec![1, 80, 80, 64];
194 let t = Tensor::<f32>::new(&shape, Some(TensorMemory::Mem), None).unwrap();
195 let mut td = TensorDyn::F32(t);
196 let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
197
198 let r = apply_schema_quant(&schema, &mut refs);
203 assert!(r.is_err());
204 if let TensorDyn::F32(_) = &td {
206 } else {
208 panic!("unexpected dtype");
209 }
210 }
211}