quantize_rs/onnx_utils/mod.rs
1// src/onnx_utils/mod.rs
2//! ONNX model utilities — loading, weight extraction, quantized save (QDQ),
3//! graph connectivity validation, and quantized-model introspection.
4
5pub mod graph_builder;
6pub mod quantization_nodes;
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{
10 ModelProto, StringStringEntryProto,
11 tensor_proto, tensor_shape_proto, type_proto,
12};
13use prost::Message;
14use std::fs;
15use std::io::{Read, Write};
16
17// Re-export so callers don't have to reach into submodules
18pub use graph_builder::ConnectivityReport;
19
20// ===========================================================================
21// Core types
22// ===========================================================================
23
24/// An ONNX model loaded from a protobuf file.
25///
26/// Provides methods for inspecting, extracting weights, saving quantized
27/// models, and validating graph connectivity.
28pub struct OnnxModel {
29 proto: ModelProto,
30}
31
32impl std::fmt::Debug for OnnxModel {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 let name = self.proto.graph.as_ref().map(|g| g.name.as_str()).unwrap_or("");
35 let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
36 f.debug_struct("OnnxModel")
37 .field("name", &name)
38 .field("num_nodes", &num_nodes)
39 .finish()
40 }
41}
42
43/// Summary of an ONNX model's structure.
44#[derive(Debug)]
45pub struct ModelInfo {
46 /// Graph name from the protobuf.
47 pub name: String,
48 /// Model version from the protobuf.
49 pub version: i64,
50 /// Number of computation nodes in the graph.
51 pub num_nodes: usize,
52 /// Names of the graph inputs.
53 pub inputs: Vec<String>,
54 /// Names of the graph outputs.
55 pub outputs: Vec<String>,
56}
57
58/// Metadata about a quantized weight recovered from a QDQ-format model.
59#[derive(Debug, Clone)]
60pub struct QuantizedWeightInfo {
61 /// Original weight name (without `_quantized` suffix).
62 pub name: String,
63 /// Quantization bit width (4 or 8).
64 pub bits: u8,
65 /// Quantization scale factor.
66 pub scale: f32,
67 /// Quantization zero point.
68 pub zero_point: i8,
69 /// Number of elements in the quantized tensor.
70 pub original_length: usize,
71}
72
73// ===========================================================================
74// OnnxModel — load / inspect
75// ===========================================================================
76
77impl OnnxModel {
78 /// Load an ONNX model from a file path.
79 ///
80 /// # Errors
81 ///
82 /// Returns [`QuantizeError::ModelLoad`] if the file cannot be opened,
83 /// is too large (>10 GB), or contains invalid protobuf data.
84 pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
85 let path = path.as_ref();
86 let mut file =
87 fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
88 path: path.to_path_buf(),
89 reason: format!("Failed to open ONNX file: {e}"),
90 })?;
91
92 const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; // 10 GB
93 let file_size = file.metadata()
94 .map_err(|e| QuantizeError::ModelLoad {
95 path: path.to_path_buf(),
96 reason: format!("Failed to read metadata: {e}"),
97 })?
98 .len();
99 if file_size > MAX_MODEL_SIZE {
100 return Err(QuantizeError::ModelLoad {
101 path: path.to_path_buf(),
102 reason: format!(
103 "Model file too large: {:.2} GB (max: 10 GB)",
104 file_size as f64 / (1024.0 * 1024.0 * 1024.0)
105 ),
106 });
107 }
108
109 let mut buffer = Vec::with_capacity(file_size as usize);
110 file.read_to_end(&mut buffer)
111 .map_err(|e| QuantizeError::ModelLoad {
112 path: path.to_path_buf(),
113 reason: format!("Failed to read ONNX file: {e}"),
114 })?;
115
116 let proto = ModelProto::decode(&buffer[..])
117 .map_err(|e| QuantizeError::ModelLoad {
118 path: path.to_path_buf(),
119 reason: format!("Failed to parse ONNX protobuf: {e}"),
120 })?;
121
122 Ok(Self { proto })
123 }
124
125 /// Return a summary of the model's structure.
126 pub fn info(&self) -> ModelInfo {
127 let graph = self.proto.graph.as_ref();
128
129 let inputs: Vec<String> = graph
130 .map(|g| g.input.iter().map(|i| i.name.clone()).collect())
131 .unwrap_or_default();
132
133 let outputs: Vec<String> = graph
134 .map(|g| g.output.iter().map(|o| o.name.clone()).collect())
135 .unwrap_or_default();
136
137 ModelInfo {
138 name: graph.map(|g| g.name.clone()).unwrap_or_default(),
139 version: self.proto.model_version,
140 num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
141 inputs,
142 outputs,
143 }
144 }
145
146 /// Return the shapes of each graph input from the protobuf type info.
147 ///
148 /// Each inner `Vec<i64>` contains the dimension values. Dynamic dims
149 /// (symbolic or missing) are returned as -1. Returns one entry per
150 /// `graph.input` that has tensor type information.
151 pub fn input_shapes(&self) -> Vec<Vec<i64>> {
152 let graph = match &self.proto.graph {
153 Some(g) => g,
154 None => return Vec::new(),
155 };
156
157 let mut shapes = Vec::new();
158 for inp in &graph.input {
159 if let Some(type_proto) = &inp.r#type {
160 if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
161 if let Some(shape) = &tensor_type.shape {
162 let dims: Vec<i64> = shape.dim.iter().map(|d| {
163 match &d.value {
164 Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
165 _ => -1,
166 }
167 }).collect();
168 shapes.push(dims);
169 }
170 }
171 }
172 }
173 shapes
174 }
175
176 /// Extract all FP32 weight tensors from the model's initializers.
177 pub fn extract_weights(&self) -> Vec<WeightTensor> {
178 let graph = match &self.proto.graph {
179 Some(g) => g,
180 None => return Vec::new(),
181 };
182
183 let mut weights = Vec::new();
184 for initializer in &graph.initializer {
185 // Only extract FP32 tensors — skip INT8, INT64, DOUBLE, etc.
186 if initializer.data_type != tensor_proto::DataType::Float as i32 {
187 continue;
188 }
189
190 let name = initializer.name.clone();
191
192 let shape: Vec<usize> = initializer.dims.iter()
193 .map(|&d| d.max(0) as usize)
194 .collect();
195
196 let data = if !initializer.raw_data.is_empty() {
197 initializer.raw_data.chunks_exact(4)
198 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
199 .collect()
200 } else {
201 initializer.float_data.clone()
202 };
203
204 if !data.is_empty() {
205 weights.push(WeightTensor { name, data, shape });
206 }
207 }
208
209 weights
210 }
211
212 /// Total size of all weight tensors in bytes (float32).
213 ///
214 /// Prefer computing this from already-extracted weights when available:
215 /// `weights.iter().map(|w| w.size_bytes()).sum()` avoids reparsing.
216 pub fn total_size_bytes(&self) -> usize {
217 let graph = match &self.proto.graph {
218 Some(g) => g,
219 None => return 0,
220 };
221 graph.initializer.iter()
222 .map(|init| {
223 if !init.raw_data.is_empty() {
224 init.raw_data.len()
225 } else {
226 init.float_data.len() * std::mem::size_of::<f32>()
227 }
228 })
229 .sum()
230 }
231}
232
233// ===========================================================================
234// OnnxModel — quantized save (QDQ pattern, v0.3.0+)
235// ===========================================================================
236
237impl OnnxModel {
238 /// Save a quantized model using the QDQ (DequantizeLinear) pattern.
239 ///
240 /// **Signature is identical to v0.2.0** — existing callers (CLI, calibration
241 /// pipeline, examples) compile without changes.
242 ///
243 /// ### What changed internally
244 ///
245 /// v0.2.0 appended metadata to initializer names (e.g. `conv1.weight` →
246 /// `conv1.weight__qINT8_s0.001_z-3_len9408`) without updating the nodes that
247 /// reference them. ONNX Runtime rejected these models on load.
248 ///
249 /// v0.3.0 inserts a `DequantizeLinear` node per weight. The node's output
250 /// carries the **original** name, so every downstream node is unchanged.
251 /// Graph connectivity is preserved by construction, and the resulting model
252 /// loads and runs in ONNX Runtime.
253 ///
254 /// ### INT4 storage note
255 ///
256 /// `DequantizeLinear` requires INT8 input (opset < 21). INT4-quantized values
257 /// ([-8, 7]) are stored as INT8 bytes. Quantization *accuracy* is still
258 /// INT4-level; only the on-disk size is 4× instead of the 8× that bit-packing
259 /// would give. True INT4 packing is a v0.4.0 target.
260 pub fn save_quantized(
261 &mut self,
262 quantized_data: &[graph_builder::QdqWeightInput],
263 path: impl AsRef<std::path::Path>,
264 ) -> Result<()> {
265 let path = path.as_ref();
266 use graph_builder::{apply_qdq_transform, ensure_opset_version};
267
268 // --- 1. Opset ≥ 13 (DequantizeLinear per-channel needs it) ---
269 ensure_opset_version(&mut self.proto, 13);
270
271 // --- 2. Persist per-weight bits in model metadata ---
272 for inp in quantized_data.iter() {
273 self.proto.metadata_props.push(StringStringEntryProto {
274 key: format!("quantize_rs.bits.{}", inp.original_name),
275 value: inp.bits.to_string(),
276 });
277 }
278
279 // --- 3. Apply QDQ transform to the graph ---
280 let graph = self.proto.graph.as_mut().ok_or_else(|| QuantizeError::ModelSave {
281 path: path.to_path_buf(),
282 reason: "Model has no graph".to_string(),
283 })?;
284 apply_qdq_transform(graph, quantized_data)?;
285
286 // --- 4. Encode and write to disk ---
287 let mut buf = Vec::new();
288 self.proto.encode(&mut buf)
289 .map_err(|e| QuantizeError::ModelSave {
290 path: path.to_path_buf(),
291 reason: format!("Failed to encode ONNX model: {e}"),
292 })?;
293
294 let mut file = std::fs::File::create(path)
295 .map_err(|e| QuantizeError::ModelSave {
296 path: path.to_path_buf(),
297 reason: format!("Failed to create output file: {e}"),
298 })?;
299
300 file.write_all(&buf)
301 .map_err(|e| QuantizeError::ModelSave {
302 path: path.to_path_buf(),
303 reason: format!("Failed to write ONNX model: {e}"),
304 })?;
305
306 Ok(())
307 }
308}
309
310// ===========================================================================
311// OnnxModel — validation
312// ===========================================================================
313
314impl OnnxModel {
315 /// Check that every node input in the graph resolves to a known tensor.
316 ///
317 /// A "known tensor" is one of:
318 /// - a declared graph input
319 /// - an initializer
320 /// - the output of a node appearing earlier in the node list
321 ///
322 /// This is the exact check ONNX Runtime performs on load. It's the check
323 /// that v0.2.0's `validate` command skipped, which is why the rename bug
324 /// went undetected. Integrate `report.summary()` into the CLI validate
325 /// output alongside the existing structure / weight checks.
326 pub fn validate_connectivity(&self) -> ConnectivityReport {
327 match &self.proto.graph {
328 Some(graph) => graph_builder::validate_graph_connectivity(graph),
329 None => {
330 use crate::onnx_proto::GraphProto;
331 graph_builder::validate_graph_connectivity(&GraphProto::default())
332 }
333 }
334 }
335}
336
337// ===========================================================================
338// OnnxModel — quantized model introspection (v0.3.0 QDQ format)
339// ===========================================================================
340
341impl OnnxModel {
342 /// Extract metadata about quantized weights from a QDQ-format model.
343 ///
344 /// Looks for initializer triples:
345 /// `{base}_quantized`, `{base}_scale`, `{base}_zp`
346 ///
347 /// Scale and zero-point values are read directly from the tensors.
348 /// Bit-width comes from `metadata_props` (written by `save_quantized`);
349 /// defaults to 8 if the metadata entry is missing.
350 pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
351 let graph = match &self.proto.graph {
352 Some(g) => g,
353 None => return Vec::new(),
354 };
355
356 let mut scale_map: std::collections::HashMap<String, f32> =
357 std::collections::HashMap::new();
358 let mut zp_map: std::collections::HashMap<String, i8> =
359 std::collections::HashMap::new();
360 let mut quant_bases: Vec<String> = Vec::new();
361
362 for init in &graph.initializer {
363 let name = &init.name;
364
365 if let Some(base) = name.strip_suffix("_scale") {
366 // Scale is stored in float_data (rank-0 scalar)
367 let scale = if !init.float_data.is_empty() {
368 init.float_data[0]
369 } else if init.raw_data.len() >= 4 {
370 // Fallback: try raw_data as little-endian f32
371 f32::from_le_bytes([
372 init.raw_data[0], init.raw_data[1],
373 init.raw_data[2], init.raw_data[3],
374 ])
375 } else {
376 1.0
377 };
378 scale_map.insert(base.to_string(), scale);
379 } else if let Some(base) = name.strip_suffix("_zp") {
380 // Zero-point is a single raw byte
381 let zp = if !init.raw_data.is_empty() {
382 init.raw_data[0] as i8
383 } else {
384 0
385 };
386 zp_map.insert(base.to_string(), zp);
387 } else if let Some(base) = name.strip_suffix("_quantized") {
388 quant_bases.push(base.to_string());
389 }
390 }
391
392 // Read bits from metadata_props (written by save_quantized)
393 let mut bits_map: std::collections::HashMap<String, u8> =
394 std::collections::HashMap::new();
395 for prop in &self.proto.metadata_props {
396 if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
397 if let Ok(bits) = prop.value.parse::<u8>() {
398 bits_map.insert(base.to_string(), bits);
399 }
400 }
401 }
402
403 // Assemble QuantizedWeightInfo from the three maps
404 quant_bases
405 .iter()
406 .map(|base| {
407 let scale = scale_map.get(base).copied().unwrap_or(1.0);
408 let zp = zp_map.get(base).copied().unwrap_or(0);
409 let bits = bits_map.get(base).copied().unwrap_or(8);
410
411 // Element count = product of dims on the _quantized tensor
412 let original_length = graph.initializer.iter()
413 .find(|i| i.name == format!("{}_quantized", base))
414 .map(|i| i.dims.iter().product::<i64>() as usize)
415 .unwrap_or(0);
416
417 QuantizedWeightInfo {
418 name: base.clone(),
419 bits,
420 scale,
421 zero_point: zp,
422 original_length,
423 }
424 })
425 .collect()
426 }
427}
428
429// ===========================================================================
430// WeightTensor (unchanged from v0.2.0)
431// ===========================================================================
432
433/// An FP32 weight tensor extracted from an ONNX model.
434#[derive(Debug, Clone)]
435pub struct WeightTensor {
436 /// Initializer name in the ONNX graph.
437 pub name: String,
438 /// FP32 weight values.
439 pub data: Vec<f32>,
440 /// Tensor dimensions.
441 pub shape: Vec<usize>,
442}
443
444impl WeightTensor {
445 /// Size of this tensor in bytes (as FP32).
446 pub fn size_bytes(&self) -> usize {
447 self.data.len() * std::mem::size_of::<f32>()
448 }
449
450 /// Total number of scalar elements.
451 pub fn num_elements(&self) -> usize {
452 self.data.len()
453 }
454}