1use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{GraphProto, ModelProto, OperatorSetIdProto};
10use std::collections::{HashMap, HashSet};
11
12use super::quantization_nodes::{
13 DequantLinearNames,
14 build_dequantize_linear_node,
15 build_quantized_weight_tensor,
16 build_scale_tensor,
17 build_zero_point_tensor,
18};
19
20#[derive(Debug)]
26pub struct QdqWeightInput {
27 pub original_name: String,
29 pub quantized_values: Vec<i8>,
32 pub scales: Vec<f32>,
35 pub zero_points: Vec<i8>,
38 pub bits: u8,
41 pub axis: Option<usize>,
43}
44
45#[derive(Debug)]
47#[must_use]
48pub struct ConnectivityReport {
49 pub valid: bool,
51 pub broken_refs: Vec<String>,
53}
54
55impl ConnectivityReport {
56 pub fn summary(&self) -> String {
58 if self.valid {
59 " Graph connectivity: OK\n".to_string()
60 } else {
61 let mut s = format!(
62 " Graph connectivity: BROKEN ({} dangling reference{})\n",
63 self.broken_refs.len(),
64 if self.broken_refs.len() == 1 { "" } else { "s" }
65 );
66 for (i, r) in self.broken_refs.iter().enumerate() {
67 s.push_str(&format!(" {}. {}\n", i + 1, r));
68 }
69 s
70 }
71 }
72}
73
74pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
88 let mut known: HashSet<String> = HashSet::new();
89
90 for inp in &graph.input {
92 known.insert(inp.name.clone());
93 }
94 for init in &graph.initializer {
95 known.insert(init.name.clone());
96 }
97
98 let mut broken = Vec::new();
99
100 for node in &graph.node {
102 for name in &node.input {
103 if name.is_empty() {
104 continue; }
106 if !known.contains(name.as_str()) {
107 broken.push(format!(
108 "Node '{}' (op={}) → unknown input '{}'",
109 node.name, node.op_type, name
110 ));
111 }
112 }
113 for name in &node.output {
115 if !name.is_empty() {
116 known.insert(name.clone());
117 }
118 }
119 }
120
121 ConnectivityReport {
122 valid: broken.is_empty(),
123 broken_refs: broken,
124 }
125}
126
127pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
136 for opset in model.opset_import.iter_mut() {
138 if opset.domain.is_empty() {
139 if opset.version < min_version {
140 opset.version = min_version;
141 }
142 return; }
144 }
145
146 model.opset_import.push(OperatorSetIdProto {
148 domain: String::new(), version: min_version,
150 });
151}
152
153pub fn apply_qdq_transform(
186 graph: &mut GraphProto,
187 inputs: &[QdqWeightInput],
188) -> Result<()> {
189 let shape_map: HashMap<String, Vec<i64>> = graph
193 .initializer
194 .iter()
195 .map(|init| (init.name.clone(), init.dims.clone()))
196 .collect();
197
198 let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
199
200 graph.initializer.retain(|init| !quant_set.contains(init.name.as_str()));
204
205 graph.input.retain(|inp| !quant_set.contains(inp.name.as_str()));
212
213 let mut dq_nodes = Vec::new();
217
218 for inp in inputs {
219 let shape = shape_map
220 .get(&inp.original_name)
221 .ok_or_else(|| {
222 QuantizeError::GraphTransform {
223 reason: format!(
224 "Weight '{}' not found in model initializers — \
225 verify the name matches exactly",
226 inp.original_name
227 ),
228 }
229 })?;
230
231 let expected_len: i64 = shape.iter().product();
232 if inp.quantized_values.len() as i64 != expected_len {
233 return Err(QuantizeError::GraphTransform {
234 reason: format!(
235 "Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
236 inp.original_name, inp.quantized_values.len(), shape, expected_len
237 ),
238 });
239 }
240
241 let names = DequantLinearNames::from_original(&inp.original_name);
242
243 graph.initializer.push(
244 build_quantized_weight_tensor(&names, &inp.quantized_values, shape),
245 );
246 graph.initializer.push(
247 build_scale_tensor(&names, &inp.scales),
248 );
249 graph.initializer.push(
250 build_zero_point_tensor(&names, &inp.zero_points),
251 );
252
253 dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
254 }
255
256 let existing_nodes = std::mem::take(&mut graph.node);
262 graph.node = dq_nodes;
263 graph.node.extend(existing_nodes);
264
265 Ok(())
266}
267
268#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::onnx_proto::{
276 GraphProto, ModelProto, NodeProto, OperatorSetIdProto,
277 TensorProto, ValueInfoProto, tensor_proto,
278 };
279
280 fn make_simple_graph() -> GraphProto {
287 GraphProto {
288 input: vec![ValueInfoProto { name: "input".to_string(), ..Default::default() }],
289 initializer: vec![TensorProto {
290 name: "w".to_string(),
291 data_type: tensor_proto::DataType::Float as i32,
292 dims: vec![2, 2],
293 float_data: vec![1.0, 2.0, 3.0, 4.0],
294 ..Default::default()
295 }],
296 node: vec![NodeProto {
297 op_type: "Conv".to_string(),
298 name: "conv0".to_string(),
299 input: vec!["input".to_string(), "w".to_string()],
300 output: vec!["out".to_string()],
301 ..Default::default()
302 }],
303 ..Default::default()
304 }
305 }
306
307 fn make_two_weight_graph() -> GraphProto {
309 GraphProto {
310 input: vec![ValueInfoProto { name: "input".to_string(), ..Default::default() }],
311 initializer: vec![
312 TensorProto {
313 name: "w1".to_string(),
314 data_type: tensor_proto::DataType::Float as i32,
315 dims: vec![2, 2],
316 float_data: vec![1.0, 2.0, 3.0, 4.0],
317 ..Default::default()
318 },
319 TensorProto {
320 name: "w2".to_string(),
321 data_type: tensor_proto::DataType::Float as i32,
322 dims: vec![2, 2],
323 float_data: vec![5.0, 6.0, 7.0, 8.0],
324 ..Default::default()
325 },
326 ],
327 node: vec![
328 NodeProto {
329 op_type: "Conv".to_string(),
330 name: "conv1".to_string(),
331 input: vec!["input".to_string(), "w1".to_string()],
332 output: vec!["mid".to_string()],
333 ..Default::default()
334 },
335 NodeProto {
336 op_type: "Conv".to_string(),
337 name: "conv2".to_string(),
338 input: vec!["mid".to_string(), "w2".to_string()],
339 output: vec!["out".to_string()],
340 ..Default::default()
341 },
342 ],
343 ..Default::default()
344 }
345 }
346
347 #[test]
352 fn test_connectivity_passes_on_valid_graph() {
353 let graph = make_simple_graph();
354 let report = validate_graph_connectivity(&graph);
355 assert!(
356 report.valid,
357 "original graph should be valid; broken: {:?}",
358 report.broken_refs
359 );
360 }
361
362 #[test]
363 fn test_connectivity_detects_renamed_initializer() {
364 let mut graph = make_simple_graph();
367
368 for init in graph.initializer.iter_mut() {
369 if init.name == "w" {
370 init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
371 }
372 }
373
374 let report = validate_graph_connectivity(&graph);
375 assert!(!report.valid, "should detect broken reference to 'w'");
376 assert_eq!(report.broken_refs.len(), 1);
377 assert!(
378 report.broken_refs[0].contains("'w'"),
379 "error should mention 'w': {}",
380 report.broken_refs[0]
381 );
382 }
383
384 #[test]
385 fn test_connectivity_detects_multiple_broken_refs() {
386 let mut graph = make_two_weight_graph();
387
388 for init in graph.initializer.iter_mut() {
389 if init.name == "w1" {
390 init.name = "w1_broken".to_string();
391 } else if init.name == "w2" {
392 init.name = "w2_broken".to_string();
393 }
394 }
395
396 let report = validate_graph_connectivity(&graph);
397 assert!(!report.valid);
398 assert_eq!(report.broken_refs.len(), 2);
399 }
400
401 #[test]
402 fn test_connectivity_summary_formatting() {
403 let valid = ConnectivityReport {
404 valid: true,
405 broken_refs: vec![],
406 };
407 assert!(valid.summary().contains("OK"));
408
409 let broken = ConnectivityReport {
410 valid: false,
411 broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
412 };
413 let s = broken.summary();
414 assert!(s.contains("BROKEN"));
415 assert!(s.contains("1 dangling reference"));
416 assert!(s.contains("unknown input 'y'"));
417 }
418
419 #[test]
424 fn test_ensure_opset_bumps_low_version() {
425 let mut model = ModelProto {
426 opset_import: vec![OperatorSetIdProto { domain: String::new(), version: 10 }],
427 ..Default::default()
428 };
429
430 ensure_opset_version(&mut model, 13);
431
432 assert_eq!(model.opset_import[0].version, 13);
433 }
434
435 #[test]
436 fn test_ensure_opset_leaves_sufficient_version() {
437 let mut model = ModelProto {
438 opset_import: vec![OperatorSetIdProto { domain: String::new(), version: 17 }],
439 ..Default::default()
440 };
441
442 ensure_opset_version(&mut model, 13);
443
444 assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
445 }
446
447 #[test]
448 fn test_ensure_opset_adds_missing_default_domain() {
449 let mut model = ModelProto::default();
450 ensure_opset_version(&mut model, 13);
452
453 assert_eq!(model.opset_import.len(), 1);
454 assert!(model.opset_import[0].domain.is_empty());
455 assert_eq!(model.opset_import[0].version, 13);
456 }
457
458 #[test]
463 fn test_qdq_single_weight_produces_valid_graph() {
464 let mut graph = make_simple_graph();
465
466 let inputs = vec![QdqWeightInput {
467 original_name: "w".to_string(),
468 quantized_values: vec![25, 51, 76, 102],
469 scales: vec![0.039_215_686], zero_points: vec![0],
471 bits: 8,
472 axis: None,
473 }];
474
475 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
476
477 let report = validate_graph_connectivity(&graph);
478 assert!(
479 report.valid,
480 "graph after QDQ must be valid; broken: {:?}",
481 report.broken_refs
482 );
483 }
484
485 #[test]
486 fn test_qdq_adds_correct_initializers() {
487 let mut graph = make_simple_graph();
488
489 let inputs = vec![QdqWeightInput {
490 original_name: "w".to_string(),
491 quantized_values: vec![10, 20, 30, 40],
492 scales: vec![0.1],
493 zero_points: vec![-5],
494 bits: 8,
495 axis: None,
496 }];
497
498 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
499
500 let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
501
502 assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
503 assert!(init_names.contains(&"w_scale"), "missing w_scale");
504 assert!(init_names.contains(&"w_zp"), "missing w_zp");
505 assert!(
506 !init_names.contains(&"w"),
507 "original FP32 'w' should be removed"
508 );
509 }
510
511 #[test]
512 fn test_qdq_node_order_dequant_first() {
513 let mut graph = make_simple_graph();
514
515 let inputs = vec![QdqWeightInput {
516 original_name: "w".to_string(),
517 quantized_values: vec![10, 20, 30, 40],
518 scales: vec![0.1],
519 zero_points: vec![0],
520 bits: 8,
521 axis: None,
522 }];
523
524 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
525
526 let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
527
528 assert_eq!(ops.len(), 2);
529 assert_eq!(ops[0], "DequantizeLinear");
530 assert_eq!(ops[1], "Conv");
531 }
532
533 #[test]
534 fn test_qdq_dequant_output_is_original_name() {
535 let mut graph = make_simple_graph();
536
537 let inputs = vec![QdqWeightInput {
538 original_name: "w".to_string(),
539 quantized_values: vec![1, 2, 3, 4],
540 scales: vec![1.0],
541 zero_points: vec![0],
542 bits: 8,
543 axis: None,
544 }];
545
546 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
547
548 let dq = &graph.node[0]; assert_eq!(dq.output[0], "w", "DequantizeLinear output must be original name");
550 }
551
552 #[test]
553 fn test_qdq_two_weights_both_transformed() {
554 let mut graph = make_two_weight_graph();
555
556 let inputs = vec![
557 QdqWeightInput {
558 original_name: "w1".to_string(),
559 quantized_values: vec![10, 20, 30, 40],
560 scales: vec![0.1],
561 zero_points: vec![0],
562 bits: 8,
563 axis: None,
564 },
565 QdqWeightInput {
566 original_name: "w2".to_string(),
567 quantized_values: vec![50, 60, 70, 80],
568 scales: vec![0.2],
569 zero_points: vec![-1],
570 bits: 8,
571 axis: None,
572 },
573 ];
574
575 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
576
577 let report = validate_graph_connectivity(&graph);
579 assert!(report.valid, "two-weight graph broken: {:?}", report.broken_refs);
580
581 assert_eq!(graph.node.len(), 4);
583
584 assert_eq!(graph.node[0].op_type, "DequantizeLinear");
586 assert_eq!(graph.node[1].op_type, "DequantizeLinear");
587
588 let dq_outputs: Vec<&str> = graph.node.iter().take(2)
590 .map(|n| n.output[0].as_str())
591 .collect();
592 assert!(dq_outputs.contains(&"w1"));
593 assert!(dq_outputs.contains(&"w2"));
594 }
595
596 #[test]
597 fn test_qdq_int4_values_stored_as_int8() {
598 let mut graph = make_simple_graph();
599
600 let inputs = vec![QdqWeightInput {
602 original_name: "w".to_string(),
603 quantized_values: vec![-8, -1, 0, 7],
604 scales: vec![0.5],
605 zero_points: vec![0],
606 bits: 4, axis: None,
608 }];
609
610 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
611
612 let quant_init = graph
613 .initializer
614 .iter()
615 .find(|i| i.name == "w_quantized")
616 .expect("w_quantized not found");
617
618 assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
620
621 let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
623 assert_eq!(recovered, vec![-8, -1, 0, 7]);
624 }
625
626 #[test]
627 fn test_qdq_unknown_weight_returns_error() {
628 let mut graph = make_simple_graph();
629
630 let inputs = vec![QdqWeightInput {
631 original_name: "does_not_exist".to_string(),
632 quantized_values: vec![1, 2, 3],
633 scales: vec![1.0],
634 zero_points: vec![0],
635 bits: 8,
636 axis: None,
637 }];
638
639 let result = apply_qdq_transform(&mut graph, &inputs);
640 assert!(result.is_err());
641 assert!(
642 result.unwrap_err().to_string().contains("does_not_exist"),
643 "error should name the missing weight"
644 );
645 }
646
647 #[test]
648 fn test_qdq_non_quantized_initializers_preserved() {
649 let mut graph = make_simple_graph();
652
653 graph.initializer.push(TensorProto {
654 name: "bias".to_string(),
655 data_type: tensor_proto::DataType::Float as i32,
656 dims: vec![2],
657 float_data: vec![0.1, 0.2],
658 ..Default::default()
659 });
660
661 graph.node[0].input.push("bias".to_string());
663
664 let inputs = vec![QdqWeightInput {
665 original_name: "w".to_string(),
666 quantized_values: vec![10, 20, 30, 40],
667 scales: vec![0.1],
668 zero_points: vec![0],
669 bits: 8,
670 axis: None,
671 }];
672
673 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
674
675 let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
677
678 assert!(bias_init.is_some(), "non-quantized 'bias' initializer must be preserved");
679 assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
680
681 let report = validate_graph_connectivity(&graph);
683 assert!(report.valid, "broken: {:?}", report.broken_refs);
684 }
685}