Skip to main content

quantize_rs/onnx_utils/
quantization_nodes.rs

1//! Low-level builders for ONNX QDQ (Quantize-Dequantize) graph primitives.
2//!
3//! Each quantized weight becomes four graph elements:
4//!
5//! ```text
6//! Initializers:
7//!   "{name}_quantized"  — INT8 tensor, same shape as original
8//!   "{name}_scale"      — FP32 scalar
9//!   "{name}_zp"         — INT8 scalar
10//!
11//! Node:
12//!   DequantizeLinear
13//!     inputs:  ["{name}_quantized", "{name}_scale", "{name}_zp"]
14//!     outputs: ["{name}"]          ← original name; downstream graph untouched
15//! ```
16//!
17//! The DequantizeLinear op runs at inference time:
18//!   `output = (input - zero_point) × scale`
19//! which matches the dequantize formula already used in `QuantParams` and
20//! `QuantParamsInt4`.
21
22use crate::onnx_proto::{
23    AttributeProto, NodeProto, TensorProto,
24    attribute_proto, tensor_proto,
25};
26
27// ---------------------------------------------------------------------------
28// Name generation
29// ---------------------------------------------------------------------------
30
31/// Canonical names for the four graph elements that replace one FP32 initializer.
32#[derive(Debug, Clone)]
33pub struct DequantLinearNames {
34    /// `"{original}_quantized"` — the INT8 weight tensor
35    pub quantized_name: String,
36    /// `"{original}_scale"` — FP32 scale scalar
37    pub scale_name: String,
38    /// `"{original}_zp"` — INT8 zero-point scalar
39    pub zp_name: String,
40    /// `"DequantizeLinear_{original}"` — the node name
41    pub node_name: String,
42    /// The original tensor name — becomes the DequantizeLinear *output*,
43    /// so every downstream node (Conv, MatMul, …) sees no change.
44    pub output_name: String,
45}
46
47impl DequantLinearNames {
48    /// Derive all four names from the original weight tensor name.
49    pub fn from_original(original_name: &str) -> Self {
50        Self {
51            quantized_name: format!("{}_quantized", original_name),
52            scale_name:     format!("{}_scale", original_name),
53            zp_name:        format!("{}_zp", original_name),
54            node_name:      format!("DequantizeLinear_{}", original_name),
55            output_name:    original_name.to_string(),
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Node builder
62// ---------------------------------------------------------------------------
63
64/// Build a DequantizeLinear `NodeProto`.
65///
66/// ONNX spec (opset ≥ 10):
67///   inputs  = [x (INT8), x_scale (FP32), x_zero_point (INT8)]
68///   outputs = [y (FP32)]
69///   y = (x - x_zero_point) × x_scale
70///
71/// When `axis` is `Some(a)`, the `axis` attribute is set on the node,
72/// enabling per-channel dequantization (opset ≥ 13).
73pub fn build_dequantize_linear_node(
74    names: &DequantLinearNames,
75    axis: Option<usize>,
76) -> NodeProto {
77    let attribute = match axis {
78        Some(a) => vec![AttributeProto {
79            name:   "axis".to_string(),
80            r#type: attribute_proto::AttributeType::Int as i32,
81            i:      a as i64,
82            ..Default::default()
83        }],
84        None => vec![],
85    };
86
87    NodeProto {
88        op_type:   "DequantizeLinear".to_string(),
89        name:      names.node_name.clone(),
90        input:     vec![
91            names.quantized_name.clone(),
92            names.scale_name.clone(),
93            names.zp_name.clone(),
94        ],
95        output:    vec![names.output_name.clone()],
96        attribute,
97        ..Default::default()
98    }
99}
100
101// ---------------------------------------------------------------------------
102// Initializer builders
103// ---------------------------------------------------------------------------
104
105/// INT8 tensor holding the quantized weight values.
106///
107/// Shape matches the original FP32 tensor exactly.  For INT4-quantized values
108/// (range [-8, 7]), the i8 bytes are stored directly — see the INT4 note in
109/// `graph_builder::apply_qdq_transform`.
110pub fn build_quantized_weight_tensor(
111    names: &DequantLinearNames,
112    values: &[i8],
113    shape: &[i64],
114) -> TensorProto {
115    TensorProto {
116        name:      names.quantized_name.clone(),
117        data_type: tensor_proto::DataType::Int8 as i32,
118        dims:      shape.to_vec(),
119        // Each i8 value → one byte.  Reinterpret cast, not value conversion.
120        raw_data:  values.iter().map(|&v| v as u8).collect(),
121        ..Default::default()
122    }
123}
124
125/// FP32 scale tensor.
126///
127/// For per-tensor quantization, `scales` has one element and the tensor
128/// is rank-0 (scalar).  For per-channel, `scales` has one entry per
129/// channel and the tensor is rank-1 with shape `[num_channels]`.
130pub fn build_scale_tensor(names: &DequantLinearNames, scales: &[f32]) -> TensorProto {
131    let mut t = TensorProto {
132        name:       names.scale_name.clone(),
133        data_type:  tensor_proto::DataType::Float as i32,
134        float_data: scales.to_vec(),
135        ..Default::default()
136    };
137    if scales.len() > 1 {
138        // rank-1: [num_channels]
139        t.dims = vec![scales.len() as i64];
140    }
141    // For scalar (len == 1), dims remains empty (rank-0 scalar).
142    t
143}
144
145/// INT8 zero-point tensor.
146///
147/// For per-tensor, `zps` has one element → rank-0 scalar.
148/// For per-channel, `zps` has one per channel → rank-1 `[num_channels]`.
149pub fn build_zero_point_tensor(names: &DequantLinearNames, zps: &[i8]) -> TensorProto {
150    let mut t = TensorProto {
151        name:      names.zp_name.clone(),
152        data_type: tensor_proto::DataType::Int8 as i32,
153        raw_data:  zps.iter().map(|&v| v as u8).collect(),
154        ..Default::default()
155    };
156    if zps.len() > 1 {
157        // rank-1: [num_channels]
158        t.dims = vec![zps.len() as i64];
159    }
160    // For scalar (len == 1), dims remains empty (rank-0 scalar).
161    t
162}
163
164// ---------------------------------------------------------------------------
165// Tests
166// ---------------------------------------------------------------------------
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::onnx_proto::tensor_proto;
172
173    #[test]
174    fn test_names_from_simple_weight() {
175        let n = DequantLinearNames::from_original("conv1.weight");
176        assert_eq!(n.quantized_name, "conv1.weight_quantized");
177        assert_eq!(n.scale_name,     "conv1.weight_scale");
178        assert_eq!(n.zp_name,        "conv1.weight_zp");
179        assert_eq!(n.node_name,      "DequantizeLinear_conv1.weight");
180        assert_eq!(n.output_name,    "conv1.weight");
181    }
182
183    #[test]
184    fn test_names_from_dotted_path() {
185        // Real ResNet-18 weight names look like this
186        let n = DequantLinearNames::from_original("layer1.0.conv1.weight");
187        assert_eq!(n.quantized_name, "layer1.0.conv1.weight_quantized");
188        assert_eq!(n.output_name,    "layer1.0.conv1.weight");
189    }
190
191    #[test]
192    fn test_dequantize_linear_node_inputs_outputs() {
193        let names = DequantLinearNames::from_original("fc.weight");
194        let node = build_dequantize_linear_node(&names, None);
195
196        assert_eq!(node.op_type, "DequantizeLinear");
197        assert_eq!(node.name,    "DequantizeLinear_fc.weight");
198
199        assert_eq!(node.input.len(), 3);
200        assert_eq!(node.input[0], "fc.weight_quantized");
201        assert_eq!(node.input[1], "fc.weight_scale");
202        assert_eq!(node.input[2], "fc.weight_zp");
203
204        assert_eq!(node.output.len(), 1);
205        assert_eq!(node.output[0], "fc.weight");
206        assert!(node.attribute.is_empty());
207    }
208
209    #[test]
210    fn test_dequantize_linear_node_with_axis() {
211        let names = DequantLinearNames::from_original("conv.weight");
212        let node = build_dequantize_linear_node(&names, Some(0));
213
214        assert_eq!(node.attribute.len(), 1);
215        assert_eq!(node.attribute[0].name, "axis");
216        assert_eq!(node.attribute[0].i, 0);
217    }
218
219    #[test]
220    fn test_quantized_weight_tensor_shape_and_data() {
221        let names  = DequantLinearNames::from_original("w");
222        let values = vec![1i8, -2, 3, -4, 5, 6];
223        let shape  = vec![2i64, 3];
224        let t = build_quantized_weight_tensor(&names, &values, &shape);
225
226        assert_eq!(t.name,      "w_quantized");
227        assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
228        assert_eq!(t.dims.len(), 2);
229        assert_eq!(t.dims[0], 2);
230        assert_eq!(t.dims[1], 3);
231
232        // Verify byte-level round-trip
233        let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
234        assert_eq!(recovered, values);
235    }
236
237    #[test]
238    fn test_scale_tensor_scalar() {
239        let names = DequantLinearNames::from_original("w");
240        let t = build_scale_tensor(&names, &[0.003921]);
241
242        assert_eq!(t.name,      "w_scale");
243        assert_eq!(t.data_type, tensor_proto::DataType::Float as i32);
244        assert_eq!(t.dims.len(), 0, "single scale must be rank-0 scalar");
245        assert!((t.float_data[0] - 0.003921).abs() < 1e-6);
246    }
247
248    #[test]
249    fn test_scale_tensor_per_channel() {
250        let names = DequantLinearNames::from_original("w");
251        let t = build_scale_tensor(&names, &[0.01, 0.02, 0.03]);
252
253        assert_eq!(t.dims.len(), 1);
254        assert_eq!(t.dims[0], 3);
255        assert_eq!(t.float_data.len(), 3);
256    }
257
258    #[test]
259    fn test_zero_point_tensor_scalar() {
260        let names = DequantLinearNames::from_original("w");
261        let t = build_zero_point_tensor(&names, &[-3]);
262
263        assert_eq!(t.name,      "w_zp");
264        assert_eq!(t.data_type, tensor_proto::DataType::Int8 as i32);
265        assert_eq!(t.dims.len(), 0, "single zp must be rank-0 scalar");
266        assert_eq!(t.raw_data[0], (-3i8) as u8);
267    }
268
269    #[test]
270    fn test_zero_point_tensor_per_channel() {
271        let names = DequantLinearNames::from_original("w");
272        let t = build_zero_point_tensor(&names, &[-3, 0, 5]);
273
274        assert_eq!(t.dims.len(), 1);
275        assert_eq!(t.dims[0], 3);
276        assert_eq!(t.raw_data.len(), 3);
277    }
278
279    #[test]
280    fn test_int4_range_values_round_trip() {
281        // INT4 signed range: [-8, 7].  These arrive as i8; we store them as-is.
282        let names  = DequantLinearNames::from_original("w");
283        let values = vec![-8i8, -1, 0, 7];
284        let shape  = vec![4i64];
285        let t = build_quantized_weight_tensor(&names, &values, &shape);
286
287        let recovered: Vec<i8> = t.raw_data.iter().map(|&b| b as i8).collect();
288        assert_eq!(recovered, values);
289    }
290}