quantize_rs/onnx_utils/mod.rs
1// src/onnx_utils/mod.rs
2//! ONNX model utilities — loading, weight extraction, quantized save (QDQ),
3//! graph connectivity validation, and quantized-model introspection.
4
5pub mod graph_builder;
6pub mod quantization_nodes;
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{
10 tensor_proto, tensor_shape_proto, type_proto, ModelProto, StringStringEntryProto,
11};
12use prost::Message;
13use std::fs;
14use std::io::{Read, Write};
15
16// Re-export so callers don't have to reach into submodules
17pub use graph_builder::{ConnectivityReport, SaveOptions};
18
19// ===========================================================================
20// Core types
21// ===========================================================================
22
23/// An ONNX model loaded from a protobuf file.
24///
25/// Provides methods for inspecting, extracting weights, saving quantized
26/// models, and validating graph connectivity.
27pub struct OnnxModel {
28 proto: ModelProto,
29}
30
31impl std::fmt::Debug for OnnxModel {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 let name = self
34 .proto
35 .graph
36 .as_ref()
37 .map(|g| g.name.as_str())
38 .unwrap_or("");
39 let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
40 f.debug_struct("OnnxModel")
41 .field("name", &name)
42 .field("num_nodes", &num_nodes)
43 .finish()
44 }
45}
46
47/// Summary of an ONNX model's structure.
48#[derive(Debug)]
49pub struct ModelInfo {
50 /// Graph name from the protobuf.
51 pub name: String,
52 /// Model version from the protobuf.
53 pub version: i64,
54 /// Number of computation nodes in the graph.
55 pub num_nodes: usize,
56 /// Names of the graph inputs.
57 pub inputs: Vec<String>,
58 /// Names of the graph outputs.
59 pub outputs: Vec<String>,
60}
61
62/// Metadata about a quantized weight recovered from a QDQ-format model.
63#[derive(Debug, Clone)]
64pub struct QuantizedWeightInfo {
65 /// Original weight name (without `_quantized` suffix).
66 pub name: String,
67 /// Quantization bit width (4 or 8).
68 pub bits: u8,
69 /// Quantization scales. `len() == 1` for per-tensor quantization;
70 /// `len() == num_channels` for per-channel.
71 pub scales: Vec<f32>,
72 /// Quantization zero points. Same length as [`scales`](Self::scales).
73 pub zero_points: Vec<i8>,
74 /// Number of elements in the quantized tensor.
75 pub original_length: usize,
76 /// Actual on-disk byte count of the quantized initializer's `raw_data`.
77 /// For INT8 storage this equals `original_length`; for native INT4
78 /// (opset 21) it is `ceil(original_length / 2)`.
79 pub storage_bytes: usize,
80}
81
82impl QuantizedWeightInfo {
83 /// `true` if the weight was quantized per-channel (more than one scale).
84 pub fn is_per_channel(&self) -> bool {
85 self.scales.len() > 1
86 }
87
88 /// Per-tensor convenience accessor: returns the first scale. Panics if empty.
89 ///
90 /// For per-channel tensors, iterate over [`scales`](Self::scales) instead.
91 pub fn scale(&self) -> f32 {
92 self.scales[0]
93 }
94
95 /// Per-tensor convenience accessor: returns the first zero-point. Panics if empty.
96 ///
97 /// For per-channel tensors, iterate over [`zero_points`](Self::zero_points) instead.
98 pub fn zero_point(&self) -> i8 {
99 self.zero_points[0]
100 }
101}
102
103// ===========================================================================
104// OnnxModel — load / inspect
105// ===========================================================================
106
107impl OnnxModel {
108 /// Load an ONNX model from a file path.
109 ///
110 /// Reads the entire file into a `Vec<u8>` before decoding. For
111 /// multi-gigabyte models consider [`load_mmap`](Self::load_mmap)
112 /// (requires the `mmap` feature) to avoid the extra heap buffer.
113 ///
114 /// # Errors
115 ///
116 /// Returns [`QuantizeError::ModelLoad`] if the file cannot be opened,
117 /// is too large (>10 GB), or contains invalid protobuf data.
118 pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
119 let path = path.as_ref();
120 let mut file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
121 path: path.to_path_buf(),
122 reason: format!("Failed to open ONNX file: {e}"),
123 })?;
124
125 const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
126 let file_size = file
127 .metadata()
128 .map_err(|e| QuantizeError::ModelLoad {
129 path: path.to_path_buf(),
130 reason: format!("Failed to read metadata: {e}"),
131 })?
132 .len();
133 if file_size > MAX_MODEL_SIZE {
134 return Err(QuantizeError::ModelLoad {
135 path: path.to_path_buf(),
136 reason: format!(
137 "Model file too large: {:.2} GB (max: 10 GB)",
138 file_size as f64 / (1024.0 * 1024.0 * 1024.0)
139 ),
140 });
141 }
142
143 let mut buffer = Vec::with_capacity(file_size as usize);
144 file.read_to_end(&mut buffer)
145 .map_err(|e| QuantizeError::ModelLoad {
146 path: path.to_path_buf(),
147 reason: format!("Failed to read ONNX file: {e}"),
148 })?;
149
150 let proto = ModelProto::decode(&buffer[..]).map_err(|e| QuantizeError::ModelLoad {
151 path: path.to_path_buf(),
152 reason: format!("Failed to parse ONNX protobuf: {e}"),
153 })?;
154
155 Ok(Self { proto })
156 }
157
158 /// Decode an ONNX model directly from a byte slice.
159 ///
160 /// Useful for in-memory or fuzzing scenarios where the source isn't a
161 /// filesystem path. Same validation as [`load`](Self::load) but without
162 /// the file-size gate.
163 ///
164 /// # Errors
165 ///
166 /// Returns [`QuantizeError::ModelLoad`] if `bytes` cannot be decoded as a
167 /// `ModelProto`.
168 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
169 let proto = ModelProto::decode(bytes).map_err(|e| QuantizeError::ModelLoad {
170 path: std::path::PathBuf::new(),
171 reason: format!("Failed to parse ONNX protobuf: {e}"),
172 })?;
173 Ok(Self { proto })
174 }
175
176 /// Load an ONNX model by memory-mapping the file (requires the `mmap`
177 /// feature).
178 ///
179 /// Compared to [`load`](Self::load), this avoids the intermediate
180 /// `Vec<u8>` buffer — useful for multi-gigabyte models where doubling
181 /// the working set during decode is a problem. Peak RAM during load
182 /// falls from roughly `2 × file_size` to `1 × file_size + mmap overhead`.
183 ///
184 /// # Safety
185 ///
186 /// Memory-mapping requires that the file is not modified for the
187 /// duration of the load. Another process truncating or rewriting the
188 /// file while decoding would be undefined behaviour. This function
189 /// uses the `unsafe { Mmap::map(&file) }` call under the hood; its
190 /// invariants are the caller's responsibility.
191 ///
192 /// # Errors
193 ///
194 /// Returns [`QuantizeError::ModelLoad`] on I/O failure, invalid size,
195 /// or malformed protobuf.
196 #[cfg(feature = "mmap")]
197 pub fn load_mmap(path: impl AsRef<std::path::Path>) -> Result<Self> {
198 let path = path.as_ref();
199 let file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
200 path: path.to_path_buf(),
201 reason: format!("Failed to open ONNX file: {e}"),
202 })?;
203
204 const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
205 let file_size = file
206 .metadata()
207 .map_err(|e| QuantizeError::ModelLoad {
208 path: path.to_path_buf(),
209 reason: format!("Failed to read metadata: {e}"),
210 })?
211 .len();
212 if file_size > MAX_MODEL_SIZE {
213 return Err(QuantizeError::ModelLoad {
214 path: path.to_path_buf(),
215 reason: format!(
216 "Model file too large: {:.2} GB (max: 10 GB)",
217 file_size as f64 / (1024.0 * 1024.0 * 1024.0)
218 ),
219 });
220 }
221
222 // SAFETY: see method-level docs — caller guarantees the file is
223 // not modified while it is mapped.
224 let mmap = unsafe {
225 memmap2::Mmap::map(&file).map_err(|e| QuantizeError::ModelLoad {
226 path: path.to_path_buf(),
227 reason: format!("Failed to mmap ONNX file: {e}"),
228 })?
229 };
230
231 let proto = ModelProto::decode(&mmap[..]).map_err(|e| QuantizeError::ModelLoad {
232 path: path.to_path_buf(),
233 reason: format!("Failed to parse ONNX protobuf: {e}"),
234 })?;
235
236 // mmap is dropped here; `proto` owns all its data (prost copies
237 // bytes out of the source during decode), so this is sound.
238 Ok(Self { proto })
239 }
240
241 /// Return a summary of the model's structure.
242 pub fn info(&self) -> ModelInfo {
243 let graph = self.proto.graph.as_ref();
244
245 let inputs: Vec<String> = graph
246 .map(|g| g.input.iter().map(|i| i.name.clone()).collect())
247 .unwrap_or_default();
248
249 let outputs: Vec<String> = graph
250 .map(|g| g.output.iter().map(|o| o.name.clone()).collect())
251 .unwrap_or_default();
252
253 ModelInfo {
254 name: graph.map(|g| g.name.clone()).unwrap_or_default(),
255 version: self.proto.model_version,
256 num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
257 inputs,
258 outputs,
259 }
260 }
261
262 /// Return the shapes of each graph input from the protobuf type info.
263 ///
264 /// Each inner `Vec<i64>` contains the dimension values. Dynamic dims
265 /// (symbolic or missing) are returned as -1. Returns one entry per
266 /// `graph.input` that has tensor type information.
267 pub fn input_shapes(&self) -> Vec<Vec<i64>> {
268 let graph = match &self.proto.graph {
269 Some(g) => g,
270 None => return Vec::new(),
271 };
272
273 let mut shapes = Vec::new();
274 for inp in &graph.input {
275 if let Some(type_proto) = &inp.r#type {
276 if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
277 if let Some(shape) = &tensor_type.shape {
278 let dims: Vec<i64> = shape
279 .dim
280 .iter()
281 .map(|d| match &d.value {
282 Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
283 _ => -1,
284 })
285 .collect();
286 shapes.push(dims);
287 }
288 }
289 }
290 }
291 shapes
292 }
293
294 /// Extract all FP32 weight tensors from the model's initializers.
295 pub fn extract_weights(&self) -> Vec<WeightTensor> {
296 let graph = match &self.proto.graph {
297 Some(g) => g,
298 None => return Vec::new(),
299 };
300
301 let mut weights = Vec::new();
302 for initializer in &graph.initializer {
303 // Only extract FP32 tensors — skip INT8, INT64, DOUBLE, etc.
304 if initializer.data_type != tensor_proto::DataType::Float as i32 {
305 continue;
306 }
307
308 let name = initializer.name.clone();
309
310 let shape: Vec<usize> = initializer
311 .dims
312 .iter()
313 .map(|&d| d.max(0) as usize)
314 .collect();
315
316 let data = if !initializer.raw_data.is_empty() {
317 if initializer.raw_data.len() % 4 != 0 {
318 // Misaligned raw_data — skip this initializer rather than panic
319 continue;
320 }
321 initializer
322 .raw_data
323 .chunks_exact(4)
324 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
325 .collect()
326 } else {
327 initializer.float_data.clone()
328 };
329
330 if !data.is_empty() {
331 weights.push(WeightTensor { name, data, shape });
332 }
333 }
334
335 weights
336 }
337
338 /// Total size of all weight tensors in bytes (float32).
339 ///
340 /// Prefer computing this from already-extracted weights when available:
341 /// `weights.iter().map(|w| w.size_bytes()).sum()` avoids reparsing.
342 pub fn total_size_bytes(&self) -> usize {
343 let graph = match &self.proto.graph {
344 Some(g) => g,
345 None => return 0,
346 };
347 graph
348 .initializer
349 .iter()
350 .map(|init| {
351 if !init.raw_data.is_empty() {
352 init.raw_data.len()
353 } else {
354 init.float_data.len() * std::mem::size_of::<f32>()
355 }
356 })
357 .sum()
358 }
359}
360
361// ===========================================================================
362// OnnxModel — quantized save (QDQ pattern, v0.3.0+)
363// ===========================================================================
364
365impl OnnxModel {
366 /// Save a quantized model using the QDQ (DequantizeLinear) pattern.
367 ///
368 /// **Signature is identical to v0.2.0** — existing callers (CLI, calibration
369 /// pipeline, examples) compile without changes.
370 ///
371 /// ### What changed internally
372 ///
373 /// v0.2.0 appended metadata to initializer names (e.g. `conv1.weight` →
374 /// `conv1.weight__qINT8_s0.001_z-3_len9408`) without updating the nodes that
375 /// reference them. ONNX Runtime rejected these models on load.
376 ///
377 /// v0.3.0 inserts a `DequantizeLinear` node per weight. The node's output
378 /// carries the **original** name, so every downstream node is unchanged.
379 /// Graph connectivity is preserved by construction, and the resulting model
380 /// loads and runs in ONNX Runtime.
381 ///
382 /// ### INT4 storage note
383 ///
384 /// `DequantizeLinear` requires INT8 input in opsets < 21. By default,
385 /// INT4-quantized values ([-8, 7]) are widened to INT8 bytes — 4×
386 /// compression from FP32. For true 8× compression, call
387 /// [`save_quantized_with_options`](Self::save_quantized_with_options) with
388 /// [`SaveOptions::with_native_int4(true)`], which emits native `INT4`
389 /// initializers and bumps the opset to 21.
390 pub fn save_quantized(
391 &mut self,
392 quantized_data: &[graph_builder::QdqWeightInput],
393 path: impl AsRef<std::path::Path>,
394 ) -> Result<()> {
395 self.save_quantized_with_options(quantized_data, path, SaveOptions::default())
396 }
397
398 /// Save a quantized model with explicit [`SaveOptions`] control.
399 ///
400 /// See [`save_quantized`](Self::save_quantized) for the transform details.
401 /// Enabling [`SaveOptions::native_int4`] for INT4 weights bumps the
402 /// required opset to 21 automatically.
403 pub fn save_quantized_with_options(
404 &mut self,
405 quantized_data: &[graph_builder::QdqWeightInput],
406 path: impl AsRef<std::path::Path>,
407 options: SaveOptions,
408 ) -> Result<()> {
409 let path = path.as_ref();
410 use graph_builder::{apply_qdq_transform_with_options, ensure_opset_version};
411
412 // --- 1. Opset: ≥10 for per-tensor, ≥13 for per-channel, ≥21 for native INT4 ---
413 let needs_per_channel = quantized_data.iter().any(|w| w.axis.is_some());
414 let uses_native_int4 = options.native_int4 && quantized_data.iter().any(|w| w.bits == 4);
415 let min_opset = if uses_native_int4 {
416 21
417 } else if needs_per_channel {
418 13
419 } else {
420 10
421 };
422 ensure_opset_version(&mut self.proto, min_opset);
423
424 // --- 2. Persist per-weight bits in model metadata ---
425 for inp in quantized_data.iter() {
426 self.proto.metadata_props.push(StringStringEntryProto {
427 key: format!("quantize_rs.bits.{}", inp.original_name),
428 value: inp.bits.to_string(),
429 });
430 }
431
432 // --- 3. Apply QDQ transform to the graph ---
433 let graph = self
434 .proto
435 .graph
436 .as_mut()
437 .ok_or_else(|| QuantizeError::ModelSave {
438 path: path.to_path_buf(),
439 reason: "Model has no graph".to_string(),
440 })?;
441 apply_qdq_transform_with_options(graph, quantized_data, options)?;
442
443 // --- 4. Encode and write to disk ---
444 let mut buf = Vec::new();
445 self.proto
446 .encode(&mut buf)
447 .map_err(|e| QuantizeError::ModelSave {
448 path: path.to_path_buf(),
449 reason: format!("Failed to encode ONNX model: {e}"),
450 })?;
451
452 let mut file = std::fs::File::create(path).map_err(|e| QuantizeError::ModelSave {
453 path: path.to_path_buf(),
454 reason: format!("Failed to create output file: {e}"),
455 })?;
456
457 file.write_all(&buf).map_err(|e| QuantizeError::ModelSave {
458 path: path.to_path_buf(),
459 reason: format!("Failed to write ONNX model: {e}"),
460 })?;
461
462 Ok(())
463 }
464}
465
466// ===========================================================================
467// OnnxModel — validation
468// ===========================================================================
469
470impl OnnxModel {
471 /// Check that every node input in the graph resolves to a known tensor.
472 ///
473 /// A "known tensor" is one of:
474 /// - a declared graph input
475 /// - an initializer
476 /// - the output of a node appearing earlier in the node list
477 ///
478 /// This is the exact check ONNX Runtime performs on load. It's the check
479 /// that v0.2.0's `validate` command skipped, which is why the rename bug
480 /// went undetected. Integrate `report.summary()` into the CLI validate
481 /// output alongside the existing structure / weight checks.
482 pub fn validate_connectivity(&self) -> ConnectivityReport {
483 match &self.proto.graph {
484 Some(graph) => graph_builder::validate_graph_connectivity(graph),
485 None => {
486 use crate::onnx_proto::GraphProto;
487 graph_builder::validate_graph_connectivity(&GraphProto::default())
488 }
489 }
490 }
491}
492
493// ===========================================================================
494// OnnxModel — quantized model introspection (v0.3.0 QDQ format)
495// ===========================================================================
496
497impl OnnxModel {
498 /// Extract metadata about quantized weights from a QDQ-format model.
499 ///
500 /// Looks for initializer triples:
501 /// `{base}_quantized`, `{base}_scale`, `{base}_zp`
502 ///
503 /// Scale and zero-point are decoded in full — per-tensor yields a single
504 /// element; per-channel yields one entry per channel. Bit-width comes
505 /// from `metadata_props` (written by `save_quantized`); defaults to 8 if
506 /// the metadata entry is missing.
507 ///
508 /// Native INT4 zero-point tensors (`DataType::Int4`) are unpacked from
509 /// their two-per-byte on-disk layout automatically.
510 pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
511 let graph = match &self.proto.graph {
512 Some(g) => g,
513 None => return Vec::new(),
514 };
515
516 let mut scale_map: std::collections::HashMap<String, Vec<f32>> =
517 std::collections::HashMap::new();
518 let mut zp_map: std::collections::HashMap<String, Vec<i8>> =
519 std::collections::HashMap::new();
520 let mut quant_bases: Vec<String> = Vec::new();
521
522 for init in &graph.initializer {
523 let name = &init.name;
524
525 if let Some(base) = name.strip_suffix("_scale") {
526 scale_map.insert(base.to_string(), decode_scale_tensor(init));
527 } else if let Some(base) = name.strip_suffix("_zp") {
528 zp_map.insert(base.to_string(), decode_zero_point_tensor(init));
529 } else if let Some(base) = name.strip_suffix("_quantized") {
530 quant_bases.push(base.to_string());
531 }
532 }
533
534 // Read bits from metadata_props (written by save_quantized)
535 let mut bits_map: std::collections::HashMap<String, u8> = std::collections::HashMap::new();
536 for prop in &self.proto.metadata_props {
537 if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
538 if let Ok(bits) = prop.value.parse::<u8>() {
539 bits_map.insert(base.to_string(), bits);
540 }
541 }
542 }
543
544 quant_bases
545 .iter()
546 .map(|base| {
547 let scales = scale_map.get(base).cloned().unwrap_or_else(|| vec![1.0]);
548 let zero_points = zp_map.get(base).cloned().unwrap_or_else(|| vec![0]);
549 let bits = bits_map.get(base).copied().unwrap_or(8);
550
551 // Element count = product of dims on the _quantized tensor;
552 // byte count = actual raw_data length (accounts for native INT4 packing).
553 let quant_init = graph
554 .initializer
555 .iter()
556 .find(|i| i.name == format!("{}_quantized", base));
557 let original_length = quant_init
558 .map(|i| i.dims.iter().product::<i64>() as usize)
559 .unwrap_or(0);
560 let storage_bytes = quant_init.map(|i| i.raw_data.len()).unwrap_or(0);
561
562 QuantizedWeightInfo {
563 name: base.clone(),
564 bits,
565 scales,
566 zero_points,
567 original_length,
568 storage_bytes,
569 }
570 })
571 .collect()
572 }
573}
574
575// ---------------------------------------------------------------------------
576// Helpers for load_quantized_info
577// ---------------------------------------------------------------------------
578
579/// Expected element count for a 1-D or scalar tensor: rank-0 → 1, rank-1 → dims[0].
580fn expected_element_count(init: &crate::onnx_proto::TensorProto) -> usize {
581 if init.dims.is_empty() {
582 1
583 } else {
584 init.dims
585 .iter()
586 .copied()
587 .filter(|&d| d > 0)
588 .product::<i64>() as usize
589 }
590}
591
592fn decode_scale_tensor(init: &crate::onnx_proto::TensorProto) -> Vec<f32> {
593 let expected = expected_element_count(init).max(1);
594
595 if !init.float_data.is_empty() {
596 return init.float_data.clone();
597 }
598
599 if !init.raw_data.is_empty() && init.raw_data.len() >= 4 * expected {
600 return init
601 .raw_data
602 .chunks_exact(4)
603 .take(expected)
604 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
605 .collect();
606 }
607
608 // Malformed or missing — fall back to a safe default so callers can still
609 // report the weight exists without a division-by-zero risk.
610 vec![1.0; expected]
611}
612
613fn decode_zero_point_tensor(init: &crate::onnx_proto::TensorProto) -> Vec<i8> {
614 use crate::onnx_proto::tensor_proto::DataType;
615 use crate::onnx_utils::quantization_nodes::unpack_int4_onnx;
616
617 let expected = expected_element_count(init).max(1);
618
619 // Native INT4: raw_data is packed two-per-byte, logical count in dims.
620 if init.data_type == DataType::Int4 as i32 {
621 return unpack_int4_onnx(&init.raw_data, expected);
622 }
623
624 // INT8 / widened INT4 / UINT8: raw_data is one byte per value.
625 if !init.raw_data.is_empty() {
626 return init
627 .raw_data
628 .iter()
629 .take(expected)
630 .map(|&b| b as i8)
631 .collect();
632 }
633
634 // int32_data carries int-type scalars when raw_data is absent.
635 if !init.int32_data.is_empty() {
636 return init
637 .int32_data
638 .iter()
639 .take(expected)
640 .map(|&v| v as i8)
641 .collect();
642 }
643
644 vec![0; expected]
645}
646
647// ===========================================================================
648// WeightTensor (unchanged from v0.2.0)
649// ===========================================================================
650
651/// An FP32 weight tensor extracted from an ONNX model.
652#[derive(Debug, Clone)]
653pub struct WeightTensor {
654 /// Initializer name in the ONNX graph.
655 pub name: String,
656 /// FP32 weight values.
657 pub data: Vec<f32>,
658 /// Tensor dimensions.
659 pub shape: Vec<usize>,
660}
661
662impl WeightTensor {
663 /// Size of this tensor in bytes (as FP32).
664 pub fn size_bytes(&self) -> usize {
665 self.data.len() * std::mem::size_of::<f32>()
666 }
667
668 /// Total number of scalar elements.
669 pub fn num_elements(&self) -> usize {
670 self.data.len()
671 }
672}