pub mod graph_builder;
pub mod quantization_nodes;
use crate::errors::{QuantizeError, Result};
use crate::onnx_proto::{
tensor_proto, tensor_shape_proto, type_proto, ModelProto, StringStringEntryProto,
};
use prost::Message;
use std::fs;
use std::io::{Read, Write};
pub use graph_builder::ConnectivityReport;
pub struct OnnxModel {
proto: ModelProto,
}
impl std::fmt::Debug for OnnxModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = self
.proto
.graph
.as_ref()
.map(|g| g.name.as_str())
.unwrap_or("");
let num_nodes = self.proto.graph.as_ref().map(|g| g.node.len()).unwrap_or(0);
f.debug_struct("OnnxModel")
.field("name", &name)
.field("num_nodes", &num_nodes)
.finish()
}
}
#[derive(Debug)]
pub struct ModelInfo {
pub name: String,
pub version: i64,
pub num_nodes: usize,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct QuantizedWeightInfo {
pub name: String,
pub bits: u8,
pub scale: f32,
pub zero_point: i8,
pub original_length: usize,
}
impl OnnxModel {
pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
let path = path.as_ref();
let mut file = fs::File::open(path).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to open ONNX file: {e}"),
})?;
const MAX_MODEL_SIZE: u64 = 10 * 1024 * 1024 * 1024; let file_size = file
.metadata()
.map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to read metadata: {e}"),
})?
.len();
if file_size > MAX_MODEL_SIZE {
return Err(QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!(
"Model file too large: {:.2} GB (max: 10 GB)",
file_size as f64 / (1024.0 * 1024.0 * 1024.0)
),
});
}
let mut buffer = Vec::with_capacity(file_size as usize);
file.read_to_end(&mut buffer)
.map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to read ONNX file: {e}"),
})?;
let proto = ModelProto::decode(&buffer[..]).map_err(|e| QuantizeError::ModelLoad {
path: path.to_path_buf(),
reason: format!("Failed to parse ONNX protobuf: {e}"),
})?;
Ok(Self { proto })
}
pub fn info(&self) -> ModelInfo {
let graph = self.proto.graph.as_ref();
let inputs: Vec<String> = graph
.map(|g| g.input.iter().map(|i| i.name.clone()).collect())
.unwrap_or_default();
let outputs: Vec<String> = graph
.map(|g| g.output.iter().map(|o| o.name.clone()).collect())
.unwrap_or_default();
ModelInfo {
name: graph.map(|g| g.name.clone()).unwrap_or_default(),
version: self.proto.model_version,
num_nodes: graph.map(|g| g.node.len()).unwrap_or(0),
inputs,
outputs,
}
}
pub fn input_shapes(&self) -> Vec<Vec<i64>> {
let graph = match &self.proto.graph {
Some(g) => g,
None => return Vec::new(),
};
let mut shapes = Vec::new();
for inp in &graph.input {
if let Some(type_proto) = &inp.r#type {
if let Some(type_proto::Value::TensorType(tensor_type)) = &type_proto.value {
if let Some(shape) = &tensor_type.shape {
let dims: Vec<i64> = shape
.dim
.iter()
.map(|d| match &d.value {
Some(tensor_shape_proto::dimension::Value::DimValue(v)) => *v,
_ => -1,
})
.collect();
shapes.push(dims);
}
}
}
}
shapes
}
pub fn extract_weights(&self) -> Vec<WeightTensor> {
let graph = match &self.proto.graph {
Some(g) => g,
None => return Vec::new(),
};
let mut weights = Vec::new();
for initializer in &graph.initializer {
if initializer.data_type != tensor_proto::DataType::Float as i32 {
continue;
}
let name = initializer.name.clone();
let shape: Vec<usize> = initializer
.dims
.iter()
.map(|&d| d.max(0) as usize)
.collect();
let data = if !initializer.raw_data.is_empty() {
if initializer.raw_data.len() % 4 != 0 {
continue;
}
initializer
.raw_data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
} else {
initializer.float_data.clone()
};
if !data.is_empty() {
weights.push(WeightTensor { name, data, shape });
}
}
weights
}
pub fn total_size_bytes(&self) -> usize {
let graph = match &self.proto.graph {
Some(g) => g,
None => return 0,
};
graph
.initializer
.iter()
.map(|init| {
if !init.raw_data.is_empty() {
init.raw_data.len()
} else {
init.float_data.len() * std::mem::size_of::<f32>()
}
})
.sum()
}
}
impl OnnxModel {
pub fn save_quantized(
&mut self,
quantized_data: &[graph_builder::QdqWeightInput],
path: impl AsRef<std::path::Path>,
) -> Result<()> {
let path = path.as_ref();
use graph_builder::{apply_qdq_transform, ensure_opset_version};
let needs_per_channel = quantized_data.iter().any(|w| w.axis.is_some());
let min_opset = if needs_per_channel { 13 } else { 10 };
ensure_opset_version(&mut self.proto, min_opset);
for inp in quantized_data.iter() {
self.proto.metadata_props.push(StringStringEntryProto {
key: format!("quantize_rs.bits.{}", inp.original_name),
value: inp.bits.to_string(),
});
}
let graph = self
.proto
.graph
.as_mut()
.ok_or_else(|| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: "Model has no graph".to_string(),
})?;
apply_qdq_transform(graph, quantized_data)?;
let mut buf = Vec::new();
self.proto
.encode(&mut buf)
.map_err(|e| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: format!("Failed to encode ONNX model: {e}"),
})?;
let mut file = std::fs::File::create(path).map_err(|e| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: format!("Failed to create output file: {e}"),
})?;
file.write_all(&buf).map_err(|e| QuantizeError::ModelSave {
path: path.to_path_buf(),
reason: format!("Failed to write ONNX model: {e}"),
})?;
Ok(())
}
}
impl OnnxModel {
pub fn validate_connectivity(&self) -> ConnectivityReport {
match &self.proto.graph {
Some(graph) => graph_builder::validate_graph_connectivity(graph),
None => {
use crate::onnx_proto::GraphProto;
graph_builder::validate_graph_connectivity(&GraphProto::default())
}
}
}
}
impl OnnxModel {
pub fn load_quantized_info(&self) -> Vec<QuantizedWeightInfo> {
let graph = match &self.proto.graph {
Some(g) => g,
None => return Vec::new(),
};
let mut scale_map: std::collections::HashMap<String, f32> =
std::collections::HashMap::new();
let mut zp_map: std::collections::HashMap<String, i8> = std::collections::HashMap::new();
let mut quant_bases: Vec<String> = Vec::new();
for init in &graph.initializer {
let name = &init.name;
if let Some(base) = name.strip_suffix("_scale") {
let scale = if !init.float_data.is_empty() {
init.float_data[0]
} else if init.raw_data.len() >= 4 {
f32::from_le_bytes([
init.raw_data[0],
init.raw_data[1],
init.raw_data[2],
init.raw_data[3],
])
} else {
1.0
};
scale_map.insert(base.to_string(), scale);
} else if let Some(base) = name.strip_suffix("_zp") {
let zp = if !init.raw_data.is_empty() {
init.raw_data[0] as i8
} else {
0
};
zp_map.insert(base.to_string(), zp);
} else if let Some(base) = name.strip_suffix("_quantized") {
quant_bases.push(base.to_string());
}
}
let mut bits_map: std::collections::HashMap<String, u8> = std::collections::HashMap::new();
for prop in &self.proto.metadata_props {
if let Some(base) = prop.key.strip_prefix("quantize_rs.bits.") {
if let Ok(bits) = prop.value.parse::<u8>() {
bits_map.insert(base.to_string(), bits);
}
}
}
quant_bases
.iter()
.map(|base| {
let scale = scale_map.get(base).copied().unwrap_or(1.0);
let zp = zp_map.get(base).copied().unwrap_or(0);
let bits = bits_map.get(base).copied().unwrap_or(8);
let original_length = graph
.initializer
.iter()
.find(|i| i.name == format!("{}_quantized", base))
.map(|i| i.dims.iter().product::<i64>() as usize)
.unwrap_or(0);
QuantizedWeightInfo {
name: base.clone(),
bits,
scale,
zero_point: zp,
original_length,
}
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct WeightTensor {
pub name: String,
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
impl WeightTensor {
pub fn size_bytes(&self) -> usize {
self.data.len() * std::mem::size_of::<f32>()
}
pub fn num_elements(&self) -> usize {
self.data.len()
}
}