Skip to main content

quantize_rs/onnx_utils/
mod.rs

1// src/onnx_utils/mod.rs
2//! ONNX model utilities — loading, weight extraction, quantized save (QDQ),
3//! graph connectivity validation, and quantized-model introspection.
4
5pub mod graph_builder;
6pub mod quantization_nodes;
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{
10    ModelProto, StringStringEntryProto,
11    tensor_proto, tensor_shape_proto, type_proto,
12};
13use prost::Message;
14use std::fs;
15use std::io::{Read, Write};
16
17// Re-export so callers don't have to reach into submodules
18pub use graph_builder::ConnectivityReport;
19
20// ===========================================================================
21// Core types
22// ===========================================================================
23
24/// An ONNX model loaded from a protobuf file.
25///
26/// Provides methods for inspecting, extracting weights, saving quantized
27/// models, and validating graph connectivity.
28pub struct OnnxModel {
29    proto: ModelProto,
30}
31
32impl std::fmt::Debug for OnnxModel {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        let name = self.proto.graph.as_ref().map(|g| g.name.as_str()).unwrap_or("");
35        let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
36        f.debug_struct("OnnxModel")
37            .field("name", &name)
38            .field("num_nodes", &num_nodes)
39            .finish()
40    }
41}
42
43/// Summary of an ONNX model's structure.
44#[derive(Debug)]
45pub struct ModelInfo {
46    /// Graph name from the protobuf.
47    pub name: String,
48    /// Model version from the protobuf.
49    pub version: i64,
50    /// Number of computation nodes in the graph.
51    pub num_nodes: usize,
52    /// Names of the graph inputs.
53    pub inputs: Vec<String>,
54    /// Names of the graph outputs.
55    pub outputs: Vec<String>,
56}
57
58/// Metadata about a quantized weight recovered from a QDQ-format model.
59#[derive(Debug, Clone)]
60pub struct QuantizedWeightInfo {
61    /// Original weight name (without `_quantized` suffix).
62    pub name: String,
63    /// Quantization bit width (4 or 8).
64    pub bits: u8,
65    /// Quantization scale factor.
66    pub scale: f32,
67    /// Quantization zero point.
68    pub zero_point: i8,
69    /// Number of elements in the quantized tensor.
70    pub original_length: usize,
71}
72
73// ===========================================================================
74// OnnxModel — load / inspect
75// ===========================================================================
76
77impl OnnxModel {
78    /// Load an ONNX model from a file path.
79    ///
80    /// # Errors
81    ///
82    /// Returns [`QuantizeError::ModelLoad`] if the file cannot be opened,
83    /// is too large (>10 GB), or contains invalid protobuf data.
84    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
85        let path = path.as_ref();
86        let mut file =
87            fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
88                path: path.to_path_buf(),
89                reason: format!("Failed to open ONNX file: {e}"),
90            })?;
91
92        const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
93        let file_size = file.metadata()
94            .map_err(|e| QuantizeError::ModelLoad {
95                path: path.to_path_buf(),
96                reason: format!("Failed to read metadata: {e}"),
97            })?
98            .len();
99        if file_size > MAX_MODEL_SIZE {
100            return Err(QuantizeError::ModelLoad {
101                path: path.to_path_buf(),
102                reason: format!(
103                    "Model file too large: {:.2} GB (max: 10 GB)",
104                    file_size as f64 / (1024.0 * 1024.0 * 1024.0)
105                ),
106            });
107        }
108
109        let mut buffer = Vec::with_capacity(file_size as usize);
110        file.read_to_end(&mut buffer)
111            .map_err(|e| QuantizeError::ModelLoad {
112                path: path.to_path_buf(),
113                reason: format!("Failed to read ONNX file: {e}"),
114            })?;
115
116        let proto = ModelProto::decode(&buffer[..])
117            .map_err(|e| QuantizeError::ModelLoad {
118                path: path.to_path_buf(),
119                reason: format!("Failed to parse ONNX protobuf: {e}"),
120            })?;
121
122        Ok(Self { proto })
123    }
124
125    /// Return a summary of the model's structure.
126    pub fn info(&self) -> ModelInfo {
127        let graph = self.proto.graph.as_ref();
128
129        let inputs: Vec<String> = graph
130            .map(|g| g.input.iter().map(|i| i.name.clone()).collect())
131            .unwrap_or_default();
132
133        let outputs: Vec<String> = graph
134            .map(|g| g.output.iter().map(|o| o.name.clone()).collect())
135            .unwrap_or_default();
136
137        ModelInfo {
138            name: graph.map(|g| g.name.clone()).unwrap_or_default(),
139            version: self.proto.model_version,
140            num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
141            inputs,
142            outputs,
143        }
144    }
145
146    /// Return the shapes of each graph input from the protobuf type info.
147    ///
148    /// Each inner `Vec<i64>` contains the dimension values.  Dynamic dims
149    /// (symbolic or missing) are returned as -1.  Returns one entry per
150    /// `graph.input` that has tensor type information.
151    pub fn input_shapes(&self) -> Vec<Vec<i64>> {
152        let graph = match &self.proto.graph {
153            Some(g) => g,
154            None => return Vec::new(),
155        };
156
157        let mut shapes = Vec::new();
158        for inp in &graph.input {
159            if let Some(type_proto) = &inp.r#type {
160                if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
161                    if let Some(shape) = &tensor_type.shape {
162                        let dims: Vec<i64> = shape.dim.iter().map(|d| {
163                            match &d.value {
164                                Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
165                                _ => -1,
166                            }
167                        }).collect();
168                        shapes.push(dims);
169                    }
170                }
171            }
172        }
173        shapes
174    }
175
176    /// Extract all FP32 weight tensors from the model's initializers.
177    pub fn extract_weights(&self) -> Vec<WeightTensor> {
178        let graph = match &self.proto.graph {
179            Some(g) => g,
180            None => return Vec::new(),
181        };
182
183        let mut weights = Vec::new();
184        for initializer in &graph.initializer {
185            // Only extract FP32 tensors — skip INT8, INT64, DOUBLE, etc.
186            if initializer.data_type != tensor_proto::DataType::Float as i32 {
187                continue;
188            }
189
190            let name = initializer.name.clone();
191
192            let shape: Vec<usize> = initializer.dims.iter()
193                .map(|&d| d.max(0) as usize)
194                .collect();
195
196            let data = if !initializer.raw_data.is_empty() {
197                initializer.raw_data.chunks_exact(4)
198                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
199                    .collect()
200            } else {
201                initializer.float_data.clone()
202            };
203
204            if !data.is_empty() {
205                weights.push(WeightTensor { name, data, shape });
206            }
207        }
208
209        weights
210    }
211
212    /// Total size of all weight tensors in bytes (float32).
213    ///
214    /// Prefer computing this from already-extracted weights when available:
215    /// `weights.iter().map(|w| w.size_bytes()).sum()` avoids reparsing.
216    pub fn total_size_bytes(&self) -> usize {
217        let graph = match &self.proto.graph {
218            Some(g) => g,
219            None => return 0,
220        };
221        graph.initializer.iter()
222            .map(|init| {
223                if !init.raw_data.is_empty() {
224                    init.raw_data.len()
225                } else {
226                    init.float_data.len() * std::mem::size_of::<f32>()
227                }
228            })
229            .sum()
230    }
231}
232
233// ===========================================================================
234// OnnxModel — quantized save (QDQ pattern, v0.3.0+)
235// ===========================================================================
236
237impl OnnxModel {
238    /// Save a quantized model using the QDQ (DequantizeLinear) pattern.
239    ///
240    /// **Signature is identical to v0.2.0** — existing callers (CLI, calibration
241    /// pipeline, examples) compile without changes.
242    ///
243    /// ### What changed internally
244    ///
245    /// v0.2.0 appended metadata to initializer names (e.g. `conv1.weight` →
246    /// `conv1.weight__qINT8_s0.001_z-3_len9408`) without updating the nodes that
247    /// reference them.  ONNX Runtime rejected these models on load.
248    ///
249    /// v0.3.0 inserts a `DequantizeLinear` node per weight.  The node's output
250    /// carries the **original** name, so every downstream node is unchanged.
251    /// Graph connectivity is preserved by construction, and the resulting model
252    /// loads and runs in ONNX Runtime.
253    ///
254    /// ### INT4 storage note
255    ///
256    /// `DequantizeLinear` requires INT8 input (opset < 21).  INT4-quantized values
257    /// ([-8, 7]) are stored as INT8 bytes.  Quantization *accuracy* is still
258    /// INT4-level; only the on-disk size is 4× instead of the 8× that bit-packing
259    /// would give.  True INT4 packing is a v0.4.0 target.
260    pub fn save_quantized(
261        &mut self,
262        quantized_data: &[graph_builder::QdqWeightInput],
263        path: impl AsRef<std::path::Path>,
264    ) -> Result<()> {
265        let path = path.as_ref();
266        use graph_builder::{apply_qdq_transform, ensure_opset_version};
267
268        // --- 1. Opset ≥ 13 (DequantizeLinear per-channel needs it) ---
269        ensure_opset_version(&mut self.proto, 13);
270
271        // --- 2. Persist per-weight bits in model metadata ---
272        for inp in quantized_data.iter() {
273            self.proto.metadata_props.push(StringStringEntryProto {
274                key:   format!("quantize_rs.bits.{}", inp.original_name),
275                value: inp.bits.to_string(),
276            });
277        }
278
279        // --- 3. Apply QDQ transform to the graph ---
280        let graph = self.proto.graph.as_mut().ok_or_else(|| QuantizeError::ModelSave {
281            path: path.to_path_buf(),
282            reason: "Model has no graph".to_string(),
283        })?;
284        apply_qdq_transform(graph, quantized_data)?;
285
286        // --- 4. Encode and write to disk ---
287        let mut buf = Vec::new();
288        self.proto.encode(&mut buf)
289            .map_err(|e| QuantizeError::ModelSave {
290                path: path.to_path_buf(),
291                reason: format!("Failed to encode ONNX model: {e}"),
292            })?;
293
294        let mut file = std::fs::File::create(path)
295            .map_err(|e| QuantizeError::ModelSave {
296                path: path.to_path_buf(),
297                reason: format!("Failed to create output file: {e}"),
298            })?;
299
300        file.write_all(&buf)
301            .map_err(|e| QuantizeError::ModelSave {
302                path: path.to_path_buf(),
303                reason: format!("Failed to write ONNX model: {e}"),
304            })?;
305
306        Ok(())
307    }
308}
309
310// ===========================================================================
311// OnnxModel — validation
312// ===========================================================================
313
314impl OnnxModel {
315    /// Check that every node input in the graph resolves to a known tensor.
316    ///
317    /// A "known tensor" is one of:
318    ///   - a declared graph input
319    ///   - an initializer
320    ///   - the output of a node appearing earlier in the node list
321    ///
322    /// This is the exact check ONNX Runtime performs on load.  It's the check
323    /// that v0.2.0's `validate` command skipped, which is why the rename bug
324    /// went undetected.  Integrate `report.summary()` into the CLI validate
325    /// output alongside the existing structure / weight checks.
326    pub fn validate_connectivity(&self) -> ConnectivityReport {
327        match &self.proto.graph {
328            Some(graph) => graph_builder::validate_graph_connectivity(graph),
329            None => {
330                use crate::onnx_proto::GraphProto;
331                graph_builder::validate_graph_connectivity(&GraphProto::default())
332            }
333        }
334    }
335}
336
337// ===========================================================================
338// OnnxModel — quantized model introspection (v0.3.0 QDQ format)
339// ===========================================================================
340
341impl OnnxModel {
342    /// Extract metadata about quantized weights from a QDQ-format model.
343    ///
344    /// Looks for initializer triples:
345    ///   `{base}_quantized`, `{base}_scale`, `{base}_zp`
346    ///
347    /// Scale and zero-point values are read directly from the tensors.
348    /// Bit-width comes from `metadata_props` (written by `save_quantized`);
349    /// defaults to 8 if the metadata entry is missing.
350    pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
351        let graph = match &self.proto.graph {
352            Some(g) => g,
353            None => return Vec::new(),
354        };
355
356        let mut scale_map: std::collections::HashMap<String, f32> =
357            std::collections::HashMap::new();
358        let mut zp_map: std::collections::HashMap<String, i8> =
359            std::collections::HashMap::new();
360        let mut quant_bases: Vec<String> = Vec::new();
361
362        for init in &graph.initializer {
363            let name = &init.name;
364
365            if let Some(base) = name.strip_suffix("_scale") {
366                // Scale is stored in float_data (rank-0 scalar)
367                let scale = if !init.float_data.is_empty() {
368                    init.float_data[0]
369                } else if init.raw_data.len() >= 4 {
370                    // Fallback: try raw_data as little-endian f32
371                    f32::from_le_bytes([
372                        init.raw_data[0], init.raw_data[1],
373                        init.raw_data[2], init.raw_data[3],
374                    ])
375                } else {
376                    1.0
377                };
378                scale_map.insert(base.to_string(), scale);
379            } else if let Some(base) = name.strip_suffix("_zp") {
380                // Zero-point is a single raw byte
381                let zp = if !init.raw_data.is_empty() {
382                    init.raw_data[0] as i8
383                } else {
384                    0
385                };
386                zp_map.insert(base.to_string(), zp);
387            } else if let Some(base) = name.strip_suffix("_quantized") {
388                quant_bases.push(base.to_string());
389            }
390        }
391
392        // Read bits from metadata_props (written by save_quantized)
393        let mut bits_map: std::collections::HashMap<String, u8> =
394            std::collections::HashMap::new();
395        for prop in &self.proto.metadata_props {
396            if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
397                if let Ok(bits) = prop.value.parse::<u8>() {
398                    bits_map.insert(base.to_string(), bits);
399                }
400            }
401        }
402
403        // Assemble QuantizedWeightInfo from the three maps
404        quant_bases
405            .iter()
406            .map(|base| {
407                let scale = scale_map.get(base).copied().unwrap_or(1.0);
408                let zp = zp_map.get(base).copied().unwrap_or(0);
409                let bits = bits_map.get(base).copied().unwrap_or(8);
410
411                // Element count = product of dims on the _quantized tensor
412                let original_length = graph.initializer.iter()
413                    .find(|i| i.name == format!("{}_quantized", base))
414                    .map(|i| i.dims.iter().product::<i64>() as usize)
415                    .unwrap_or(0);
416
417                QuantizedWeightInfo {
418                    name: base.clone(),
419                    bits,
420                    scale,
421                    zero_point: zp,
422                    original_length,
423                }
424            })
425            .collect()
426    }
427}
428
429// ===========================================================================
430// WeightTensor (unchanged from v0.2.0)
431// ===========================================================================
432
433/// An FP32 weight tensor extracted from an ONNX model.
434#[derive(Debug, Clone)]
435pub struct WeightTensor {
436    /// Initializer name in the ONNX graph.
437    pub name: String,
438    /// FP32 weight values.
439    pub data: Vec<f32>,
440    /// Tensor dimensions.
441    pub shape: Vec<usize>,
442}
443
444impl WeightTensor {
445    /// Size of this tensor in bytes (as FP32).
446    pub fn size_bytes(&self) -> usize {
447        self.data.len() * std::mem::size_of::<f32>()
448    }
449
450    /// Total number of scalar elements.
451    pub fn num_elements(&self) -> usize {
452        self.data.len()
453    }
454}