use crate::errors::{QuantizeError, Result};
use crate::onnx_proto::{
attribute_proto, tensor_proto, AttributeProto, GraphProto, ModelProto, OperatorSetIdProto,
TensorProto,
};
use std::collections::{HashMap, HashSet};
use super::quantization_nodes::{
build_dequantize_linear_node, build_quantized_weight_tensor, build_scale_tensor,
build_zero_point_tensor, DequantLinearNames,
};
#[derive(Debug)]
pub struct QdqWeightInput {
pub original_name: String,
pub quantized_values: Vec<i8>,
pub scales: Vec<f32>,
pub zero_points: Vec<i8>,
pub bits: u8,
pub axis: Option<usize>,
}
#[derive(Debug)]
#[must_use]
pub struct ConnectivityReport {
pub valid: bool,
pub broken_refs: Vec<String>,
}
impl ConnectivityReport {
pub fn summary(&self) -> String {
if self.valid {
" Graph connectivity: OK\n".to_string()
} else {
let mut s = format!(
" Graph connectivity: BROKEN ({} dangling reference{})\n",
self.broken_refs.len(),
if self.broken_refs.len() == 1 { "" } else { "s" }
);
for (i, r) in self.broken_refs.iter().enumerate() {
s.push_str(&format!(" {}. {}\n", i + 1, r));
}
s
}
}
}
pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
let mut known: HashSet<String> = HashSet::new();
for inp in &graph.input {
known.insert(inp.name.clone());
}
for init in &graph.initializer {
known.insert(init.name.clone());
}
let mut broken = Vec::new();
for node in &graph.node {
for name in &node.input {
if name.is_empty() {
continue; }
if !known.contains(name.as_str()) {
broken.push(format!(
"Node '{}' (op={}) → unknown input '{}'",
node.name, node.op_type, name
));
}
}
for name in &node.output {
if !name.is_empty() {
known.insert(name.clone());
}
}
}
ConnectivityReport {
valid: broken.is_empty(),
broken_refs: broken,
}
}
pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
let old_version = get_opset_version(model);
let mut found = false;
for opset in model.opset_import.iter_mut() {
if opset.domain.is_empty() {
if opset.version < min_version {
opset.version = min_version;
}
found = true;
break;
}
}
if !found {
model.opset_import.push(OperatorSetIdProto {
domain: String::new(),
version: min_version,
});
}
if old_version < min_version {
if let Some(graph) = model.graph.as_mut() {
upgrade_deprecated_ops(graph, old_version, min_version);
}
}
}
fn get_opset_version(model: &ModelProto) -> i64 {
model
.opset_import
.iter()
.find(|o| o.domain.is_empty())
.map_or(0, |o| o.version)
}
fn upgrade_deprecated_ops(graph: &mut GraphProto, old_opset: i64, new_opset: i64) {
let mut new_initializers: Vec<TensorProto> = Vec::new();
for node in graph.node.iter_mut() {
if node.op_type == "BatchNormalization" && old_opset < 9 && new_opset >= 9 {
node.attribute.retain(|a| a.name != "spatial");
}
if node.op_type == "Dropout" && old_opset < 12 && new_opset >= 12 {
let ratio = node
.attribute
.iter()
.find(|a| a.name == "ratio")
.map(|a| a.f)
.unwrap_or(0.5);
node.attribute.retain(|a| a.name != "ratio");
let init_name = format!(
"_quantize_rs_dropout_ratio_{}",
node.output.first().map_or("", |s| s.as_str()),
);
new_initializers.push(TensorProto {
name: init_name.clone(),
data_type: tensor_proto::DataType::Float as i32,
float_data: vec![ratio],
..Default::default()
});
if node.input.len() < 2 {
node.input.push(init_name);
} else {
node.input[1] = init_name;
}
}
if (node.op_type == "Softmax" || node.op_type == "LogSoftmax")
&& old_opset < 13
&& new_opset >= 13
{
let has_axis = node.attribute.iter().any(|a| a.name == "axis");
if !has_axis {
node.attribute.push(AttributeProto {
name: "axis".to_string(),
r#type: attribute_proto::AttributeType::Int as i32,
i: 1, ..Default::default()
});
}
}
}
graph.initializer.extend(new_initializers);
}
pub fn apply_qdq_transform(graph: &mut GraphProto, inputs: &[QdqWeightInput]) -> Result<()> {
let shape_map: HashMap<String, Vec<i64>> = graph
.initializer
.iter()
.map(|init| (init.name.clone(), init.dims.clone()))
.collect();
let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
graph
.initializer
.retain(|init| !quant_set.contains(init.name.as_str()));
graph
.input
.retain(|inp| !quant_set.contains(inp.name.as_str()));
let mut dq_nodes = Vec::new();
for inp in inputs {
let shape =
shape_map
.get(&inp.original_name)
.ok_or_else(|| QuantizeError::GraphTransform {
reason: format!(
"Weight '{}' not found in model initializers — \
verify the name matches exactly",
inp.original_name
),
})?;
let expected_len: i64 = shape.iter().product();
if inp.quantized_values.len() as i64 != expected_len {
return Err(QuantizeError::GraphTransform {
reason: format!(
"Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
inp.original_name,
inp.quantized_values.len(),
shape,
expected_len
),
});
}
let names = DequantLinearNames::from_original(&inp.original_name);
graph.initializer.push(build_quantized_weight_tensor(
&names,
&inp.quantized_values,
shape,
));
graph
.initializer
.push(build_scale_tensor(&names, &inp.scales));
graph
.initializer
.push(build_zero_point_tensor(&names, &inp.zero_points));
dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
}
let existing_nodes = std::mem::take(&mut graph.node);
graph.node = dq_nodes;
graph.node.extend(existing_nodes);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::onnx_proto::{
tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
ValueInfoProto,
};
fn make_simple_graph() -> GraphProto {
GraphProto {
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
initializer: vec![TensorProto {
name: "w".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: vec![2, 2],
float_data: vec![1.0, 2.0, 3.0, 4.0],
..Default::default()
}],
node: vec![NodeProto {
op_type: "Conv".to_string(),
name: "conv0".to_string(),
input: vec!["input".to_string(), "w".to_string()],
output: vec!["out".to_string()],
..Default::default()
}],
..Default::default()
}
}
fn make_two_weight_graph() -> GraphProto {
GraphProto {
input: vec![ValueInfoProto {
name: "input".to_string(),
..Default::default()
}],
initializer: vec![
TensorProto {
name: "w1".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: vec![2, 2],
float_data: vec![1.0, 2.0, 3.0, 4.0],
..Default::default()
},
TensorProto {
name: "w2".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: vec![2, 2],
float_data: vec![5.0, 6.0, 7.0, 8.0],
..Default::default()
},
],
node: vec![
NodeProto {
op_type: "Conv".to_string(),
name: "conv1".to_string(),
input: vec!["input".to_string(), "w1".to_string()],
output: vec!["mid".to_string()],
..Default::default()
},
NodeProto {
op_type: "Conv".to_string(),
name: "conv2".to_string(),
input: vec!["mid".to_string(), "w2".to_string()],
output: vec!["out".to_string()],
..Default::default()
},
],
..Default::default()
}
}
#[test]
fn test_connectivity_passes_on_valid_graph() {
let graph = make_simple_graph();
let report = validate_graph_connectivity(&graph);
assert!(
report.valid,
"original graph should be valid; broken: {:?}",
report.broken_refs
);
}
#[test]
fn test_connectivity_detects_renamed_initializer() {
let mut graph = make_simple_graph();
for init in graph.initializer.iter_mut() {
if init.name == "w" {
init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
}
}
let report = validate_graph_connectivity(&graph);
assert!(!report.valid, "should detect broken reference to 'w'");
assert_eq!(report.broken_refs.len(), 1);
assert!(
report.broken_refs[0].contains("'w'"),
"error should mention 'w': {}",
report.broken_refs[0]
);
}
#[test]
fn test_connectivity_detects_multiple_broken_refs() {
let mut graph = make_two_weight_graph();
for init in graph.initializer.iter_mut() {
if init.name == "w1" {
init.name = "w1_broken".to_string();
} else if init.name == "w2" {
init.name = "w2_broken".to_string();
}
}
let report = validate_graph_connectivity(&graph);
assert!(!report.valid);
assert_eq!(report.broken_refs.len(), 2);
}
#[test]
fn test_connectivity_summary_formatting() {
let valid = ConnectivityReport {
valid: true,
broken_refs: vec![],
};
assert!(valid.summary().contains("OK"));
let broken = ConnectivityReport {
valid: false,
broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
};
let s = broken.summary();
assert!(s.contains("BROKEN"));
assert!(s.contains("1 dangling reference"));
assert!(s.contains("unknown input 'y'"));
}
#[test]
fn test_ensure_opset_bumps_low_version() {
let mut model = ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 10,
}],
..Default::default()
};
ensure_opset_version(&mut model, 13);
assert_eq!(model.opset_import[0].version, 13);
}
#[test]
fn test_ensure_opset_leaves_sufficient_version() {
let mut model = ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 17,
}],
..Default::default()
};
ensure_opset_version(&mut model, 13);
assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
}
#[test]
fn test_ensure_opset_adds_missing_default_domain() {
let mut model = ModelProto::default();
ensure_opset_version(&mut model, 13);
assert_eq!(model.opset_import.len(), 1);
assert!(model.opset_import[0].domain.is_empty());
assert_eq!(model.opset_import[0].version, 13);
}
#[test]
fn test_qdq_single_weight_produces_valid_graph() {
let mut graph = make_simple_graph();
let inputs = vec![QdqWeightInput {
original_name: "w".to_string(),
quantized_values: vec![25, 51, 76, 102],
scales: vec![0.039_215_686], zero_points: vec![0],
bits: 8,
axis: None,
}];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let report = validate_graph_connectivity(&graph);
assert!(
report.valid,
"graph after QDQ must be valid; broken: {:?}",
report.broken_refs
);
}
#[test]
fn test_qdq_adds_correct_initializers() {
let mut graph = make_simple_graph();
let inputs = vec![QdqWeightInput {
original_name: "w".to_string(),
quantized_values: vec![10, 20, 30, 40],
scales: vec![0.1],
zero_points: vec![-5],
bits: 8,
axis: None,
}];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
assert!(init_names.contains(&"w_scale"), "missing w_scale");
assert!(init_names.contains(&"w_zp"), "missing w_zp");
assert!(
!init_names.contains(&"w"),
"original FP32 'w' should be removed"
);
}
#[test]
fn test_qdq_node_order_dequant_first() {
let mut graph = make_simple_graph();
let inputs = vec![QdqWeightInput {
original_name: "w".to_string(),
quantized_values: vec![10, 20, 30, 40],
scales: vec![0.1],
zero_points: vec![0],
bits: 8,
axis: None,
}];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
assert_eq!(ops.len(), 2);
assert_eq!(ops[0], "DequantizeLinear");
assert_eq!(ops[1], "Conv");
}
#[test]
fn test_qdq_dequant_output_is_original_name() {
let mut graph = make_simple_graph();
let inputs = vec![QdqWeightInput {
original_name: "w".to_string(),
quantized_values: vec![1, 2, 3, 4],
scales: vec![1.0],
zero_points: vec![0],
bits: 8,
axis: None,
}];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let dq = &graph.node[0]; assert_eq!(
dq.output[0], "w",
"DequantizeLinear output must be original name"
);
}
#[test]
fn test_qdq_two_weights_both_transformed() {
let mut graph = make_two_weight_graph();
let inputs = vec![
QdqWeightInput {
original_name: "w1".to_string(),
quantized_values: vec![10, 20, 30, 40],
scales: vec![0.1],
zero_points: vec![0],
bits: 8,
axis: None,
},
QdqWeightInput {
original_name: "w2".to_string(),
quantized_values: vec![50, 60, 70, 80],
scales: vec![0.2],
zero_points: vec![-1],
bits: 8,
axis: None,
},
];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let report = validate_graph_connectivity(&graph);
assert!(
report.valid,
"two-weight graph broken: {:?}",
report.broken_refs
);
assert_eq!(graph.node.len(), 4);
assert_eq!(graph.node[0].op_type, "DequantizeLinear");
assert_eq!(graph.node[1].op_type, "DequantizeLinear");
let dq_outputs: Vec<&str> = graph
.node
.iter()
.take(2)
.map(|n| n.output[0].as_str())
.collect();
assert!(dq_outputs.contains(&"w1"));
assert!(dq_outputs.contains(&"w2"));
}
#[test]
fn test_qdq_int4_values_stored_as_int8() {
let mut graph = make_simple_graph();
let inputs = vec![QdqWeightInput {
original_name: "w".to_string(),
quantized_values: vec![-8, -1, 0, 7],
scales: vec![0.5],
zero_points: vec![0],
bits: 4, axis: None,
}];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let quant_init = graph
.initializer
.iter()
.find(|i| i.name == "w_quantized")
.expect("w_quantized not found");
assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
assert_eq!(recovered, vec![-8, -1, 0, 7]);
}
#[test]
fn test_qdq_unknown_weight_returns_error() {
let mut graph = make_simple_graph();
let inputs = vec![QdqWeightInput {
original_name: "does_not_exist".to_string(),
quantized_values: vec![1, 2, 3],
scales: vec![1.0],
zero_points: vec![0],
bits: 8,
axis: None,
}];
let result = apply_qdq_transform(&mut graph, &inputs);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("does_not_exist"),
"error should name the missing weight"
);
}
#[test]
fn test_qdq_non_quantized_initializers_preserved() {
let mut graph = make_simple_graph();
graph.initializer.push(TensorProto {
name: "bias".to_string(),
data_type: tensor_proto::DataType::Float as i32,
dims: vec![2],
float_data: vec![0.1, 0.2],
..Default::default()
});
graph.node[0].input.push("bias".to_string());
let inputs = vec![QdqWeightInput {
original_name: "w".to_string(),
quantized_values: vec![10, 20, 30, 40],
scales: vec![0.1],
zero_points: vec![0],
bits: 8,
axis: None,
}];
apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
assert!(
bias_init.is_some(),
"non-quantized 'bias' initializer must be preserved"
);
assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
let report = validate_graph_connectivity(&graph);
assert!(report.valid, "broken: {:?}", report.broken_refs);
}
#[test]
fn test_ensure_opset_strips_deprecated_attrs() {
use crate::onnx_proto::NodeProto;
let mut model = ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 7,
}],
graph: Some(GraphProto {
node: vec![
NodeProto {
op_type: "BatchNormalization".to_string(),
input: vec!["x".into(), "s".into(), "b".into(), "m".into(), "v".into()],
output: vec!["bn_out".into()],
attribute: vec![
AttributeProto {
name: "epsilon".to_string(),
r#type: attribute_proto::AttributeType::Float as i32,
f: 1e-5,
..Default::default()
},
AttributeProto {
name: "spatial".to_string(),
r#type: attribute_proto::AttributeType::Int as i32,
i: 1,
..Default::default()
},
],
..Default::default()
},
NodeProto {
op_type: "Dropout".to_string(),
input: vec!["bn_out".into()],
output: vec!["drop_out".into(), "drop_mask".into()],
attribute: vec![AttributeProto {
name: "ratio".to_string(),
r#type: attribute_proto::AttributeType::Float as i32,
f: 0.3,
..Default::default()
}],
..Default::default()
},
NodeProto {
op_type: "Softmax".to_string(),
input: vec!["drop_out".into()],
output: vec!["sm_out".into()],
attribute: vec![],
..Default::default()
},
],
..Default::default()
}),
..Default::default()
};
ensure_opset_version(&mut model, 13);
let opset = model
.opset_import
.iter()
.find(|o| o.domain.is_empty())
.unwrap();
assert_eq!(opset.version, 13);
let graph = model.graph.as_ref().unwrap();
let bn = &graph.node[0];
assert!(
!bn.attribute.iter().any(|a| a.name == "spatial"),
"BatchNormalization.spatial should be stripped"
);
assert!(
bn.attribute.iter().any(|a| a.name == "epsilon"),
"BatchNormalization.epsilon should be preserved"
);
let drop = &graph.node[1];
assert!(
!drop.attribute.iter().any(|a| a.name == "ratio"),
"Dropout.ratio attribute should be removed"
);
assert_eq!(drop.input.len(), 2, "Dropout should now have 2 inputs");
let ratio_init_name = &drop.input[1];
let ratio_init = graph
.initializer
.iter()
.find(|i| &i.name == ratio_init_name)
.expect("Dropout ratio initializer should exist");
assert_eq!(ratio_init.data_type, tensor_proto::DataType::Float as i32);
assert!(
(ratio_init.float_data[0] - 0.3).abs() < 1e-6,
"ratio should be 0.3"
);
let sm = &graph.node[2];
assert_eq!(sm.op_type, "Softmax");
let axis_attr = sm
.attribute
.iter()
.find(|a| a.name == "axis")
.expect("Softmax should have axis attribute added");
assert_eq!(axis_attr.i, 1, "Softmax axis should be 1 (old default)");
}
#[test]
fn test_ensure_opset_no_downgrade() {
let mut model = ModelProto {
opset_import: vec![OperatorSetIdProto {
domain: String::new(),
version: 15,
}],
graph: Some(GraphProto::default()),
..Default::default()
};
ensure_opset_version(&mut model, 10);
let opset = model
.opset_import
.iter()
.find(|o| o.domain.is_empty())
.unwrap();
assert_eq!(opset.version, 15);
}
}