Skip to main content

quantize_rs/onnx_utils/
mod.rs

1// src/onnx_utils/mod.rs
2//! ONNX model utilities — loading, weight extraction, quantized save (QDQ),
3//! graph connectivity validation, and quantized-model introspection.
4
5pub mod graph_builder;
6pub mod quantization_nodes;
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{
10    tensor_proto, tensor_shape_proto, type_proto, ModelProto, StringStringEntryProto,
11};
12use prost::Message;
13use std::fs;
14use std::io::{Read, Write};
15
16// Re-export so callers don't have to reach into submodules
17pub use graph_builder::ConnectivityReport;
18
19// ===========================================================================
20// Core types
21// ===========================================================================
22
23/// An ONNX model loaded from a protobuf file.
24///
25/// Provides methods for inspecting, extracting weights, saving quantized
26/// models, and validating graph connectivity.
27pub struct OnnxModel {
28    proto: ModelProto,
29}
30
31impl std::fmt::Debug for OnnxModel {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        let name = self
34            .proto
35            .graph
36            .as_ref()
37            .map(|g| g.name.as_str())
38            .unwrap_or("");
39        let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
40        f.debug_struct("OnnxModel")
41            .field("name", &name)
42            .field("num_nodes", &num_nodes)
43            .finish()
44    }
45}
46
47/// Summary of an ONNX model's structure.
48#[derive(Debug)]
49pub struct ModelInfo {
50    /// Graph name from the protobuf.
51    pub name: String,
52    /// Model version from the protobuf.
53    pub version: i64,
54    /// Number of computation nodes in the graph.
55    pub num_nodes: usize,
56    /// Names of the graph inputs.
57    pub inputs: Vec<String>,
58    /// Names of the graph outputs.
59    pub outputs: Vec<String>,
60}
61
62/// Metadata about a quantized weight recovered from a QDQ-format model.
63#[derive(Debug, Clone)]
64pub struct QuantizedWeightInfo {
65    /// Original weight name (without `_quantized` suffix).
66    pub name: String,
67    /// Quantization bit width (4 or 8).
68    pub bits: u8,
69    /// Quantization scale factor.
70    pub scale: f32,
71    /// Quantization zero point.
72    pub zero_point: i8,
73    /// Number of elements in the quantized tensor.
74    pub original_length: usize,
75}
76
77// ===========================================================================
78// OnnxModel — load / inspect
79// ===========================================================================
80
81impl OnnxModel {
82    /// Load an ONNX model from a file path.
83    ///
84    /// # Errors
85    ///
86    /// Returns [`QuantizeError::ModelLoad`] if the file cannot be opened,
87    /// is too large (>10 GB), or contains invalid protobuf data.
88    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
89        let path = path.as_ref();
90        let mut file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
91            path: path.to_path_buf(),
92            reason: format!("Failed to open ONNX file: {e}"),
93        })?;
94
95        const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
96        let file_size = file
97            .metadata()
98            .map_err(|e| QuantizeError::ModelLoad {
99                path: path.to_path_buf(),
100                reason: format!("Failed to read metadata: {e}"),
101            })?
102            .len();
103        if file_size > MAX_MODEL_SIZE {
104            return Err(QuantizeError::ModelLoad {
105                path: path.to_path_buf(),
106                reason: format!(
107                    "Model file too large: {:.2} GB (max: 10 GB)",
108                    file_size as f64 / (1024.0 * 1024.0 * 1024.0)
109                ),
110            });
111        }
112
113        let mut buffer = Vec::with_capacity(file_size as usize);
114        file.read_to_end(&mut buffer)
115            .map_err(|e| QuantizeError::ModelLoad {
116                path: path.to_path_buf(),
117                reason: format!("Failed to read ONNX file: {e}"),
118            })?;
119
120        let proto = ModelProto::decode(&buffer[..]).map_err(|e| QuantizeError::ModelLoad {
121            path: path.to_path_buf(),
122            reason: format!("Failed to parse ONNX protobuf: {e}"),
123        })?;
124
125        Ok(Self { proto })
126    }
127
128    /// Return a summary of the model's structure.
129    pub fn info(&self) -> ModelInfo {
130        let graph = self.proto.graph.as_ref();
131
132        let inputs: Vec<String> = graph
133            .map(|g| g.input.iter().map(|i| i.name.clone()).collect())
134            .unwrap_or_default();
135
136        let outputs: Vec<String> = graph
137            .map(|g| g.output.iter().map(|o| o.name.clone()).collect())
138            .unwrap_or_default();
139
140        ModelInfo {
141            name: graph.map(|g| g.name.clone()).unwrap_or_default(),
142            version: self.proto.model_version,
143            num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
144            inputs,
145            outputs,
146        }
147    }
148
149    /// Return the shapes of each graph input from the protobuf type info.
150    ///
151    /// Each inner `Vec<i64>` contains the dimension values.  Dynamic dims
152    /// (symbolic or missing) are returned as -1.  Returns one entry per
153    /// `graph.input` that has tensor type information.
154    pub fn input_shapes(&self) -> Vec<Vec<i64>> {
155        let graph = match &self.proto.graph {
156            Some(g) => g,
157            None => return Vec::new(),
158        };
159
160        let mut shapes = Vec::new();
161        for inp in &graph.input {
162            if let Some(type_proto) = &inp.r#type {
163                if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
164                    if let Some(shape) = &tensor_type.shape {
165                        let dims: Vec<i64> = shape
166                            .dim
167                            .iter()
168                            .map(|d| match &d.value {
169                                Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
170                                _ => -1,
171                            })
172                            .collect();
173                        shapes.push(dims);
174                    }
175                }
176            }
177        }
178        shapes
179    }
180
181    /// Extract all FP32 weight tensors from the model's initializers.
182    pub fn extract_weights(&self) -> Vec<WeightTensor> {
183        let graph = match &self.proto.graph {
184            Some(g) => g,
185            None => return Vec::new(),
186        };
187
188        let mut weights = Vec::new();
189        for initializer in &graph.initializer {
190            // Only extract FP32 tensors — skip INT8, INT64, DOUBLE, etc.
191            if initializer.data_type != tensor_proto::DataType::Float as i32 {
192                continue;
193            }
194
195            let name = initializer.name.clone();
196
197            let shape: Vec<usize> = initializer
198                .dims
199                .iter()
200                .map(|&d| d.max(0) as usize)
201                .collect();
202
203            let data = if !initializer.raw_data.is_empty() {
204                if initializer.raw_data.len() % 4 != 0 {
205                    // Misaligned raw_data — skip this initializer rather than panic
206                    continue;
207                }
208                initializer
209                    .raw_data
210                    .chunks_exact(4)
211                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
212                    .collect()
213            } else {
214                initializer.float_data.clone()
215            };
216
217            if !data.is_empty() {
218                weights.push(WeightTensor { name, data, shape });
219            }
220        }
221
222        weights
223    }
224
225    /// Total size of all weight tensors in bytes (float32).
226    ///
227    /// Prefer computing this from already-extracted weights when available:
228    /// `weights.iter().map(|w| w.size_bytes()).sum()` avoids reparsing.
229    pub fn total_size_bytes(&self) -> usize {
230        let graph = match &self.proto.graph {
231            Some(g) => g,
232            None => return 0,
233        };
234        graph
235            .initializer
236            .iter()
237            .map(|init| {
238                if !init.raw_data.is_empty() {
239                    init.raw_data.len()
240                } else {
241                    init.float_data.len() * std::mem::size_of::<f32>()
242                }
243            })
244            .sum()
245    }
246}
247
248// ===========================================================================
249// OnnxModel — quantized save (QDQ pattern, v0.3.0+)
250// ===========================================================================
251
252impl OnnxModel {
253    /// Save a quantized model using the QDQ (DequantizeLinear) pattern.
254    ///
255    /// **Signature is identical to v0.2.0** — existing callers (CLI, calibration
256    /// pipeline, examples) compile without changes.
257    ///
258    /// ### What changed internally
259    ///
260    /// v0.2.0 appended metadata to initializer names (e.g. `conv1.weight` →
261    /// `conv1.weight__qINT8_s0.001_z-3_len9408`) without updating the nodes that
262    /// reference them.  ONNX Runtime rejected these models on load.
263    ///
264    /// v0.3.0 inserts a `DequantizeLinear` node per weight.  The node's output
265    /// carries the **original** name, so every downstream node is unchanged.
266    /// Graph connectivity is preserved by construction, and the resulting model
267    /// loads and runs in ONNX Runtime.
268    ///
269    /// ### INT4 storage note
270    ///
271    /// `DequantizeLinear` requires INT8 input (opset < 21).  INT4-quantized values
272    /// ([-8, 7]) are stored as INT8 bytes.  Quantization *accuracy* is still
273    /// INT4-level; only the on-disk size is 4× instead of the 8× that bit-packing
274    /// would give.  True INT4 packing is a v0.4.0 target.
275    pub fn save_quantized(
276        &mut self,
277        quantized_data: &[graph_builder::QdqWeightInput],
278        path: impl AsRef<std::path::Path>,
279    ) -> Result<()> {
280        let path = path.as_ref();
281        use graph_builder::{apply_qdq_transform, ensure_opset_version};
282
283        // --- 1. Opset: ≥10 for per-tensor DequantizeLinear, ≥13 for per-channel ---
284        let needs_per_channel = quantized_data.iter().any(|w| w.axis.is_some());
285        let min_opset = if needs_per_channel { 13 } else { 10 };
286        ensure_opset_version(&mut self.proto, min_opset);
287
288        // --- 2. Persist per-weight bits in model metadata ---
289        for inp in quantized_data.iter() {
290            self.proto.metadata_props.push(StringStringEntryProto {
291                key: format!("quantize_rs.bits.{}", inp.original_name),
292                value: inp.bits.to_string(),
293            });
294        }
295
296        // --- 3. Apply QDQ transform to the graph ---
297        let graph = self
298            .proto
299            .graph
300            .as_mut()
301            .ok_or_else(|| QuantizeError::ModelSave {
302                path: path.to_path_buf(),
303                reason: "Model has no graph".to_string(),
304            })?;
305        apply_qdq_transform(graph, quantized_data)?;
306
307        // --- 4. Encode and write to disk ---
308        let mut buf = Vec::new();
309        self.proto
310            .encode(&mut buf)
311            .map_err(|e| QuantizeError::ModelSave {
312                path: path.to_path_buf(),
313                reason: format!("Failed to encode ONNX model: {e}"),
314            })?;
315
316        let mut file = std::fs::File::create(path).map_err(|e| QuantizeError::ModelSave {
317            path: path.to_path_buf(),
318            reason: format!("Failed to create output file: {e}"),
319        })?;
320
321        file.write_all(&buf).map_err(|e| QuantizeError::ModelSave {
322            path: path.to_path_buf(),
323            reason: format!("Failed to write ONNX model: {e}"),
324        })?;
325
326        Ok(())
327    }
328}
329
330// ===========================================================================
331// OnnxModel — validation
332// ===========================================================================
333
334impl OnnxModel {
335    /// Check that every node input in the graph resolves to a known tensor.
336    ///
337    /// A "known tensor" is one of:
338    ///   - a declared graph input
339    ///   - an initializer
340    ///   - the output of a node appearing earlier in the node list
341    ///
342    /// This is the exact check ONNX Runtime performs on load.  It's the check
343    /// that v0.2.0's `validate` command skipped, which is why the rename bug
344    /// went undetected.  Integrate `report.summary()` into the CLI validate
345    /// output alongside the existing structure / weight checks.
346    pub fn validate_connectivity(&self) -> ConnectivityReport {
347        match &self.proto.graph {
348            Some(graph) => graph_builder::validate_graph_connectivity(graph),
349            None => {
350                use crate::onnx_proto::GraphProto;
351                graph_builder::validate_graph_connectivity(&GraphProto::default())
352            }
353        }
354    }
355}
356
357// ===========================================================================
358// OnnxModel — quantized model introspection (v0.3.0 QDQ format)
359// ===========================================================================
360
361impl OnnxModel {
362    /// Extract metadata about quantized weights from a QDQ-format model.
363    ///
364    /// Looks for initializer triples:
365    ///   `{base}_quantized`, `{base}_scale`, `{base}_zp`
366    ///
367    /// Scale and zero-point values are read directly from the tensors.
368    /// Bit-width comes from `metadata_props` (written by `save_quantized`);
369    /// defaults to 8 if the metadata entry is missing.
370    pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
371        let graph = match &self.proto.graph {
372            Some(g) => g,
373            None => return Vec::new(),
374        };
375
376        let mut scale_map: std::collections::HashMap<String, f32> =
377            std::collections::HashMap::new();
378        let mut zp_map: std::collections::HashMap<String, i8> = std::collections::HashMap::new();
379        let mut quant_bases: Vec<String> = Vec::new();
380
381        for init in &graph.initializer {
382            let name = &init.name;
383
384            if let Some(base) = name.strip_suffix("_scale") {
385                // Scale is stored in float_data (rank-0 scalar)
386                let scale = if !init.float_data.is_empty() {
387                    init.float_data[0]
388                } else if init.raw_data.len() >= 4 {
389                    // Fallback: try raw_data as little-endian f32
390                    f32::from_le_bytes([
391                        init.raw_data[0],
392                        init.raw_data[1],
393                        init.raw_data[2],
394                        init.raw_data[3],
395                    ])
396                } else {
397                    1.0
398                };
399                scale_map.insert(base.to_string(), scale);
400            } else if let Some(base) = name.strip_suffix("_zp") {
401                // Zero-point is a single raw byte
402                let zp = if !init.raw_data.is_empty() {
403                    init.raw_data[0] as i8
404                } else {
405                    0
406                };
407                zp_map.insert(base.to_string(), zp);
408            } else if let Some(base) = name.strip_suffix("_quantized") {
409                quant_bases.push(base.to_string());
410            }
411        }
412
413        // Read bits from metadata_props (written by save_quantized)
414        let mut bits_map: std::collections::HashMap<String, u8> = std::collections::HashMap::new();
415        for prop in &self.proto.metadata_props {
416            if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
417                if let Ok(bits) = prop.value.parse::<u8>() {
418                    bits_map.insert(base.to_string(), bits);
419                }
420            }
421        }
422
423        // Assemble QuantizedWeightInfo from the three maps
424        quant_bases
425            .iter()
426            .map(|base| {
427                let scale = scale_map.get(base).copied().unwrap_or(1.0);
428                let zp = zp_map.get(base).copied().unwrap_or(0);
429                let bits = bits_map.get(base).copied().unwrap_or(8);
430
431                // Element count = product of dims on the _quantized tensor
432                let original_length = graph
433                    .initializer
434                    .iter()
435                    .find(|i| i.name == format!("{}_quantized", base))
436                    .map(|i| i.dims.iter().product::<i64>() as usize)
437                    .unwrap_or(0);
438
439                QuantizedWeightInfo {
440                    name: base.clone(),
441                    bits,
442                    scale,
443                    zero_point: zp,
444                    original_length,
445                }
446            })
447            .collect()
448    }
449}
450
451// ===========================================================================
452// WeightTensor (unchanged from v0.2.0)
453// ===========================================================================
454
455/// An FP32 weight tensor extracted from an ONNX model.
456#[derive(Debug, Clone)]
457pub struct WeightTensor {
458    /// Initializer name in the ONNX graph.
459    pub name: String,
460    /// FP32 weight values.
461    pub data: Vec<f32>,
462    /// Tensor dimensions.
463    pub shape: Vec<usize>,
464}
465
466impl WeightTensor {
467    /// Size of this tensor in bytes (as FP32).
468    pub fn size_bytes(&self) -> usize {
469        self.data.len() * std::mem::size_of::<f32>()
470    }
471
472    /// Total number of scalar elements.
473    pub fn num_elements(&self) -> usize {
474        self.data.len()
475    }
476}