pub mod graph_builder;
pub mod quantization_nodes;
use crate::errors::{QuantizeError, Result};
use crate::onnx_proto::{
tensor_proto, tensor_shape_proto, type_proto, ModelProto, StringStringEntryProto,
};
use prost::Message;
use std::fs;
use std::io::{Read, Write};
pub use graph_builder::{ConnectivityReport, SaveOptions};
pub struct OnnxModel {
proto: ModelProto,
}
impl std::fmt::Debug for OnnxModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = self
.proto
.graph
.as_ref()
.map(|g| g.name.as_str())
.unwrap_or("");
let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
f.debug_struct("OnnxModel")
.field("name", &name)
.field("num_nodes", &num_nodes)
.finish()
}
}
#[derive(Debug)]
pub struct ModelInfo {
pub name: String,
pub version: i64,
pub num_nodes: usize,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct QuantizedWeightInfo {
pub name: String,
pub bits: u8,
pub scales: Vec<f32>,
pub zero_points: Vec<i8>,
pub original_length: usize,
pub storage_bytes: usize,
}
impl QuantizedWeightInfo {
pub fn is_per_channel(&self) -> bool {
self.scales.len() > 1
}
pub fn scale(&self) -> f32 {
self.scales[0]
}
pub fn zero_point(&self) -> i8 {
self.zero_points[0]
}
}
impl OnnxModel {
pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
let path = path.as_ref();
let mut file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to open ONNX file: {e}"),
})?;
const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; let file_size = file
.metadata()
.map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to read metadata: {e}"),
})?
.len();
if file_size > MAX_MODEL_SIZE {
return Err(QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!(
"Model file too large: {:.2} GB (max: 10 GB)",
file_size as f64 / (1024.0 * 1024.0 * 1024.0)
),
});
}
let mut buffer = Vec::with_capacity(file_size as usize);
file.read_to_end(&mut buffer)
.map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to read ONNX file: {e}"),
})?;
let proto = ModelProto::decode(&buffer[..]).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to parse ONNX protobuf: {e}"),
})?;
Ok(Self { proto })
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let proto = ModelProto::decode(bytes).map_err(|e| QuantizeError::ModelLoad {
path: std::path::PathBuf::new(),
reason: format!("Failed to parse ONNX protobuf: {e}"),
})?;
Ok(Self { proto })
}
#[cfg(feature = "mmap")]
pub fn load_mmap(path: impl AsRef<std::path::Path>) -> Result<Self> {
let path = path.as_ref();
let file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to open ONNX file: {e}"),
})?;
const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; let file_size = file
.metadata()
.map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to read metadata: {e}"),
})?
.len();
if file_size > MAX_MODEL_SIZE {
return Err(QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!(
"Model file too large: {:.2} GB (max: 10 GB)",
file_size as f64 / (1024.0 * 1024.0 * 1024.0)
),
});
}
let mmap = unsafe {
memmap2::Mmap::map(&file).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to mmap ONNX file: {e}"),
})?
};
let proto = ModelProto::decode(&mmap[..]).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to parse ONNX protobuf: {e}"),
})?;
Ok(Self { proto })
}
pub fn info(&self) -> ModelInfo {
let graph = self.proto.graph.as_ref();
let inputs: Vec<String> = graph
.map(|g| g.input.iter().map(|i| i.name.clone()).collect())
.unwrap_or_default();
let outputs: Vec<String> = graph
.map(|g| g.output.iter().map(|o| o.name.clone()).collect())
.unwrap_or_default();
ModelInfo {
name: graph.map(|g| g.name.clone()).unwrap_or_default(),
version: self.proto.model_version,
num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
inputs,
outputs,
}
}
pub fn input_shapes(&self) -> Vec<Vec<i64>> {
let graph = match &self.proto.graph {
Some(g) => g,
None => return Vec::new(),
};
let mut shapes = Vec::new();
for inp in &graph.input {
if let Some(type_proto) = &inp.r#type {
if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
if let Some(shape) = &tensor_type.shape {
let dims: Vec<i64> = shape
.dim
.iter()
.map(|d| match &d.value {
Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
_ => -1,
})
.collect();
shapes.push(dims);
}
}
}
}
shapes
}
pub fn extract_weights(&self) -> Vec<WeightTensor> {
let graph = match &self.proto.graph {
Some(g) => g,
None => return Vec::new(),
};
let mut weights = Vec::new();
for initializer in &graph.initializer {
if initializer.data_type != tensor_proto::DataType::Float as i32 {
continue;
}
let name = initializer.name.clone();
let shape: Vec<usize> = initializer
.dims
.iter()
.map(|&d| d.max(0) as usize)
.collect();
let data = if !initializer.raw_data.is_empty() {
if initializer.raw_data.len() % 4 != 0 {
continue;
}
initializer
.raw_data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
} else {
initializer.float_data.clone()
};
if !data.is_empty() {
weights.push(WeightTensor { name, data, shape });
}
}
weights
}
pub fn total_size_bytes(&self) -> usize {
let graph = match &self.proto.graph {
Some(g) => g,
None => return 0,
};
graph
.initializer
.iter()
.map(|init| {
if !init.raw_data.is_empty() {
init.raw_data.len()
} else {
init.float_data.len() * std::mem::size_of::<f32>()
}
})
.sum()
}
}
impl OnnxModel {
pub fn save_quantized(
&mut self,
quantized_data: &[graph_builder::QdqWeightInput],
path: impl AsRef<std::path::Path>,
) -> Result<()> {
self.save_quantized_with_options(quantized_data, path, SaveOptions::default())
}
pub fn save_quantized_with_options(
&mut self,
quantized_data: &[graph_builder::QdqWeightInput],
path: impl AsRef<std::path::Path>,
options: SaveOptions,
) -> Result<()> {
let path = path.as_ref();
use graph_builder::{apply_qdq_transform_with_options, ensure_opset_version};
let needs_per_channel = quantized_data.iter().any(|w| w.axis.is_some());
let uses_native_int4 = options.native_int4 && quantized_data.iter().any(|w| w.bits == 4);
let min_opset = if uses_native_int4 {
21
} else if needs_per_channel {
13
} else {
10
};
ensure_opset_version(&mut self.proto, min_opset);
for inp in quantized_data.iter() {
self.proto.metadata_props.push(StringStringEntryProto {
key: format!("quantize_rs.bits.{}", inp.original_name),
value: inp.bits.to_string(),
});
}
let graph = self
.proto
.graph
.as_mut()
.ok_or_else(|| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: "Model has no graph".to_string(),
})?;
apply_qdq_transform_with_options(graph, quantized_data, options)?;
let mut buf = Vec::new();
self.proto
.encode(&mut buf)
.map_err(|e| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: format!("Failed to encode ONNX model: {e}"),
})?;
let mut file = std::fs::File::create(path).map_err(|e| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: format!("Failed to create output file: {e}"),
})?;
file.write_all(&buf).map_err(|e| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: format!("Failed to write ONNX model: {e}"),
})?;
Ok(())
}
}
impl OnnxModel {
pub fn validate_connectivity(&self) -> ConnectivityReport {
match &self.proto.graph {
Some(graph) => graph_builder::validate_graph_connectivity(graph),
None => {
use crate::onnx_proto::GraphProto;
graph_builder::validate_graph_connectivity(&GraphProto::default())
}
}
}
}
impl OnnxModel {
pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
let graph = match &self.proto.graph {
Some(g) => g,
None => return Vec::new(),
};
let mut scale_map: std::collections::HashMap<String, Vec<f32>> =
std::collections::HashMap::new();
let mut zp_map: std::collections::HashMap<String, Vec<i8>> =
std::collections::HashMap::new();
let mut quant_bases: Vec<String> = Vec::new();
for init in &graph.initializer {
let name = &init.name;
if let Some(base) = name.strip_suffix("_scale") {
scale_map.insert(base.to_string(), decode_scale_tensor(init));
} else if let Some(base) = name.strip_suffix("_zp") {
zp_map.insert(base.to_string(), decode_zero_point_tensor(init));
} else if let Some(base) = name.strip_suffix("_quantized") {
quant_bases.push(base.to_string());
}
}
let mut bits_map: std::collections::HashMap<String, u8> = std::collections::HashMap::new();
for prop in &self.proto.metadata_props {
if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
if let Ok(bits) = prop.value.parse::<u8>() {
bits_map.insert(base.to_string(), bits);
}
}
}
quant_bases
.iter()
.map(|base| {
let scales = scale_map.get(base).cloned().unwrap_or_else(|| vec![1.0]);
let zero_points = zp_map.get(base).cloned().unwrap_or_else(|| vec![0]);
let bits = bits_map.get(base).copied().unwrap_or(8);
let quant_init = graph
.initializer
.iter()
.find(|i| i.name == format!("{}_quantized", base));
let original_length = quant_init
.map(|i| i.dims.iter().product::<i64>() as usize)
.unwrap_or(0);
let storage_bytes = quant_init.map(|i| i.raw_data.len()).unwrap_or(0);
QuantizedWeightInfo {
name: base.clone(),
bits,
scales,
zero_points,
original_length,
storage_bytes,
}
})
.collect()
}
}
fn expected_element_count(init: &crate::onnx_proto::TensorProto) -> usize {
if init.dims.is_empty() {
1
} else {
init.dims
.iter()
.copied()
.filter(|&d| d > 0)
.product::<i64>() as usize
}
}
fn decode_scale_tensor(init: &crate::onnx_proto::TensorProto) -> Vec<f32> {
let expected = expected_element_count(init).max(1);
if !init.float_data.is_empty() {
return init.float_data.clone();
}
if !init.raw_data.is_empty() && init.raw_data.len() >= 4 * expected {
return init
.raw_data
.chunks_exact(4)
.take(expected)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
}
vec![1.0; expected]
}
fn decode_zero_point_tensor(init: &crate::onnx_proto::TensorProto) -> Vec<i8> {
use crate::onnx_proto::tensor_proto::DataType;
use crate::onnx_utils::quantization_nodes::unpack_int4_onnx;
let expected = expected_element_count(init).max(1);
if init.data_type == DataType::Int4 as i32 {
return unpack_int4_onnx(&init.raw_data, expected);
}
if !init.raw_data.is_empty() {
return init
.raw_data
.iter()
.take(expected)
.map(|&b| b as i8)
.collect();
}
if !init.int32_data.is_empty() {
return init
.int32_data
.iter()
.take(expected)
.map(|&v| v as i8)
.collect();
}
vec![0; expected]
}
#[derive(Debug, Clone)]
pub struct WeightTensor {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl WeightTensor {
pub fn size_bytes(&self) -> usize {
self.data.len() * std::mem::size_of::<f32>()
}
pub fn num_elements(&self) -> usize {
self.data.len()
}
}