1use crate::errors::{QuantizeError, Result};
10use crate::onnx_proto::{
11 attribute_proto, tensor_proto, AttributeProto, GraphProto, ModelProto, OperatorSetIdProto,
12 TensorProto,
13};
14use std::collections::{HashMap, HashSet};
15
16use super::quantization_nodes::{
17 build_dequantize_linear_node, build_quantized_weight_tensor, build_scale_tensor,
18 build_zero_point_tensor, DequantLinearNames,
19};
20
21#[derive(Debug)]
27pub struct QdqWeightInput {
28 pub original_name: String,
30 pub quantized_values: Vec<i8>,
33 pub scales: Vec<f32>,
36 pub zero_points: Vec<i8>,
39 pub bits: u8,
42 pub axis: Option<usize>,
44}
45
46#[derive(Debug)]
48#[must_use]
49pub struct ConnectivityReport {
50 pub valid: bool,
52 pub broken_refs: Vec<String>,
54}
55
56impl ConnectivityReport {
57 pub fn summary(&self) -> String {
59 if self.valid {
60 " Graph connectivity: OK\n".to_string()
61 } else {
62 let mut s = format!(
63 " Graph connectivity: BROKEN ({} dangling reference{})\n",
64 self.broken_refs.len(),
65 if self.broken_refs.len() == 1 { "" } else { "s" }
66 );
67 for (i, r) in self.broken_refs.iter().enumerate() {
68 s.push_str(&format!(" {}. {}\n", i + 1, r));
69 }
70 s
71 }
72 }
73}
74
75pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
89 let mut known: HashSet<String> = HashSet::new();
90
91 for inp in &graph.input {
93 known.insert(inp.name.clone());
94 }
95 for init in &graph.initializer {
96 known.insert(init.name.clone());
97 }
98
99 let mut broken = Vec::new();
100
101 for node in &graph.node {
103 for name in &node.input {
104 if name.is_empty() {
105 continue; }
107 if !known.contains(name.as_str()) {
108 broken.push(format!(
109 "Node '{}' (op={}) → unknown input '{}'",
110 node.name, node.op_type, name
111 ));
112 }
113 }
114 for name in &node.output {
116 if !name.is_empty() {
117 known.insert(name.clone());
118 }
119 }
120 }
121
122 ConnectivityReport {
123 valid: broken.is_empty(),
124 broken_refs: broken,
125 }
126}
127
128pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
141 let old_version = get_opset_version(model);
142
143 let mut found = false;
145 for opset in model.opset_import.iter_mut() {
146 if opset.domain.is_empty() {
147 if opset.version < min_version {
148 opset.version = min_version;
149 }
150 found = true;
151 break;
152 }
153 }
154 if !found {
155 model.opset_import.push(OperatorSetIdProto {
156 domain: String::new(),
157 version: min_version,
158 });
159 }
160
161 if old_version < min_version {
163 if let Some(graph) = model.graph.as_mut() {
164 upgrade_deprecated_ops(graph, old_version, min_version);
165 }
166 }
167}
168
169fn get_opset_version(model: &ModelProto) -> i64 {
171 model
172 .opset_import
173 .iter()
174 .find(|o| o.domain.is_empty())
175 .map_or(0, |o| o.version)
176}
177
178fn upgrade_deprecated_ops(graph: &mut GraphProto, old_opset: i64, new_opset: i64) {
185 let mut new_initializers: Vec<TensorProto> = Vec::new();
186
187 for node in graph.node.iter_mut() {
188 if node.op_type == "BatchNormalization" && old_opset < 9 && new_opset >= 9 {
191 node.attribute.retain(|a| a.name != "spatial");
192 }
193
194 if node.op_type == "Dropout" && old_opset < 12 && new_opset >= 12 {
197 let ratio = node
198 .attribute
199 .iter()
200 .find(|a| a.name == "ratio")
201 .map(|a| a.f)
202 .unwrap_or(0.5);
203 node.attribute.retain(|a| a.name != "ratio");
204
205 let init_name = format!(
206 "_quantize_rs_dropout_ratio_{}",
207 node.output.first().map_or("", |s| s.as_str()),
208 );
209 new_initializers.push(TensorProto {
210 name: init_name.clone(),
211 data_type: tensor_proto::DataType::Float as i32,
212 float_data: vec![ratio],
213 ..Default::default()
214 });
215
216 if node.input.len() < 2 {
217 node.input.push(init_name);
218 } else {
219 node.input[1] = init_name;
220 }
221 }
222
223 if (node.op_type == "Softmax" || node.op_type == "LogSoftmax")
227 && old_opset < 13
228 && new_opset >= 13
229 {
230 let has_axis = node.attribute.iter().any(|a| a.name == "axis");
231 if !has_axis {
232 node.attribute.push(AttributeProto {
233 name: "axis".to_string(),
234 r#type: attribute_proto::AttributeType::Int as i32,
235 i: 1, ..Default::default()
237 });
238 }
239 }
240 }
241
242 graph.initializer.extend(new_initializers);
243}
244
245pub fn apply_qdq_transform(graph: &mut GraphProto, inputs: &[QdqWeightInput]) -> Result<()> {
278 let shape_map: HashMap<String, Vec<i64>> = graph
282 .initializer
283 .iter()
284 .map(|init| (init.name.clone(), init.dims.clone()))
285 .collect();
286
287 let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
288
289 graph
293 .initializer
294 .retain(|init| !quant_set.contains(init.name.as_str()));
295
296 graph
303 .input
304 .retain(|inp| !quant_set.contains(inp.name.as_str()));
305
306 let mut dq_nodes = Vec::new();
310
311 for inp in inputs {
312 let shape =
313 shape_map
314 .get(&inp.original_name)
315 .ok_or_else(|| QuantizeError::GraphTransform {
316 reason: format!(
317 "Weight '{}' not found in model initializers — \
318 verify the name matches exactly",
319 inp.original_name
320 ),
321 })?;
322
323 let expected_len: i64 = shape.iter().product();
324 if inp.quantized_values.len() as i64 != expected_len {
325 return Err(QuantizeError::GraphTransform {
326 reason: format!(
327 "Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
328 inp.original_name,
329 inp.quantized_values.len(),
330 shape,
331 expected_len
332 ),
333 });
334 }
335
336 let names = DequantLinearNames::from_original(&inp.original_name);
337
338 graph.initializer.push(build_quantized_weight_tensor(
339 &names,
340 &inp.quantized_values,
341 shape,
342 ));
343 graph
344 .initializer
345 .push(build_scale_tensor(&names, &inp.scales));
346 graph
347 .initializer
348 .push(build_zero_point_tensor(&names, &inp.zero_points));
349
350 dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
351 }
352
353 let existing_nodes = std::mem::take(&mut graph.node);
359 graph.node = dq_nodes;
360 graph.node.extend(existing_nodes);
361
362 Ok(())
363}
364
365#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::onnx_proto::{
373 tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
374 ValueInfoProto,
375 };
376
377 fn make_simple_graph() -> GraphProto {
384 GraphProto {
385 input: vec![ValueInfoProto {
386 name: "input".to_string(),
387 ..Default::default()
388 }],
389 initializer: vec![TensorProto {
390 name: "w".to_string(),
391 data_type: tensor_proto::DataType::Float as i32,
392 dims: vec![2, 2],
393 float_data: vec![1.0, 2.0, 3.0, 4.0],
394 ..Default::default()
395 }],
396 node: vec![NodeProto {
397 op_type: "Conv".to_string(),
398 name: "conv0".to_string(),
399 input: vec!["input".to_string(), "w".to_string()],
400 output: vec!["out".to_string()],
401 ..Default::default()
402 }],
403 ..Default::default()
404 }
405 }
406
407 fn make_two_weight_graph() -> GraphProto {
409 GraphProto {
410 input: vec![ValueInfoProto {
411 name: "input".to_string(),
412 ..Default::default()
413 }],
414 initializer: vec![
415 TensorProto {
416 name: "w1".to_string(),
417 data_type: tensor_proto::DataType::Float as i32,
418 dims: vec![2, 2],
419 float_data: vec![1.0, 2.0, 3.0, 4.0],
420 ..Default::default()
421 },
422 TensorProto {
423 name: "w2".to_string(),
424 data_type: tensor_proto::DataType::Float as i32,
425 dims: vec![2, 2],
426 float_data: vec![5.0, 6.0, 7.0, 8.0],
427 ..Default::default()
428 },
429 ],
430 node: vec![
431 NodeProto {
432 op_type: "Conv".to_string(),
433 name: "conv1".to_string(),
434 input: vec!["input".to_string(), "w1".to_string()],
435 output: vec!["mid".to_string()],
436 ..Default::default()
437 },
438 NodeProto {
439 op_type: "Conv".to_string(),
440 name: "conv2".to_string(),
441 input: vec!["mid".to_string(), "w2".to_string()],
442 output: vec!["out".to_string()],
443 ..Default::default()
444 },
445 ],
446 ..Default::default()
447 }
448 }
449
450 #[test]
455 fn test_connectivity_passes_on_valid_graph() {
456 let graph = make_simple_graph();
457 let report = validate_graph_connectivity(&graph);
458 assert!(
459 report.valid,
460 "original graph should be valid; broken: {:?}",
461 report.broken_refs
462 );
463 }
464
465 #[test]
466 fn test_connectivity_detects_renamed_initializer() {
467 let mut graph = make_simple_graph();
470
471 for init in graph.initializer.iter_mut() {
472 if init.name == "w" {
473 init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
474 }
475 }
476
477 let report = validate_graph_connectivity(&graph);
478 assert!(!report.valid, "should detect broken reference to 'w'");
479 assert_eq!(report.broken_refs.len(), 1);
480 assert!(
481 report.broken_refs[0].contains("'w'"),
482 "error should mention 'w': {}",
483 report.broken_refs[0]
484 );
485 }
486
487 #[test]
488 fn test_connectivity_detects_multiple_broken_refs() {
489 let mut graph = make_two_weight_graph();
490
491 for init in graph.initializer.iter_mut() {
492 if init.name == "w1" {
493 init.name = "w1_broken".to_string();
494 } else if init.name == "w2" {
495 init.name = "w2_broken".to_string();
496 }
497 }
498
499 let report = validate_graph_connectivity(&graph);
500 assert!(!report.valid);
501 assert_eq!(report.broken_refs.len(), 2);
502 }
503
504 #[test]
505 fn test_connectivity_summary_formatting() {
506 let valid = ConnectivityReport {
507 valid: true,
508 broken_refs: vec![],
509 };
510 assert!(valid.summary().contains("OK"));
511
512 let broken = ConnectivityReport {
513 valid: false,
514 broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
515 };
516 let s = broken.summary();
517 assert!(s.contains("BROKEN"));
518 assert!(s.contains("1 dangling reference"));
519 assert!(s.contains("unknown input 'y'"));
520 }
521
522 #[test]
527 fn test_ensure_opset_bumps_low_version() {
528 let mut model = ModelProto {
529 opset_import: vec![OperatorSetIdProto {
530 domain: String::new(),
531 version: 10,
532 }],
533 ..Default::default()
534 };
535
536 ensure_opset_version(&mut model, 13);
537
538 assert_eq!(model.opset_import[0].version, 13);
539 }
540
541 #[test]
542 fn test_ensure_opset_leaves_sufficient_version() {
543 let mut model = ModelProto {
544 opset_import: vec![OperatorSetIdProto {
545 domain: String::new(),
546 version: 17,
547 }],
548 ..Default::default()
549 };
550
551 ensure_opset_version(&mut model, 13);
552
553 assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
554 }
555
556 #[test]
557 fn test_ensure_opset_adds_missing_default_domain() {
558 let mut model = ModelProto::default();
559 ensure_opset_version(&mut model, 13);
561
562 assert_eq!(model.opset_import.len(), 1);
563 assert!(model.opset_import[0].domain.is_empty());
564 assert_eq!(model.opset_import[0].version, 13);
565 }
566
567 #[test]
572 fn test_qdq_single_weight_produces_valid_graph() {
573 let mut graph = make_simple_graph();
574
575 let inputs = vec![QdqWeightInput {
576 original_name: "w".to_string(),
577 quantized_values: vec![25, 51, 76, 102],
578 scales: vec![0.039_215_686], zero_points: vec![0],
580 bits: 8,
581 axis: None,
582 }];
583
584 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
585
586 let report = validate_graph_connectivity(&graph);
587 assert!(
588 report.valid,
589 "graph after QDQ must be valid; broken: {:?}",
590 report.broken_refs
591 );
592 }
593
594 #[test]
595 fn test_qdq_adds_correct_initializers() {
596 let mut graph = make_simple_graph();
597
598 let inputs = vec![QdqWeightInput {
599 original_name: "w".to_string(),
600 quantized_values: vec![10, 20, 30, 40],
601 scales: vec![0.1],
602 zero_points: vec![-5],
603 bits: 8,
604 axis: None,
605 }];
606
607 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
608
609 let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
610
611 assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
612 assert!(init_names.contains(&"w_scale"), "missing w_scale");
613 assert!(init_names.contains(&"w_zp"), "missing w_zp");
614 assert!(
615 !init_names.contains(&"w"),
616 "original FP32 'w' should be removed"
617 );
618 }
619
620 #[test]
621 fn test_qdq_node_order_dequant_first() {
622 let mut graph = make_simple_graph();
623
624 let inputs = vec![QdqWeightInput {
625 original_name: "w".to_string(),
626 quantized_values: vec![10, 20, 30, 40],
627 scales: vec![0.1],
628 zero_points: vec![0],
629 bits: 8,
630 axis: None,
631 }];
632
633 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
634
635 let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
636
637 assert_eq!(ops.len(), 2);
638 assert_eq!(ops[0], "DequantizeLinear");
639 assert_eq!(ops[1], "Conv");
640 }
641
642 #[test]
643 fn test_qdq_dequant_output_is_original_name() {
644 let mut graph = make_simple_graph();
645
646 let inputs = vec![QdqWeightInput {
647 original_name: "w".to_string(),
648 quantized_values: vec![1, 2, 3, 4],
649 scales: vec![1.0],
650 zero_points: vec![0],
651 bits: 8,
652 axis: None,
653 }];
654
655 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
656
657 let dq = &graph.node[0]; assert_eq!(
659 dq.output[0], "w",
660 "DequantizeLinear output must be original name"
661 );
662 }
663
664 #[test]
665 fn test_qdq_two_weights_both_transformed() {
666 let mut graph = make_two_weight_graph();
667
668 let inputs = vec![
669 QdqWeightInput {
670 original_name: "w1".to_string(),
671 quantized_values: vec![10, 20, 30, 40],
672 scales: vec![0.1],
673 zero_points: vec![0],
674 bits: 8,
675 axis: None,
676 },
677 QdqWeightInput {
678 original_name: "w2".to_string(),
679 quantized_values: vec![50, 60, 70, 80],
680 scales: vec![0.2],
681 zero_points: vec![-1],
682 bits: 8,
683 axis: None,
684 },
685 ];
686
687 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
688
689 let report = validate_graph_connectivity(&graph);
691 assert!(
692 report.valid,
693 "two-weight graph broken: {:?}",
694 report.broken_refs
695 );
696
697 assert_eq!(graph.node.len(), 4);
699
700 assert_eq!(graph.node[0].op_type, "DequantizeLinear");
702 assert_eq!(graph.node[1].op_type, "DequantizeLinear");
703
704 let dq_outputs: Vec<&str> = graph
706 .node
707 .iter()
708 .take(2)
709 .map(|n| n.output[0].as_str())
710 .collect();
711 assert!(dq_outputs.contains(&"w1"));
712 assert!(dq_outputs.contains(&"w2"));
713 }
714
715 #[test]
716 fn test_qdq_int4_values_stored_as_int8() {
717 let mut graph = make_simple_graph();
718
719 let inputs = vec![QdqWeightInput {
721 original_name: "w".to_string(),
722 quantized_values: vec![-8, -1, 0, 7],
723 scales: vec![0.5],
724 zero_points: vec![0],
725 bits: 4, axis: None,
727 }];
728
729 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
730
731 let quant_init = graph
732 .initializer
733 .iter()
734 .find(|i| i.name == "w_quantized")
735 .expect("w_quantized not found");
736
737 assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
739
740 let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
742 assert_eq!(recovered, vec![-8, -1, 0, 7]);
743 }
744
745 #[test]
746 fn test_qdq_unknown_weight_returns_error() {
747 let mut graph = make_simple_graph();
748
749 let inputs = vec![QdqWeightInput {
750 original_name: "does_not_exist".to_string(),
751 quantized_values: vec![1, 2, 3],
752 scales: vec![1.0],
753 zero_points: vec![0],
754 bits: 8,
755 axis: None,
756 }];
757
758 let result = apply_qdq_transform(&mut graph, &inputs);
759 assert!(result.is_err());
760 assert!(
761 result.unwrap_err().to_string().contains("does_not_exist"),
762 "error should name the missing weight"
763 );
764 }
765
766 #[test]
767 fn test_qdq_non_quantized_initializers_preserved() {
768 let mut graph = make_simple_graph();
771
772 graph.initializer.push(TensorProto {
773 name: "bias".to_string(),
774 data_type: tensor_proto::DataType::Float as i32,
775 dims: vec![2],
776 float_data: vec![0.1, 0.2],
777 ..Default::default()
778 });
779
780 graph.node[0].input.push("bias".to_string());
782
783 let inputs = vec![QdqWeightInput {
784 original_name: "w".to_string(),
785 quantized_values: vec![10, 20, 30, 40],
786 scales: vec![0.1],
787 zero_points: vec![0],
788 bits: 8,
789 axis: None,
790 }];
791
792 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
793
794 let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
796
797 assert!(
798 bias_init.is_some(),
799 "non-quantized 'bias' initializer must be preserved"
800 );
801 assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
802
803 let report = validate_graph_connectivity(&graph);
805 assert!(report.valid, "broken: {:?}", report.broken_refs);
806 }
807
808 #[test]
809 fn test_ensure_opset_strips_deprecated_attrs() {
810 use crate::onnx_proto::NodeProto;
811
812 let mut model = ModelProto {
813 opset_import: vec![OperatorSetIdProto {
814 domain: String::new(),
815 version: 7,
816 }],
817 graph: Some(GraphProto {
818 node: vec![
819 NodeProto {
821 op_type: "BatchNormalization".to_string(),
822 input: vec!["x".into(), "s".into(), "b".into(), "m".into(), "v".into()],
823 output: vec!["bn_out".into()],
824 attribute: vec![
825 AttributeProto {
826 name: "epsilon".to_string(),
827 r#type: attribute_proto::AttributeType::Float as i32,
828 f: 1e-5,
829 ..Default::default()
830 },
831 AttributeProto {
832 name: "spatial".to_string(),
833 r#type: attribute_proto::AttributeType::Int as i32,
834 i: 1,
835 ..Default::default()
836 },
837 ],
838 ..Default::default()
839 },
840 NodeProto {
843 op_type: "Dropout".to_string(),
844 input: vec!["bn_out".into()],
845 output: vec!["drop_out".into(), "drop_mask".into()],
846 attribute: vec![AttributeProto {
847 name: "ratio".to_string(),
848 r#type: attribute_proto::AttributeType::Float as i32,
849 f: 0.3,
850 ..Default::default()
851 }],
852 ..Default::default()
853 },
854 NodeProto {
857 op_type: "Softmax".to_string(),
858 input: vec!["drop_out".into()],
859 output: vec!["sm_out".into()],
860 attribute: vec![],
861 ..Default::default()
862 },
863 ],
864 ..Default::default()
865 }),
866 ..Default::default()
867 };
868
869 ensure_opset_version(&mut model, 13);
870
871 let opset = model
873 .opset_import
874 .iter()
875 .find(|o| o.domain.is_empty())
876 .unwrap();
877 assert_eq!(opset.version, 13);
878
879 let graph = model.graph.as_ref().unwrap();
880
881 let bn = &graph.node[0];
883 assert!(
884 !bn.attribute.iter().any(|a| a.name == "spatial"),
885 "BatchNormalization.spatial should be stripped"
886 );
887 assert!(
888 bn.attribute.iter().any(|a| a.name == "epsilon"),
889 "BatchNormalization.epsilon should be preserved"
890 );
891
892 let drop = &graph.node[1];
894 assert!(
895 !drop.attribute.iter().any(|a| a.name == "ratio"),
896 "Dropout.ratio attribute should be removed"
897 );
898 assert_eq!(drop.input.len(), 2, "Dropout should now have 2 inputs");
899 let ratio_init_name = &drop.input[1];
900
901 let ratio_init = graph
903 .initializer
904 .iter()
905 .find(|i| &i.name == ratio_init_name)
906 .expect("Dropout ratio initializer should exist");
907 assert_eq!(ratio_init.data_type, tensor_proto::DataType::Float as i32);
908 assert!(
909 (ratio_init.float_data[0] - 0.3).abs() < 1e-6,
910 "ratio should be 0.3"
911 );
912
913 let sm = &graph.node[2];
915 assert_eq!(sm.op_type, "Softmax");
916 let axis_attr = sm
917 .attribute
918 .iter()
919 .find(|a| a.name == "axis")
920 .expect("Softmax should have axis attribute added");
921 assert_eq!(axis_attr.i, 1, "Softmax axis should be 1 (old default)");
922 }
923
924 #[test]
925 fn test_ensure_opset_no_downgrade() {
926 let mut model = ModelProto {
927 opset_import: vec![OperatorSetIdProto {
928 domain: String::new(),
929 version: 15,
930 }],
931 graph: Some(GraphProto::default()),
932 ..Default::default()
933 };
934
935 ensure_opset_version(&mut model, 10);
937 let opset = model
938 .opset_import
939 .iter()
940 .find(|o| o.domain.is_empty())
941 .unwrap();
942 assert_eq!(opset.version, 15);
943 }
944}