Skip to main content

edgefirst_decoder/per_scale/
helper.rs

1// SPDX-FileCopyrightText: Copyright 2026 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4//! Optional helper: walk a schema-v2 document and attach
5//! per-tensor `Quantization` to integer input tensors via shape match.
6//!
7//! Use when the upstream inference layer hasn't already attached
8//! quantization metadata to the tensors.
9
10use crate::schema::SchemaV2;
11use crate::{DecoderError, DecoderResult};
12use edgefirst_tensor::{Quantization as TQ, TensorDyn};
13
14/// Walk the schema's logical outputs and their per-scale children,
15/// attaching per-tensor `Quantization` to any matching integer tensor
16/// in `tensors` (matched by shape).
17///
18/// **Behavior:**
19/// - Per-channel quantization is silently skipped — the per-scale
20///   subsystem only consumes per-tensor quant.
21/// - Float tensors are silently skipped (they don't carry quantization).
22/// - Tensors with no matching schema entry are silently left alone.
23/// - If a schema entry has a quant declaration but no tensor matches
24///   the declared shape, returns `InvalidShape`.
25///
26/// **Idempotent:** safe to call multiple times; later calls overwrite
27/// the attached quantization.
28pub fn apply_schema_quant(schema: &SchemaV2, tensors: &mut [&mut TensorDyn]) -> DecoderResult<()> {
29    for logical in &schema.outputs {
30        // Per-scale children
31        for child in &logical.outputs {
32            if let Some(q) = child.quantization.as_ref() {
33                attach_per_tensor_quant_by_shape(tensors, &child.shape, q)?;
34            }
35        }
36        // Direct logical (no children) — quant lives on the logical itself.
37        if logical.outputs.is_empty() {
38            if let Some(q) = logical.quantization.as_ref() {
39                attach_per_tensor_quant_by_shape(tensors, &logical.shape, q)?;
40            }
41        }
42    }
43    Ok(())
44}
45
46fn attach_per_tensor_quant_by_shape(
47    tensors: &mut [&mut TensorDyn],
48    expected_shape: &[usize],
49    schema_q: &crate::schema::Quantization,
50) -> DecoderResult<()> {
51    // Reject empty scale up-front so a malformed schema fails fast at
52    // attach time instead of silently looking "per-channel" here and
53    // surfacing as `QuantMissing` later at decode time.
54    if schema_q.scale.is_empty() {
55        return Err(DecoderError::InvalidShape(format!(
56            "apply_schema_quant: schema declares quantization for shape \
57             {expected_shape:?} but `scale` is empty"
58        )));
59    }
60    if !schema_q.is_per_tensor() {
61        // Per-channel — skip silently. The per-scale planner errors
62        // separately when it actually needs to use a per-channel quant.
63        return Ok(());
64    }
65    let scale = schema_q.scale[0];
66    let zp = schema_q.zero_point_at(0);
67
68    for t in tensors.iter_mut() {
69        if t.shape() == expected_shape {
70            // Try to attach to integer variants; silently skip floats.
71            macro_rules! try_attach {
72                ($variant:ident) => {
73                    if let TensorDyn::$variant(inner) = &mut **t {
74                        let tq = TQ::per_tensor(scale, zp);
75                        inner.set_quantization(tq).map_err(|e| {
76                            DecoderError::Internal(format!(
77                                "apply_schema_quant set_quantization failed: {e}"
78                            ))
79                        })?;
80                        return Ok(());
81                    }
82                };
83            }
84            try_attach!(I8);
85            try_attach!(U8);
86            try_attach!(I16);
87            try_attach!(U16);
88            // Float / other dtypes silently skipped.
89            return Ok(());
90        }
91    }
92    Err(DecoderError::InvalidShape(format!(
93        "apply_schema_quant: no tensor matches {expected_shape:?}"
94    )))
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::schema::SchemaV2;
101    use edgefirst_tensor::{Tensor, TensorMemory};
102
103    #[test]
104    fn applies_quant_to_int8_by_shape() {
105        // Build a tiny schema with one logical output `boxes` having an int8 child.
106        let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
107        let schema: SchemaV2 = serde_json::from_str(json).unwrap();
108
109        // Build a tensor matching one box child's shape.
110        // The yolov8n schema's first box child has shape [1, 80, 80, 64].
111        let t = Tensor::<i8>::new(&[1, 80, 80, 64], Some(TensorMemory::Mem), None).unwrap();
112        assert!(t.quantization().is_none(), "fresh tensor has no quant");
113        let mut td = TensorDyn::I8(t);
114        let mut tensors: Vec<&mut TensorDyn> = vec![&mut td];
115
116        apply_schema_quant(&schema, &mut tensors).unwrap_err();
117        // Errors because not all schema-declared shapes have matching tensors.
118        // (We only provided one tensor, but schema has many children.)
119    }
120
121    #[test]
122    fn applies_quant_to_full_yolov8_input_set() {
123        let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
124        let schema: SchemaV2 = serde_json::from_str(json).unwrap();
125
126        // Build the full set of input tensors matching every schema child + protos.
127        let shapes_int8 = [
128            vec![1, 80, 80, 64],
129            vec![1, 80, 80, 80],
130            vec![1, 80, 80, 32],
131            vec![1, 40, 40, 64],
132            vec![1, 40, 40, 80],
133            vec![1, 40, 40, 32],
134            vec![1, 20, 20, 64],
135            vec![1, 20, 20, 80],
136            vec![1, 20, 20, 32],
137            vec![1, 160, 160, 32],
138        ];
139        let mut owned: Vec<TensorDyn> = shapes_int8
140            .iter()
141            .map(|s| TensorDyn::I8(Tensor::<i8>::new(s, Some(TensorMemory::Mem), None).unwrap()))
142            .collect();
143        let mut refs: Vec<&mut TensorDyn> = owned.iter_mut().collect();
144
145        apply_schema_quant(&schema, &mut refs).unwrap();
146
147        // All 10 tensors should now carry quant.
148        for td in &owned {
149            if let TensorDyn::I8(t) = td {
150                assert!(
151                    t.quantization().is_some(),
152                    "tensor missing quant after apply"
153                );
154            }
155        }
156    }
157
158    #[test]
159    fn empty_scale_errors_instead_of_silently_skipping() {
160        // Regression for Copilot review on PR #63: an empty `scale`
161        // vector used to slip past `is_per_tensor()` and return Ok,
162        // masking a malformed schema until decode time.
163        use crate::schema::{DType, Quantization};
164        let bad_q = Quantization {
165            scale: vec![],
166            zero_point: None,
167            axis: None,
168            dtype: Some(DType::Int8),
169        };
170        let mut td =
171            TensorDyn::I8(Tensor::<i8>::new(&[1, 2, 3], Some(TensorMemory::Mem), None).unwrap());
172        let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
173
174        let err = attach_per_tensor_quant_by_shape(&mut refs, &[1, 2, 3], &bad_q)
175            .expect_err("empty scale must be rejected");
176        match err {
177            DecoderError::InvalidShape(msg) => {
178                assert!(msg.contains("`scale` is empty"), "msg = {msg}")
179            }
180            other => panic!("unexpected error: {other:?}"),
181        }
182    }
183
184    #[test]
185    fn skips_float_tensors_silently() {
186        let json = include_str!("../../../../testdata/per_scale/synthetic_yolov8n_schema.json");
187        let schema: SchemaV2 = serde_json::from_str(json).unwrap();
188
189        // Build float tensors instead of int8.
190        let shape = vec![1, 80, 80, 64];
191        let t = Tensor::<f32>::new(&shape, Some(TensorMemory::Mem), None).unwrap();
192        let mut td = TensorDyn::F32(t);
193        let mut refs: Vec<&mut TensorDyn> = vec![&mut td];
194
195        // Float tensors don't take quant; the helper returns Ok and the schema
196        // declaration is silently skipped on this single mismatched-shape input.
197        // (Actually it'll error on InvalidShape for the OTHER schema children
198        // that have no matching tensor. That's the expected behaviour.)
199        let r = apply_schema_quant(&schema, &mut refs);
200        assert!(r.is_err());
201        // What we're really testing: float tensors don't crash.
202        if let TensorDyn::F32(_) = &td {
203            // OK — still F32, no panic
204        } else {
205            panic!("unexpected dtype");
206        }
207    }
208}