Skip to main content

quantize_rs/onnx_utils/
mod.rs

1// src/onnx_utils/mod.rs
2//! ONNX model utilities — loading, weight extraction, quantized save (QDQ),
3//! graph connectivity validation, and quantized-model introspection.
4
5pub mod graph_builder;
6pub mod quantization_nodes;
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{
10    tensor_proto, tensor_shape_proto, type_proto, ModelProto, StringStringEntryProto,
11};
12use prost::Message;
13use std::fs;
14use std::io::{Read, Write};
15
16// Re-export so callers don't have to reach into submodules
17pub use graph_builder::{ConnectivityReport, SaveOptions};
18
19// ===========================================================================
20// Core types
21// ===========================================================================
22
23/// An ONNX model loaded from a protobuf file.
24///
25/// Provides methods for inspecting, extracting weights, saving quantized
26/// models, and validating graph connectivity.
27pub struct OnnxModel {
28    proto: ModelProto,
29}
30
31impl std::fmt::Debug for OnnxModel {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        let name = self
34            .proto
35            .graph
36            .as_ref()
37            .map(|g| g.name.as_str())
38            .unwrap_or("");
39        let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
40        f.debug_struct("OnnxModel")
41            .field("name", &name)
42            .field("num_nodes", &num_nodes)
43            .finish()
44    }
45}
46
47/// Summary of an ONNX model's structure.
48#[derive(Debug)]
49pub struct ModelInfo {
50    /// Graph name from the protobuf.
51    pub name: String,
52    /// Model version from the protobuf.
53    pub version: i64,
54    /// Number of computation nodes in the graph.
55    pub num_nodes: usize,
56    /// Names of the graph inputs.
57    pub inputs: Vec<String>,
58    /// Names of the graph outputs.
59    pub outputs: Vec<String>,
60}
61
62/// Metadata about a quantized weight recovered from a QDQ-format model.
63#[derive(Debug, Clone)]
64pub struct QuantizedWeightInfo {
65    /// Original weight name (without `_quantized` suffix).
66    pub name: String,
67    /// Quantization bit width (4 or 8).
68    pub bits: u8,
69    /// Quantization scales.  `len() == 1` for per-tensor quantization;
70    /// `len() == num_channels` for per-channel.
71    pub scales: Vec<f32>,
72    /// Quantization zero points.  Same length as [`scales`](Self::scales).
73    pub zero_points: Vec<i8>,
74    /// Number of elements in the quantized tensor.
75    pub original_length: usize,
76    /// Actual on-disk byte count of the quantized initializer's `raw_data`.
77    /// For INT8 storage this equals `original_length`; for native INT4
78    /// (opset 21) it is `ceil(original_length / 2)`.
79    pub storage_bytes: usize,
80}
81
82impl QuantizedWeightInfo {
83    /// `true` if the weight was quantized per-channel (more than one scale).
84    pub fn is_per_channel(&self) -> bool {
85        self.scales.len() > 1
86    }
87
88    /// Per-tensor convenience accessor: returns the first scale.  Panics if empty.
89    ///
90    /// For per-channel tensors, iterate over [`scales`](Self::scales) instead.
91    pub fn scale(&self) -> f32 {
92        self.scales[0]
93    }
94
95    /// Per-tensor convenience accessor: returns the first zero-point.  Panics if empty.
96    ///
97    /// For per-channel tensors, iterate over [`zero_points`](Self::zero_points) instead.
98    pub fn zero_point(&self) -> i8 {
99        self.zero_points[0]
100    }
101}
102
103// ===========================================================================
104// OnnxModel — load / inspect
105// ===========================================================================
106
107impl OnnxModel {
108    /// Load an ONNX model from a file path.
109    ///
110    /// Reads the entire file into a `Vec<u8>` before decoding.  For
111    /// multi-gigabyte models consider [`load_mmap`](Self::load_mmap)
112    /// (requires the `mmap` feature) to avoid the extra heap buffer.
113    ///
114    /// # Errors
115    ///
116    /// Returns [`QuantizeError::ModelLoad`] if the file cannot be opened,
117    /// is too large (>10 GB), or contains invalid protobuf data.
118    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
119        let path = path.as_ref();
120        let mut file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
121            path: path.to_path_buf(),
122            reason: format!("Failed to open ONNX file: {e}"),
123        })?;
124
125        const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
126        let file_size = file
127            .metadata()
128            .map_err(|e| QuantizeError::ModelLoad {
129                path: path.to_path_buf(),
130                reason: format!("Failed to read metadata: {e}"),
131            })?
132            .len();
133        if file_size > MAX_MODEL_SIZE {
134            return Err(QuantizeError::ModelLoad {
135                path: path.to_path_buf(),
136                reason: format!(
137                    "Model file too large: {:.2} GB (max: 10 GB)",
138                    file_size as f64 / (1024.0 * 1024.0 * 1024.0)
139                ),
140            });
141        }
142
143        let mut buffer = Vec::with_capacity(file_size as usize);
144        file.read_to_end(&mut buffer)
145            .map_err(|e| QuantizeError::ModelLoad {
146                path: path.to_path_buf(),
147                reason: format!("Failed to read ONNX file: {e}"),
148            })?;
149
150        let proto = ModelProto::decode(&buffer[..]).map_err(|e| QuantizeError::ModelLoad {
151            path: path.to_path_buf(),
152            reason: format!("Failed to parse ONNX protobuf: {e}"),
153        })?;
154
155        Ok(Self { proto })
156    }
157
158    /// Decode an ONNX model directly from a byte slice.
159    ///
160    /// Useful for in-memory or fuzzing scenarios where the source isn't a
161    /// filesystem path.  Same validation as [`load`](Self::load) but without
162    /// the file-size gate.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`QuantizeError::ModelLoad`] if `bytes` cannot be decoded as a
167    /// `ModelProto`.
168    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
169        let proto = ModelProto::decode(bytes).map_err(|e| QuantizeError::ModelLoad {
170            path: std::path::PathBuf::new(),
171            reason: format!("Failed to parse ONNX protobuf: {e}"),
172        })?;
173        Ok(Self { proto })
174    }
175
176    /// Load an ONNX model by memory-mapping the file (requires the `mmap`
177    /// feature).
178    ///
179    /// Compared to [`load`](Self::load), this avoids the intermediate
180    /// `Vec<u8>` buffer — useful for multi-gigabyte models where doubling
181    /// the working set during decode is a problem.  Peak RAM during load
182    /// falls from roughly `2 × file_size` to `1 × file_size + mmap overhead`.
183    ///
184    /// # Safety
185    ///
186    /// Memory-mapping requires that the file is not modified for the
187    /// duration of the load.  Another process truncating or rewriting the
188    /// file while decoding would be undefined behaviour.  This function
189    /// uses the `unsafe { Mmap::map(&file) }` call under the hood; its
190    /// invariants are the caller's responsibility.
191    ///
192    /// # Errors
193    ///
194    /// Returns [`QuantizeError::ModelLoad`] on I/O failure, invalid size,
195    /// or malformed protobuf.
196    #[cfg(feature = "mmap")]
197    pub fn load_mmap(path: impl AsRef<std::path::Path>) -> Result<Self> {
198        let path = path.as_ref();
199        let file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
200            path: path.to_path_buf(),
201            reason: format!("Failed to open ONNX file: {e}"),
202        })?;
203
204        const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
205        let file_size = file
206            .metadata()
207            .map_err(|e| QuantizeError::ModelLoad {
208                path: path.to_path_buf(),
209                reason: format!("Failed to read metadata: {e}"),
210            })?
211            .len();
212        if file_size > MAX_MODEL_SIZE {
213            return Err(QuantizeError::ModelLoad {
214                path: path.to_path_buf(),
215                reason: format!(
216                    "Model file too large: {:.2} GB (max: 10 GB)",
217                    file_size as f64 / (1024.0 * 1024.0 * 1024.0)
218                ),
219            });
220        }
221
222        // SAFETY: see method-level docs — caller guarantees the file is
223        // not modified while it is mapped.
224        let mmap = unsafe {
225            memmap2::Mmap::map(&file).map_err(|e| QuantizeError::ModelLoad {
226                path: path.to_path_buf(),
227                reason: format!("Failed to mmap ONNX file: {e}"),
228            })?
229        };
230
231        let proto = ModelProto::decode(&mmap[..]).map_err(|e| QuantizeError::ModelLoad {
232            path: path.to_path_buf(),
233            reason: format!("Failed to parse ONNX protobuf: {e}"),
234        })?;
235
236        // mmap is dropped here; `proto` owns all its data (prost copies
237        // bytes out of the source during decode), so this is sound.
238        Ok(Self { proto })
239    }
240
241    /// Return a summary of the model's structure.
242    pub fn info(&self) -> ModelInfo {
243        let graph = self.proto.graph.as_ref();
244
245        let inputs: Vec<String> = graph
246            .map(|g| g.input.iter().map(|i| i.name.clone()).collect())
247            .unwrap_or_default();
248
249        let outputs: Vec<String> = graph
250            .map(|g| g.output.iter().map(|o| o.name.clone()).collect())
251            .unwrap_or_default();
252
253        ModelInfo {
254            name: graph.map(|g| g.name.clone()).unwrap_or_default(),
255            version: self.proto.model_version,
256            num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
257            inputs,
258            outputs,
259        }
260    }
261
262    /// Return the shapes of each graph input from the protobuf type info.
263    ///
264    /// Each inner `Vec<i64>` contains the dimension values.  Dynamic dims
265    /// (symbolic or missing) are returned as -1.  Returns one entry per
266    /// `graph.input` that has tensor type information.
267    pub fn input_shapes(&self) -> Vec<Vec<i64>> {
268        let graph = match &self.proto.graph {
269            Some(g) => g,
270            None => return Vec::new(),
271        };
272
273        let mut shapes = Vec::new();
274        for inp in &graph.input {
275            if let Some(type_proto) = &inp.r#type {
276                if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
277                    if let Some(shape) = &tensor_type.shape {
278                        let dims: Vec<i64> = shape
279                            .dim
280                            .iter()
281                            .map(|d| match &d.value {
282                                Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
283                                _ => -1,
284                            })
285                            .collect();
286                        shapes.push(dims);
287                    }
288                }
289            }
290        }
291        shapes
292    }
293
294    /// Extract all FP32 weight tensors from the model's initializers.
295    pub fn extract_weights(&self) -> Vec<WeightTensor> {
296        let graph = match &self.proto.graph {
297            Some(g) => g,
298            None => return Vec::new(),
299        };
300
301        let mut weights = Vec::new();
302        for initializer in &graph.initializer {
303            // Only extract FP32 tensors — skip INT8, INT64, DOUBLE, etc.
304            if initializer.data_type != tensor_proto::DataType::Float as i32 {
305                continue;
306            }
307
308            let name = initializer.name.clone();
309
310            let shape: Vec<usize> = initializer
311                .dims
312                .iter()
313                .map(|&d| d.max(0) as usize)
314                .collect();
315
316            let data = if !initializer.raw_data.is_empty() {
317                if initializer.raw_data.len() % 4 != 0 {
318                    // Misaligned raw_data — skip this initializer rather than panic
319                    continue;
320                }
321                initializer
322                    .raw_data
323                    .chunks_exact(4)
324                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
325                    .collect()
326            } else {
327                initializer.float_data.clone()
328            };
329
330            if !data.is_empty() {
331                weights.push(WeightTensor { name, data, shape });
332            }
333        }
334
335        weights
336    }
337
338    /// Total size of all weight tensors in bytes (float32).
339    ///
340    /// Prefer computing this from already-extracted weights when available:
341    /// `weights.iter().map(|w| w.size_bytes()).sum()` avoids reparsing.
342    pub fn total_size_bytes(&self) -> usize {
343        let graph = match &self.proto.graph {
344            Some(g) => g,
345            None => return 0,
346        };
347        graph
348            .initializer
349            .iter()
350            .map(|init| {
351                if !init.raw_data.is_empty() {
352                    init.raw_data.len()
353                } else {
354                    init.float_data.len() * std::mem::size_of::<f32>()
355                }
356            })
357            .sum()
358    }
359}
360
361// ===========================================================================
362// OnnxModel — quantized save (QDQ pattern, v0.3.0+)
363// ===========================================================================
364
365impl OnnxModel {
366    /// Save a quantized model using the QDQ (DequantizeLinear) pattern.
367    ///
368    /// **Signature is identical to v0.2.0** — existing callers (CLI, calibration
369    /// pipeline, examples) compile without changes.
370    ///
371    /// ### What changed internally
372    ///
373    /// v0.2.0 appended metadata to initializer names (e.g. `conv1.weight` →
374    /// `conv1.weight__qINT8_s0.001_z-3_len9408`) without updating the nodes that
375    /// reference them.  ONNX Runtime rejected these models on load.
376    ///
377    /// v0.3.0 inserts a `DequantizeLinear` node per weight.  The node's output
378    /// carries the **original** name, so every downstream node is unchanged.
379    /// Graph connectivity is preserved by construction, and the resulting model
380    /// loads and runs in ONNX Runtime.
381    ///
382    /// ### INT4 storage note
383    ///
384    /// `DequantizeLinear` requires INT8 input in opsets &lt; 21.  By default,
385    /// INT4-quantized values ([-8, 7]) are widened to INT8 bytes — 4×
386    /// compression from FP32.  For true 8× compression, call
387    /// [`save_quantized_with_options`](Self::save_quantized_with_options) with
388    /// [`SaveOptions::with_native_int4(true)`], which emits native `INT4`
389    /// initializers and bumps the opset to 21.
390    pub fn save_quantized(
391        &mut self,
392        quantized_data: &[graph_builder::QdqWeightInput],
393        path: impl AsRef<std::path::Path>,
394    ) -> Result<()> {
395        self.save_quantized_with_options(quantized_data, path, SaveOptions::default())
396    }
397
398    /// Save a quantized model with explicit [`SaveOptions`] control.
399    ///
400    /// See [`save_quantized`](Self::save_quantized) for the transform details.
401    /// Enabling [`SaveOptions::native_int4`] for INT4 weights bumps the
402    /// required opset to 21 automatically.
403    pub fn save_quantized_with_options(
404        &mut self,
405        quantized_data: &[graph_builder::QdqWeightInput],
406        path: impl AsRef<std::path::Path>,
407        options: SaveOptions,
408    ) -> Result<()> {
409        let path = path.as_ref();
410        use graph_builder::{apply_qdq_transform_with_options, ensure_opset_version};
411
412        // --- 1. Opset: ≥10 for per-tensor, ≥13 for per-channel, ≥21 for native INT4 ---
413        let needs_per_channel = quantized_data.iter().any(|w| w.axis.is_some());
414        let uses_native_int4 = options.native_int4 && quantized_data.iter().any(|w| w.bits == 4);
415        let min_opset = if uses_native_int4 {
416            21
417        } else if needs_per_channel {
418            13
419        } else {
420            10
421        };
422        ensure_opset_version(&mut self.proto, min_opset);
423
424        // --- 2. Persist per-weight bits in model metadata ---
425        for inp in quantized_data.iter() {
426            self.proto.metadata_props.push(StringStringEntryProto {
427                key: format!("quantize_rs.bits.{}", inp.original_name),
428                value: inp.bits.to_string(),
429            });
430        }
431
432        // --- 3. Apply QDQ transform to the graph ---
433        let graph = self
434            .proto
435            .graph
436            .as_mut()
437            .ok_or_else(|| QuantizeError::ModelSave {
438                path: path.to_path_buf(),
439                reason: "Model has no graph".to_string(),
440            })?;
441        apply_qdq_transform_with_options(graph, quantized_data, options)?;
442
443        // --- 4. Encode and write to disk ---
444        let mut buf = Vec::new();
445        self.proto
446            .encode(&mut buf)
447            .map_err(|e| QuantizeError::ModelSave {
448                path: path.to_path_buf(),
449                reason: format!("Failed to encode ONNX model: {e}"),
450            })?;
451
452        let mut file = std::fs::File::create(path).map_err(|e| QuantizeError::ModelSave {
453            path: path.to_path_buf(),
454            reason: format!("Failed to create output file: {e}"),
455        })?;
456
457        file.write_all(&buf).map_err(|e| QuantizeError::ModelSave {
458            path: path.to_path_buf(),
459            reason: format!("Failed to write ONNX model: {e}"),
460        })?;
461
462        Ok(())
463    }
464}
465
466// ===========================================================================
467// OnnxModel — validation
468// ===========================================================================
469
470impl OnnxModel {
471    /// Check that every node input in the graph resolves to a known tensor.
472    ///
473    /// A "known tensor" is one of:
474    ///   - a declared graph input
475    ///   - an initializer
476    ///   - the output of a node appearing earlier in the node list
477    ///
478    /// This is the exact check ONNX Runtime performs on load.  It's the check
479    /// that v0.2.0's `validate` command skipped, which is why the rename bug
480    /// went undetected.  Integrate `report.summary()` into the CLI validate
481    /// output alongside the existing structure / weight checks.
482    pub fn validate_connectivity(&self) -> ConnectivityReport {
483        match &self.proto.graph {
484            Some(graph) => graph_builder::validate_graph_connectivity(graph),
485            None => {
486                use crate::onnx_proto::GraphProto;
487                graph_builder::validate_graph_connectivity(&GraphProto::default())
488            }
489        }
490    }
491}
492
493// ===========================================================================
494// OnnxModel — quantized model introspection (v0.3.0 QDQ format)
495// ===========================================================================
496
497impl OnnxModel {
498    /// Extract metadata about quantized weights from a QDQ-format model.
499    ///
500    /// Looks for initializer triples:
501    ///   `{base}_quantized`, `{base}_scale`, `{base}_zp`
502    ///
503    /// Scale and zero-point are decoded in full — per-tensor yields a single
504    /// element; per-channel yields one entry per channel.  Bit-width comes
505    /// from `metadata_props` (written by `save_quantized`); defaults to 8 if
506    /// the metadata entry is missing.
507    ///
508    /// Native INT4 zero-point tensors (`DataType::Int4`) are unpacked from
509    /// their two-per-byte on-disk layout automatically.
510    pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
511        let graph = match &self.proto.graph {
512            Some(g) => g,
513            None => return Vec::new(),
514        };
515
516        let mut scale_map: std::collections::HashMap<String, Vec<f32>> =
517            std::collections::HashMap::new();
518        let mut zp_map: std::collections::HashMap<String, Vec<i8>> =
519            std::collections::HashMap::new();
520        let mut quant_bases: Vec<String> = Vec::new();
521
522        for init in &graph.initializer {
523            let name = &init.name;
524
525            if let Some(base) = name.strip_suffix("_scale") {
526                scale_map.insert(base.to_string(), decode_scale_tensor(init));
527            } else if let Some(base) = name.strip_suffix("_zp") {
528                zp_map.insert(base.to_string(), decode_zero_point_tensor(init));
529            } else if let Some(base) = name.strip_suffix("_quantized") {
530                quant_bases.push(base.to_string());
531            }
532        }
533
534        // Read bits from metadata_props (written by save_quantized)
535        let mut bits_map: std::collections::HashMap<String, u8> = std::collections::HashMap::new();
536        for prop in &self.proto.metadata_props {
537            if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
538                if let Ok(bits) = prop.value.parse::<u8>() {
539                    bits_map.insert(base.to_string(), bits);
540                }
541            }
542        }
543
544        quant_bases
545            .iter()
546            .map(|base| {
547                let scales = scale_map.get(base).cloned().unwrap_or_else(|| vec![1.0]);
548                let zero_points = zp_map.get(base).cloned().unwrap_or_else(|| vec![0]);
549                let bits = bits_map.get(base).copied().unwrap_or(8);
550
551                // Element count = product of dims on the _quantized tensor;
552                // byte count = actual raw_data length (accounts for native INT4 packing).
553                let quant_init = graph
554                    .initializer
555                    .iter()
556                    .find(|i| i.name == format!("{}_quantized", base));
557                let original_length = quant_init
558                    .map(|i| i.dims.iter().product::<i64>() as usize)
559                    .unwrap_or(0);
560                let storage_bytes = quant_init.map(|i| i.raw_data.len()).unwrap_or(0);
561
562                QuantizedWeightInfo {
563                    name: base.clone(),
564                    bits,
565                    scales,
566                    zero_points,
567                    original_length,
568                    storage_bytes,
569                }
570            })
571            .collect()
572    }
573}
574
575// ---------------------------------------------------------------------------
576// Helpers for load_quantized_info
577// ---------------------------------------------------------------------------
578
579/// Expected element count for a 1-D or scalar tensor: rank-0 → 1, rank-1 → dims[0].
580fn expected_element_count(init: &crate::onnx_proto::TensorProto) -> usize {
581    if init.dims.is_empty() {
582        1
583    } else {
584        init.dims
585            .iter()
586            .copied()
587            .filter(|&d| d > 0)
588            .product::<i64>() as usize
589    }
590}
591
592fn decode_scale_tensor(init: &crate::onnx_proto::TensorProto) -> Vec<f32> {
593    let expected = expected_element_count(init).max(1);
594
595    if !init.float_data.is_empty() {
596        return init.float_data.clone();
597    }
598
599    if !init.raw_data.is_empty() && init.raw_data.len() >= 4 * expected {
600        return init
601            .raw_data
602            .chunks_exact(4)
603            .take(expected)
604            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
605            .collect();
606    }
607
608    // Malformed or missing — fall back to a safe default so callers can still
609    // report the weight exists without a division-by-zero risk.
610    vec![1.0; expected]
611}
612
613fn decode_zero_point_tensor(init: &crate::onnx_proto::TensorProto) -> Vec<i8> {
614    use crate::onnx_proto::tensor_proto::DataType;
615    use crate::onnx_utils::quantization_nodes::unpack_int4_onnx;
616
617    let expected = expected_element_count(init).max(1);
618
619    // Native INT4: raw_data is packed two-per-byte, logical count in dims.
620    if init.data_type == DataType::Int4 as i32 {
621        return unpack_int4_onnx(&init.raw_data, expected);
622    }
623
624    // INT8 / widened INT4 / UINT8: raw_data is one byte per value.
625    if !init.raw_data.is_empty() {
626        return init
627            .raw_data
628            .iter()
629            .take(expected)
630            .map(|&b| b as i8)
631            .collect();
632    }
633
634    // int32_data carries int-type scalars when raw_data is absent.
635    if !init.int32_data.is_empty() {
636        return init
637            .int32_data
638            .iter()
639            .take(expected)
640            .map(|&v| v as i8)
641            .collect();
642    }
643
644    vec![0; expected]
645}
646
647// ===========================================================================
648// WeightTensor (unchanged from v0.2.0)
649// ===========================================================================
650
651/// An FP32 weight tensor extracted from an ONNX model.
652#[derive(Debug, Clone)]
653pub struct WeightTensor {
654    /// Initializer name in the ONNX graph.
655    pub name: String,
656    /// FP32 weight values.
657    pub data: Vec<f32>,
658    /// Tensor dimensions.
659    pub shape: Vec<usize>,
660}
661
662impl WeightTensor {
663    /// Size of this tensor in bytes (as FP32).
664    pub fn size_bytes(&self) -> usize {
665        self.data.len() * std::mem::size_of::<f32>()
666    }
667
668    /// Total number of scalar elements.
669    pub fn num_elements(&self) -> usize {
670        self.data.len()
671    }
672}