Skip to main content

quantize_rs/onnx_utils/
quantization_nodes.rs

1//! Low-level builders for ONNX QDQ (Quantize-Dequantize) graph primitives.
2//!
3//! Each quantized weight becomes four graph elements:
4//!
5//! ```text
6//! Initializers:
7//!   "{name}_quantized"  — INT8 tensor, same shape as original
8//!   "{name}_scale"      — FP32 scalar
9//!   "{name}_zp"         — INT8 scalar
10//!
11//! Node:
12//!   DequantizeLinear
13//!     inputs:  ["{name}_quantized", "{name}_scale", "{name}_zp"]
14//!     outputs: ["{name}"]          ← original name; downstream graph untouched
15//! ```
16//!
17//! The DequantizeLinear op runs at inference time:
18//!   `output = (input - zero_point) × scale`
19//! which matches the dequantize formula already used in `QuantParams` and
20//! `QuantParamsInt4`.
21
22use crate::onnx_proto::{attribute_proto, tensor_proto, AttributeProto, NodeProto, TensorProto};
23
24// ---------------------------------------------------------------------------
25// Name generation
26// ---------------------------------------------------------------------------
27
28/// Canonical names for the four graph elements that replace one FP32 initializer.
29#[derive(Debug, Clone)]
30pub struct DequantLinearNames {
31    /// `"{original}_quantized"` — the INT8 weight tensor
32    pub quantized_name: String,
33    /// `"{original}_scale"` — FP32 scale scalar
34    pub scale_name: String,
35    /// `"{original}_zp"` — INT8 zero-point scalar
36    pub zp_name: String,
37    /// `"DequantizeLinear_{original}"` — the node name
38    pub node_name: String,
39    /// The original tensor name — becomes the DequantizeLinear *output*,
40    /// so every downstream node (Conv, MatMul, …) sees no change.
41    pub output_name: String,
42}
43
44impl DequantLinearNames {
45    /// Derive all four names from the original weight tensor name.
46    pub fn from_original(original_name: &str) -> Self {
47        Self {
48            quantized_name: format!("{}_quantized", original_name),
49            scale_name: format!("{}_scale", original_name),
50            zp_name: format!("{}_zp", original_name),
51            node_name: format!("DequantizeLinear_{}", original_name),
52            output_name: original_name.to_string(),
53        }
54    }
55}
56
57// ---------------------------------------------------------------------------
58// Node builder
59// ---------------------------------------------------------------------------
60
61/// Build a DequantizeLinear `NodeProto`.
62///
63/// ONNX spec (opset ≥ 10):
64///   inputs  = [x (INT8), x_scale (FP32), x_zero_point (INT8)]
65///   outputs = [y (FP32)]
66///   y = (x - x_zero_point) × x_scale
67///
68/// When `axis` is `Some(a)`, the `axis` attribute is set on the node,
69/// enabling per-channel dequantization (opset ≥ 13).
70pub fn build_dequantize_linear_node(names: &DequantLinearNames, axis: Option<usize>) -> NodeProto {
71    let attribute = match axis {
72        Some(a) => vec![AttributeProto {
73            name: "axis".to_string(),
74            r#type: attribute_proto::AttributeType::Int as i32,
75            i: a as i64,
76            ..Default::default()
77        }],
78        None => vec![],
79    };
80
81    NodeProto {
82        op_type: "DequantizeLinear".to_string(),
83        name: names.node_name.clone(),
84        input: vec![
85            names.quantized_name.clone(),
86            names.scale_name.clone(),
87            names.zp_name.clone(),
88        ],
89        output: vec![names.output_name.clone()],
90        attribute,
91        ..Default::default()
92    }
93}
94
95// ---------------------------------------------------------------------------
96// On-disk storage format
97// ---------------------------------------------------------------------------
98
99/// How quantized values are stored on disk inside the ONNX initializer.
100///
101/// ONNX `DequantizeLinear` accepted only INT8 inputs before opset 21.  From
102/// opset 21 it also accepts native `INT4` (and `UINT4`), which is 2× smaller
103/// on disk.
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
105pub enum StorageFormat {
106    /// INT4 values widened to INT8 bytes.  Compatible with opset 10+ — the
107    /// default for backward compatibility, but gives only 4× compression.
108    Int8Widened,
109    /// Native `DataType::Int4` with two values packed per byte.  Requires
110    /// opset 21.  Gives the full 8× compression for INT4 models.
111    NativeInt4,
112}
113
114/// Pack INT4 values in ONNX wire-format layout: the element at the **even**
115/// index goes into the **low** nibble, odd index into the high nibble.
116///
117/// See the ONNX spec for `TensorProto` `DataType::Int4`.  This is the opposite
118/// nibble order from [`crate::pack_int4`], which uses the library's internal
119/// layout (val1 in high nibble).
120pub(crate) fn pack_int4_onnx(values: &[i8]) -> Vec<u8> {
121    let mut packed = Vec::with_capacity(values.len().div_ceil(2));
122    for chunk in values.chunks(2) {
123        let lo = (chunk[0] & 0x0F) as u8;
124        let hi = if chunk.len() > 1 {
125            (chunk[1] & 0x0F) as u8
126        } else {
127            0
128        };
129        packed.push((hi << 4) | lo);
130    }
131    packed
132}
133
134/// Unpack INT4 values stored in ONNX wire-format layout.  Returns exactly
135/// `num_values` `i8`s, sign-extended from 4 bits.
136pub(crate) fn unpack_int4_onnx(packed: &[u8], num_values: usize) -> Vec<i8> {
137    let mut values = Vec::with_capacity(num_values);
138    for &byte in packed {
139        let lo = byte & 0x0F;
140        let hi = (byte >> 4) & 0x0F;
141        values.push(sign_extend_nibble(lo));
142        if values.len() < num_values {
143            values.push(sign_extend_nibble(hi));
144        }
145    }
146    values.truncate(num_values);
147    values
148}
149
150#[inline]
151fn sign_extend_nibble(nibble: u8) -> i8 {
152    if nibble >= 8 {
153        (nibble as i8) | !0x0F
154    } else {
155        nibble as i8
156    }
157}
158
159// ---------------------------------------------------------------------------
160// Initializer builders
161// ---------------------------------------------------------------------------
162
163/// Tensor holding the quantized weight values.
164///
165/// Shape (`dims`) always matches the **logical** element count of the original
166/// FP32 tensor.  With [`StorageFormat::Int8Widened`] each element occupies one
167/// byte; with [`StorageFormat::NativeInt4`] two elements share a byte so
168/// `raw_data.len() == dims.product().div_ceil(2)`.
169pub fn build_quantized_weight_tensor(
170    names: &DequantLinearNames,
171    values: &[i8],
172    shape: &[i64],
173    format: StorageFormat,
174) -> TensorProto {
175    match format {
176        StorageFormat::Int8Widened => TensorProto {
177            name: names.quantized_name.clone(),
178            data_type: tensor_proto::DataType::Int8 as i32,
179            dims: shape.to_vec(),
180            // Each i8 value → one byte.  Reinterpret cast, not value conversion.
181            raw_data: values.iter().map(|&v| v as u8).collect(),
182            ..Default::default()
183        },
184        StorageFormat::NativeInt4 => TensorProto {
185            name: names.quantized_name.clone(),
186            data_type: tensor_proto::DataType::Int4 as i32,
187            dims: shape.to_vec(),
188            raw_data: pack_int4_onnx(values),
189            ..Default::default()
190        },
191    }
192}
193
194/// FP32 scale tensor.
195///
196/// For per-tensor quantization, `scales` has one element and the tensor
197/// is rank-0 (scalar).  For per-channel, `scales` has one entry per
198/// channel and the tensor is rank-1 with shape `[num_channels]`.
199pub fn build_scale_tensor(names: &DequantLinearNames, scales: &[f32]) -> TensorProto {
200    let mut t = TensorProto {
201        name: names.scale_name.clone(),
202        data_type: tensor_proto::DataType::Float as i32,
203        float_data: scales.to_vec(),
204        ..Default::default()
205    };
206    if scales.len() > 1 {
207        // rank-1: [num_channels]
208        t.dims = vec![scales.len() as i64];
209    }
210    // For scalar (len == 1), dims remains empty (rank-0 scalar).
211    t
212}
213
214/// Zero-point tensor.  Data type matches the quantized weight:
215///   - [`StorageFormat::Int8Widened`]: `DataType::Int8`, one byte per value.
216///   - [`StorageFormat::NativeInt4`]: `DataType::Int4`, packed two per byte.
217///
218/// For per-tensor, `zps` has one element → rank-0 scalar.
219/// For per-channel, `zps` has one per channel → rank-1 `[num_channels]`.
220pub fn build_zero_point_tensor(
221    names: &DequantLinearNames,
222    zps: &[i8],
223    format: StorageFormat,
224) -> TensorProto {
225    let (data_type, raw_data) = match format {
226        StorageFormat::Int8Widened => (
227            tensor_proto::DataType::Int8 as i32,
228            zps.iter().map(|&v| v as u8).collect(),
229        ),
230        StorageFormat::NativeInt4 => (tensor_proto::DataType::Int4 as i32, pack_int4_onnx(zps)),
231    };
232
233    let mut t = TensorProto {
234        name: names.zp_name.clone(),
235        data_type,
236        raw_data,
237        ..Default::default()
238    };
239    if zps.len() > 1 {
240        // rank-1: [num_channels]
241        t.dims = vec![zps.len() as i64];
242    }
243    // For scalar (len == 1), dims remains empty (rank-0 scalar).
244    t
245}
246
247// ---------------------------------------------------------------------------
248// Tests
249// ---------------------------------------------------------------------------
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::onnx_proto::tensor_proto;
255
256    #[test]
257    fn test_names_from_simple_weight() {
258        let n = DequantLinearNames::from_original("conv1.weight");
259        assert_eq!(n.quantized_name, "conv1.weight_quantized");
260        assert_eq!(n.scale_name, "conv1.weight_scale");
261        assert_eq!(n.zp_name, "conv1.weight_zp");
262        assert_eq!(n.node_name, "DequantizeLinear_conv1.weight");
263        assert_eq!(n.output_name, "conv1.weight");
264    }
265
266    #[test]
267    fn test_names_from_dotted_path() {
268        // Real ResNet-18 weight names look like this
269        let n = DequantLinearNames::from_original("layer1.0.conv1.weight");
270        assert_eq!(n.quantized_name, "layer1.0.conv1.weight_quantized");
271        assert_eq!(n.output_name, "layer1.0.conv1.weight");
272    }
273
274    #[test]
275    fn test_dequantize_linear_node_inputs_outputs() {
276        let names = DequantLinearNames::from_original("fc.weight");
277        let node = build_dequantize_linear_node(&names, None);
278
279        assert_eq!(node.op_type, "DequantizeLinear");
280        assert_eq!(node.name, "DequantizeLinear_fc.weight");
281
282        assert_eq!(node.input.len(), 3);
283        assert_eq!(node.input[0], "fc.weight_quantized");
284        assert_eq!(node.input[1], "fc.weight_scale");
285        assert_eq!(node.input[2], "fc.weight_zp");
286
287        assert_eq!(node.output.len(), 1);
288        assert_eq!(node.output[0], "fc.weight");
289        assert!(node.attribute.is_empty());
290    }
291
292    #[test]
293    fn test_dequantize_linear_node_with_axis() {
294        let names = DequantLinearNames::from_original("conv.weight");
295        let node = build_dequantize_linear_node(&names, Some(0));
296
297        assert_eq!(node.attribute.len(), 1);
298        assert_eq!(node.attribute[0].name, "axis");
299        assert_eq!(node.attribute[0].i, 0);
300    }
301
302    #[test]
303    fn test_quantized_weight_tensor_shape_and_data() {
304        let names = DequantLinearNames::from_original("w");
305        let values = vec![1i8, -2, 3, -4, 5, 6];
306        let shape = vec![2i64, 3];
307        let t = build_quantized_weight_tensor(&names, &values, &shape, StorageFormat::Int8Widened);
308
309        assert_eq!(t.name, "w_quantized");
310        assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
311        assert_eq!(t.dims.len(), 2);
312        assert_eq!(t.dims[0], 2);
313        assert_eq!(t.dims[1], 3);
314
315        // Verify byte-level round-trip
316        let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
317        assert_eq!(recovered, values);
318    }
319
320    #[test]
321    fn test_scale_tensor_scalar() {
322        let names = DequantLinearNames::from_original("w");
323        let t = build_scale_tensor(&names, &[0.003921]);
324
325        assert_eq!(t.name, "w_scale");
326        assert_eq!(t.data_type, tensor_proto::DataType::Float as i32);
327        assert_eq!(t.dims.len(), 0, "single scale must be rank-0 scalar");
328        assert!((t.float_data[0] - 0.003921).abs() < 1e-6);
329    }
330
331    #[test]
332    fn test_scale_tensor_per_channel() {
333        let names = DequantLinearNames::from_original("w");
334        let t = build_scale_tensor(&names, &[0.01, 0.02, 0.03]);
335
336        assert_eq!(t.dims.len(), 1);
337        assert_eq!(t.dims[0], 3);
338        assert_eq!(t.float_data.len(), 3);
339    }
340
341    #[test]
342    fn test_zero_point_tensor_scalar() {
343        let names = DequantLinearNames::from_original("w");
344        let t = build_zero_point_tensor(&names, &[-3], StorageFormat::Int8Widened);
345
346        assert_eq!(t.name, "w_zp");
347        assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
348        assert_eq!(t.dims.len(), 0, "single zp must be rank-0 scalar");
349        assert_eq!(t.raw_data[0], (-3i8) as u8);
350    }
351
352    #[test]
353    fn test_zero_point_tensor_per_channel() {
354        let names = DequantLinearNames::from_original("w");
355        let t = build_zero_point_tensor(&names, &[-3, 0, 5], StorageFormat::Int8Widened);
356
357        assert_eq!(t.dims.len(), 1);
358        assert_eq!(t.dims[0], 3);
359        assert_eq!(t.raw_data.len(), 3);
360    }
361
362    #[test]
363    fn test_int4_range_values_round_trip() {
364        // INT4 signed range: [-8, 7].  These arrive as i8; we store them as-is.
365        let names = DequantLinearNames::from_original("w");
366        let values = vec![-8i8, -1, 0, 7];
367        let shape = vec![4i64];
368        let t = build_quantized_weight_tensor(&names, &values, &shape, StorageFormat::Int8Widened);
369
370        let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
371        assert_eq!(recovered, values);
372    }
373
374    // -----------------------------------------------------------------------
375    // Native INT4 (ONNX opset 21) tests
376    // -----------------------------------------------------------------------
377
378    #[test]
379    fn test_onnx_pack_layout_even_index_in_low_nibble() {
380        // ONNX spec: element at even index goes in the low nibble.
381        // [0x1, 0x2] → byte = (0x2 << 4) | 0x1 = 0x21
382        let packed = pack_int4_onnx(&[1, 2]);
383        assert_eq!(packed, vec![0x21]);
384
385        let packed = pack_int4_onnx(&[0, 0x7]);
386        assert_eq!(packed, vec![0x70]);
387    }
388
389    #[test]
390    fn test_onnx_pack_negative_values() {
391        // -1 in 4-bit two's complement is 0xF.
392        // [-1, -1] → byte = (0xF << 4) | 0xF = 0xFF
393        assert_eq!(pack_int4_onnx(&[-1, -1]), vec![0xFF]);
394
395        // [-8, 7] → byte = (0x7 << 4) | 0x8 = 0x78
396        assert_eq!(pack_int4_onnx(&[-8, 7]), vec![0x78]);
397    }
398
399    #[test]
400    fn test_onnx_pack_odd_length_zero_pads_high_nibble() {
401        // Single value in the low nibble, high nibble zero.
402        assert_eq!(pack_int4_onnx(&[0x3]), vec![0x03]);
403        assert_eq!(pack_int4_onnx(&[-1]), vec![0x0F]);
404    }
405
406    #[test]
407    fn test_onnx_pack_unpack_round_trip_all_values() {
408        let values: Vec<i8> = (-8..=7).collect();
409        let packed = pack_int4_onnx(&values);
410        let unpacked = unpack_int4_onnx(&packed, values.len());
411        assert_eq!(unpacked, values);
412        assert_eq!(packed.len(), 8, "16 values must pack to exactly 8 bytes");
413    }
414
415    #[test]
416    fn test_onnx_pack_unpack_round_trip_odd_length() {
417        let values: Vec<i8> = vec![-8, -1, 0, 7, -3];
418        let packed = pack_int4_onnx(&values);
419        let unpacked = unpack_int4_onnx(&packed, values.len());
420        assert_eq!(unpacked, values);
421        assert_eq!(packed.len(), 3, "5 values must pack to ceil(5/2) = 3 bytes");
422    }
423
424    #[test]
425    fn test_native_int4_weight_tensor_uses_int4_data_type() {
426        let names = DequantLinearNames::from_original("w");
427        let values = vec![-8i8, -1, 0, 7];
428        let shape = vec![4i64];
429        let t = build_quantized_weight_tensor(&names, &values, &shape, StorageFormat::NativeInt4);
430
431        assert_eq!(t.data_type, tensor_proto::DataType::Int4 as i32);
432        assert_eq!(t.dims, vec![4], "dims should be logical element count");
433        assert_eq!(t.raw_data.len(), 2, "4 values → 2 packed bytes");
434
435        let recovered = unpack_int4_onnx(&t.raw_data, values.len());
436        assert_eq!(recovered, values);
437    }
438
439    #[test]
440    fn test_native_int4_zero_point_scalar() {
441        let names = DequantLinearNames::from_original("w");
442        let t = build_zero_point_tensor(&names, &[-3], StorageFormat::NativeInt4);
443
444        assert_eq!(t.data_type, tensor_proto::DataType::Int4 as i32);
445        assert_eq!(t.dims.len(), 0, "scalar zp has rank 0");
446        assert_eq!(t.raw_data.len(), 1);
447
448        let recovered = unpack_int4_onnx(&t.raw_data, 1);
449        assert_eq!(recovered, vec![-3]);
450    }
451
452    #[test]
453    fn test_native_int4_zero_point_per_channel() {
454        let names = DequantLinearNames::from_original("w");
455        let zps = vec![-3, 0, 5, -1, 7];
456        let t = build_zero_point_tensor(&names, &zps, StorageFormat::NativeInt4);
457
458        assert_eq!(t.data_type, tensor_proto::DataType::Int4 as i32);
459        assert_eq!(t.dims, vec![5], "per-channel zp has rank 1");
460        assert_eq!(t.raw_data.len(), 3, "5 values → 3 packed bytes");
461
462        let recovered = unpack_int4_onnx(&t.raw_data, zps.len());
463        assert_eq!(recovered, zps);
464    }
465}