quantize_rs/onnx_utils/mod.rs
1// src/onnx_utils/mod.rs
2//! ONNX model utilities — loading, weight extraction, quantized save (QDQ),
3//! graph connectivity validation, and quantized-model introspection.
4
5pub mod graph_builder;
6pub mod quantization_nodes;
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{
10 tensor_proto, tensor_shape_proto, type_proto, ModelProto, StringStringEntryProto,
11};
12use prost::Message;
13use std::fs;
14use std::io::{Read, Write};
15
16// Re-export so callers don't have to reach into submodules
17pub use graph_builder::ConnectivityReport;
18
19// ===========================================================================
20// Core types
21// ===========================================================================
22
23/// An ONNX model loaded from a protobuf file.
24///
25/// Provides methods for inspecting, extracting weights, saving quantized
26/// models, and validating graph connectivity.
27pub struct OnnxModel {
28 proto: ModelProto,
29}
30
31impl std::fmt::Debug for OnnxModel {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 let name = self
34 .proto
35 .graph
36 .as_ref()
37 .map(|g| g.name.as_str())
38 .unwrap_or("");
39 let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
40 f.debug_struct("OnnxModel")
41 .field("name", &name)
42 .field("num_nodes", &num_nodes)
43 .finish()
44 }
45}
46
47/// Summary of an ONNX model's structure.
48#[derive(Debug)]
49pub struct ModelInfo {
50 /// Graph name from the protobuf.
51 pub name: String,
52 /// Model version from the protobuf.
53 pub version: i64,
54 /// Number of computation nodes in the graph.
55 pub num_nodes: usize,
56 /// Names of the graph inputs.
57 pub inputs: Vec<String>,
58 /// Names of the graph outputs.
59 pub outputs: Vec<String>,
60}
61
62/// Metadata about a quantized weight recovered from a QDQ-format model.
63#[derive(Debug, Clone)]
64pub struct QuantizedWeightInfo {
65 /// Original weight name (without `_quantized` suffix).
66 pub name: String,
67 /// Quantization bit width (4 or 8).
68 pub bits: u8,
69 /// Quantization scale factor.
70 pub scale: f32,
71 /// Quantization zero point.
72 pub zero_point: i8,
73 /// Number of elements in the quantized tensor.
74 pub original_length: usize,
75}
76
77// ===========================================================================
78// OnnxModel — load / inspect
79// ===========================================================================
80
81impl OnnxModel {
82 /// Load an ONNX model from a file path.
83 ///
84 /// # Errors
85 ///
86 /// Returns [`QuantizeError::ModelLoad`] if the file cannot be opened,
87 /// is too large (>10 GB), or contains invalid protobuf data.
88 pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
89 let path = path.as_ref();
90 let mut file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
91 path: path.to_path_buf(),
92 reason: format!("Failed to open ONNX file: {e}"),
93 })?;
94
95 const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
96 let file_size = file
97 .metadata()
98 .map_err(|e| QuantizeError::ModelLoad {
99 path: path.to_path_buf(),
100 reason: format!("Failed to read metadata: {e}"),
101 })?
102 .len();
103 if file_size > MAX_MODEL_SIZE {
104 return Err(QuantizeError::ModelLoad {
105 path: path.to_path_buf(),
106 reason: format!(
107 "Model file too large: {:.2} GB (max: 10 GB)",
108 file_size as f64 / (1024.0 * 1024.0 * 1024.0)
109 ),
110 });
111 }
112
113 let mut buffer = Vec::with_capacity(file_size as usize);
114 file.read_to_end(&mut buffer)
115 .map_err(|e| QuantizeError::ModelLoad {
116 path: path.to_path_buf(),
117 reason: format!("Failed to read ONNX file: {e}"),
118 })?;
119
120 let proto = ModelProto::decode(&buffer[..]).map_err(|e| QuantizeError::ModelLoad {
121 path: path.to_path_buf(),
122 reason: format!("Failed to parse ONNX protobuf: {e}"),
123 })?;
124
125 Ok(Self { proto })
126 }
127
128 /// Return a summary of the model's structure.
129 pub fn info(&self) -> ModelInfo {
130 let graph = self.proto.graph.as_ref();
131
132 let inputs: Vec<String> = graph
133 .map(|g| g.input.iter().map(|i| i.name.clone()).collect())
134 .unwrap_or_default();
135
136 let outputs: Vec<String> = graph
137 .map(|g| g.output.iter().map(|o| o.name.clone()).collect())
138 .unwrap_or_default();
139
140 ModelInfo {
141 name: graph.map(|g| g.name.clone()).unwrap_or_default(),
142 version: self.proto.model_version,
143 num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
144 inputs,
145 outputs,
146 }
147 }
148
149 /// Return the shapes of each graph input from the protobuf type info.
150 ///
151 /// Each inner `Vec<i64>` contains the dimension values. Dynamic dims
152 /// (symbolic or missing) are returned as -1. Returns one entry per
153 /// `graph.input` that has tensor type information.
154 pub fn input_shapes(&self) -> Vec<Vec<i64>> {
155 let graph = match &self.proto.graph {
156 Some(g) => g,
157 None => return Vec::new(),
158 };
159
160 let mut shapes = Vec::new();
161 for inp in &graph.input {
162 if let Some(type_proto) = &inp.r#type {
163 if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
164 if let Some(shape) = &tensor_type.shape {
165 let dims: Vec<i64> = shape
166 .dim
167 .iter()
168 .map(|d| match &d.value {
169 Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
170 _ => -1,
171 })
172 .collect();
173 shapes.push(dims);
174 }
175 }
176 }
177 }
178 shapes
179 }
180
181 /// Extract all FP32 weight tensors from the model's initializers.
182 pub fn extract_weights(&self) -> Vec<WeightTensor> {
183 let graph = match &self.proto.graph {
184 Some(g) => g,
185 None => return Vec::new(),
186 };
187
188 let mut weights = Vec::new();
189 for initializer in &graph.initializer {
190 // Only extract FP32 tensors — skip INT8, INT64, DOUBLE, etc.
191 if initializer.data_type != tensor_proto::DataType::Float as i32 {
192 continue;
193 }
194
195 let name = initializer.name.clone();
196
197 let shape: Vec<usize> = initializer
198 .dims
199 .iter()
200 .map(|&d| d.max(0) as usize)
201 .collect();
202
203 let data = if !initializer.raw_data.is_empty() {
204 if initializer.raw_data.len() % 4 != 0 {
205 // Misaligned raw_data — skip this initializer rather than panic
206 continue;
207 }
208 initializer
209 .raw_data
210 .chunks_exact(4)
211 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
212 .collect()
213 } else {
214 initializer.float_data.clone()
215 };
216
217 if !data.is_empty() {
218 weights.push(WeightTensor { name, data, shape });
219 }
220 }
221
222 weights
223 }
224
225 /// Total size of all weight tensors in bytes (float32).
226 ///
227 /// Prefer computing this from already-extracted weights when available:
228 /// `weights.iter().map(|w| w.size_bytes()).sum()` avoids reparsing.
229 pub fn total_size_bytes(&self) -> usize {
230 let graph = match &self.proto.graph {
231 Some(g) => g,
232 None => return 0,
233 };
234 graph
235 .initializer
236 .iter()
237 .map(|init| {
238 if !init.raw_data.is_empty() {
239 init.raw_data.len()
240 } else {
241 init.float_data.len() * std::mem::size_of::<f32>()
242 }
243 })
244 .sum()
245 }
246}
247
248// ===========================================================================
249// OnnxModel — quantized save (QDQ pattern, v0.3.0+)
250// ===========================================================================
251
252impl OnnxModel {
253 /// Save a quantized model using the QDQ (DequantizeLinear) pattern.
254 ///
255 /// **Signature is identical to v0.2.0** — existing callers (CLI, calibration
256 /// pipeline, examples) compile without changes.
257 ///
258 /// ### What changed internally
259 ///
260 /// v0.2.0 appended metadata to initializer names (e.g. `conv1.weight` →
261 /// `conv1.weight__qINT8_s0.001_z-3_len9408`) without updating the nodes that
262 /// reference them. ONNX Runtime rejected these models on load.
263 ///
264 /// v0.3.0 inserts a `DequantizeLinear` node per weight. The node's output
265 /// carries the **original** name, so every downstream node is unchanged.
266 /// Graph connectivity is preserved by construction, and the resulting model
267 /// loads and runs in ONNX Runtime.
268 ///
269 /// ### INT4 storage note
270 ///
271 /// `DequantizeLinear` requires INT8 input (opset < 21). INT4-quantized values
272 /// ([-8, 7]) are stored as INT8 bytes. Quantization *accuracy* is still
273 /// INT4-level; only the on-disk size is 4× instead of the 8× that bit-packing
274 /// would give. True INT4 packing is a v0.4.0 target.
275 pub fn save_quantized(
276 &mut self,
277 quantized_data: &[graph_builder::QdqWeightInput],
278 path: impl AsRef<std::path::Path>,
279 ) -> Result<()> {
280 let path = path.as_ref();
281 use graph_builder::{apply_qdq_transform, ensure_opset_version};
282
283 // --- 1. Opset: ≥10 for per-tensor DequantizeLinear, ≥13 for per-channel ---
284 let needs_per_channel = quantized_data.iter().any(|w| w.axis.is_some());
285 let min_opset = if needs_per_channel { 13 } else { 10 };
286 ensure_opset_version(&mut self.proto, min_opset);
287
288 // --- 2. Persist per-weight bits in model metadata ---
289 for inp in quantized_data.iter() {
290 self.proto.metadata_props.push(StringStringEntryProto {
291 key: format!("quantize_rs.bits.{}", inp.original_name),
292 value: inp.bits.to_string(),
293 });
294 }
295
296 // --- 3. Apply QDQ transform to the graph ---
297 let graph = self
298 .proto
299 .graph
300 .as_mut()
301 .ok_or_else(|| QuantizeError::ModelSave {
302 path: path.to_path_buf(),
303 reason: "Model has no graph".to_string(),
304 })?;
305 apply_qdq_transform(graph, quantized_data)?;
306
307 // --- 4. Encode and write to disk ---
308 let mut buf = Vec::new();
309 self.proto
310 .encode(&mut buf)
311 .map_err(|e| QuantizeError::ModelSave {
312 path: path.to_path_buf(),
313 reason: format!("Failed to encode ONNX model: {e}"),
314 })?;
315
316 let mut file = std::fs::File::create(path).map_err(|e| QuantizeError::ModelSave {
317 path: path.to_path_buf(),
318 reason: format!("Failed to create output file: {e}"),
319 })?;
320
321 file.write_all(&buf).map_err(|e| QuantizeError::ModelSave {
322 path: path.to_path_buf(),
323 reason: format!("Failed to write ONNX model: {e}"),
324 })?;
325
326 Ok(())
327 }
328}
329
330// ===========================================================================
331// OnnxModel — validation
332// ===========================================================================
333
334impl OnnxModel {
335 /// Check that every node input in the graph resolves to a known tensor.
336 ///
337 /// A "known tensor" is one of:
338 /// - a declared graph input
339 /// - an initializer
340 /// - the output of a node appearing earlier in the node list
341 ///
342 /// This is the exact check ONNX Runtime performs on load. It's the check
343 /// that v0.2.0's `validate` command skipped, which is why the rename bug
344 /// went undetected. Integrate `report.summary()` into the CLI validate
345 /// output alongside the existing structure / weight checks.
346 pub fn validate_connectivity(&self) -> ConnectivityReport {
347 match &self.proto.graph {
348 Some(graph) => graph_builder::validate_graph_connectivity(graph),
349 None => {
350 use crate::onnx_proto::GraphProto;
351 graph_builder::validate_graph_connectivity(&GraphProto::default())
352 }
353 }
354 }
355}
356
357// ===========================================================================
358// OnnxModel — quantized model introspection (v0.3.0 QDQ format)
359// ===========================================================================
360
361impl OnnxModel {
362 /// Extract metadata about quantized weights from a QDQ-format model.
363 ///
364 /// Looks for initializer triples:
365 /// `{base}_quantized`, `{base}_scale`, `{base}_zp`
366 ///
367 /// Scale and zero-point values are read directly from the tensors.
368 /// Bit-width comes from `metadata_props` (written by `save_quantized`);
369 /// defaults to 8 if the metadata entry is missing.
370 pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
371 let graph = match &self.proto.graph {
372 Some(g) => g,
373 None => return Vec::new(),
374 };
375
376 let mut scale_map: std::collections::HashMap<String, f32> =
377 std::collections::HashMap::new();
378 let mut zp_map: std::collections::HashMap<String, i8> = std::collections::HashMap::new();
379 let mut quant_bases: Vec<String> = Vec::new();
380
381 for init in &graph.initializer {
382 let name = &init.name;
383
384 if let Some(base) = name.strip_suffix("_scale") {
385 // Scale is stored in float_data (rank-0 scalar)
386 let scale = if !init.float_data.is_empty() {
387 init.float_data[0]
388 } else if init.raw_data.len() >= 4 {
389 // Fallback: try raw_data as little-endian f32
390 f32::from_le_bytes([
391 init.raw_data[0],
392 init.raw_data[1],
393 init.raw_data[2],
394 init.raw_data[3],
395 ])
396 } else {
397 1.0
398 };
399 scale_map.insert(base.to_string(), scale);
400 } else if let Some(base) = name.strip_suffix("_zp") {
401 // Zero-point is a single raw byte
402 let zp = if !init.raw_data.is_empty() {
403 init.raw_data[0] as i8
404 } else {
405 0
406 };
407 zp_map.insert(base.to_string(), zp);
408 } else if let Some(base) = name.strip_suffix("_quantized") {
409 quant_bases.push(base.to_string());
410 }
411 }
412
413 // Read bits from metadata_props (written by save_quantized)
414 let mut bits_map: std::collections::HashMap<String, u8> = std::collections::HashMap::new();
415 for prop in &self.proto.metadata_props {
416 if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
417 if let Ok(bits) = prop.value.parse::<u8>() {
418 bits_map.insert(base.to_string(), bits);
419 }
420 }
421 }
422
423 // Assemble QuantizedWeightInfo from the three maps
424 quant_bases
425 .iter()
426 .map(|base| {
427 let scale = scale_map.get(base).copied().unwrap_or(1.0);
428 let zp = zp_map.get(base).copied().unwrap_or(0);
429 let bits = bits_map.get(base).copied().unwrap_or(8);
430
431 // Element count = product of dims on the _quantized tensor
432 let original_length = graph
433 .initializer
434 .iter()
435 .find(|i| i.name == format!("{}_quantized", base))
436 .map(|i| i.dims.iter().product::<i64>() as usize)
437 .unwrap_or(0);
438
439 QuantizedWeightInfo {
440 name: base.clone(),
441 bits,
442 scale,
443 zero_point: zp,
444 original_length,
445 }
446 })
447 .collect()
448 }
449}
450
451// ===========================================================================
452// WeightTensor (unchanged from v0.2.0)
453// ===========================================================================
454
455/// An FP32 weight tensor extracted from an ONNX model.
456#[derive(Debug, Clone)]
457pub struct WeightTensor {
458 /// Initializer name in the ONNX graph.
459 pub name: String,
460 /// FP32 weight values.
461 pub data: Vec<f32>,
462 /// Tensor dimensions.
463 pub shape: Vec<usize>,
464}
465
466impl WeightTensor {
467 /// Size of this tensor in bytes (as FP32).
468 pub fn size_bytes(&self) -> usize {
469 self.data.len() * std::mem::size_of::<f32>()
470 }
471
472 /// Total number of scalar elements.
473 pub fn num_elements(&self) -> usize {
474 self.data.len()
475 }
476}