1use crate::errors::{QuantizeError, Result};
10use crate::onnx_proto::{
11 attribute_proto, tensor_proto, AttributeProto, GraphProto, ModelProto, OperatorSetIdProto,
12 TensorProto,
13};
14use std::collections::{HashMap, HashSet};
15
16use super::quantization_nodes::{
17 build_dequantize_linear_node, build_quantized_weight_tensor, build_scale_tensor,
18 build_zero_point_tensor, DequantLinearNames, StorageFormat,
19};
20
21#[derive(Debug, Clone)]
27pub struct QdqWeightInput {
28 pub original_name: String,
30 pub quantized_values: Vec<i8>,
33 pub scales: Vec<f32>,
36 pub zero_points: Vec<i8>,
39 pub bits: u8,
42 pub axis: Option<usize>,
44}
45
46#[derive(Debug, Clone, Copy, Default)]
48pub struct SaveOptions {
49 pub native_int4: bool,
60}
61
62impl SaveOptions {
63 pub fn with_native_int4(mut self, enabled: bool) -> Self {
65 self.native_int4 = enabled;
66 self
67 }
68}
69
70#[derive(Debug)]
72#[must_use]
73pub struct ConnectivityReport {
74 pub valid: bool,
76 pub broken_refs: Vec<String>,
78}
79
80impl ConnectivityReport {
81 pub fn summary(&self) -> String {
83 if self.valid {
84 " Graph connectivity: OK\n".to_string()
85 } else {
86 let mut s = format!(
87 " Graph connectivity: BROKEN ({} dangling reference{})\n",
88 self.broken_refs.len(),
89 if self.broken_refs.len() == 1 { "" } else { "s" }
90 );
91 for (i, r) in self.broken_refs.iter().enumerate() {
92 s.push_str(&format!(" {}. {}\n", i + 1, r));
93 }
94 s
95 }
96 }
97}
98
99pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
113 let mut known: HashSet<String> = HashSet::new();
114
115 for inp in &graph.input {
117 known.insert(inp.name.clone());
118 }
119 for init in &graph.initializer {
120 known.insert(init.name.clone());
121 }
122
123 let mut broken = Vec::new();
124
125 for node in &graph.node {
127 for name in &node.input {
128 if name.is_empty() {
129 continue; }
131 if !known.contains(name.as_str()) {
132 broken.push(format!(
133 "Node '{}' (op={}) → unknown input '{}'",
134 node.name, node.op_type, name
135 ));
136 }
137 }
138 for name in &node.output {
140 if !name.is_empty() {
141 known.insert(name.clone());
142 }
143 }
144 }
145
146 ConnectivityReport {
147 valid: broken.is_empty(),
148 broken_refs: broken,
149 }
150}
151
152pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
165 let old_version = get_opset_version(model);
166
167 let mut found = false;
169 for opset in model.opset_import.iter_mut() {
170 if opset.domain.is_empty() {
171 if opset.version < min_version {
172 opset.version = min_version;
173 }
174 found = true;
175 break;
176 }
177 }
178 if !found {
179 model.opset_import.push(OperatorSetIdProto {
180 domain: String::new(),
181 version: min_version,
182 });
183 }
184
185 if old_version < min_version {
187 if let Some(graph) = model.graph.as_mut() {
188 upgrade_deprecated_ops(graph, old_version, min_version);
189 }
190 }
191}
192
193fn get_opset_version(model: &ModelProto) -> i64 {
195 model
196 .opset_import
197 .iter()
198 .find(|o| o.domain.is_empty())
199 .map_or(0, |o| o.version)
200}
201
202fn upgrade_deprecated_ops(graph: &mut GraphProto, old_opset: i64, new_opset: i64) {
209 let mut new_initializers: Vec<TensorProto> = Vec::new();
210
211 for node in graph.node.iter_mut() {
212 if node.op_type == "BatchNormalization" && old_opset < 9 && new_opset >= 9 {
215 node.attribute.retain(|a| a.name != "spatial");
216 }
217
218 if node.op_type == "Dropout" && old_opset < 12 && new_opset >= 12 {
221 let ratio = node
222 .attribute
223 .iter()
224 .find(|a| a.name == "ratio")
225 .map(|a| a.f)
226 .unwrap_or(0.5);
227 node.attribute.retain(|a| a.name != "ratio");
228
229 let init_name = format!(
230 "_quantize_rs_dropout_ratio_{}",
231 node.output.first().map_or("", |s| s.as_str()),
232 );
233 new_initializers.push(TensorProto {
234 name: init_name.clone(),
235 data_type: tensor_proto::DataType::Float as i32,
236 float_data: vec![ratio],
237 ..Default::default()
238 });
239
240 if node.input.len() < 2 {
241 node.input.push(init_name);
242 } else {
243 node.input[1] = init_name;
244 }
245 }
246
247 if (node.op_type == "Softmax" || node.op_type == "LogSoftmax")
251 && old_opset < 13
252 && new_opset >= 13
253 {
254 let has_axis = node.attribute.iter().any(|a| a.name == "axis");
255 if !has_axis {
256 node.attribute.push(AttributeProto {
257 name: "axis".to_string(),
258 r#type: attribute_proto::AttributeType::Int as i32,
259 i: 1, ..Default::default()
261 });
262 }
263 }
264 }
265
266 graph.initializer.extend(new_initializers);
267}
268
269pub fn apply_qdq_transform(graph: &mut GraphProto, inputs: &[QdqWeightInput]) -> Result<()> {
302 apply_qdq_transform_with_options(graph, inputs, SaveOptions::default())
303}
304
305pub fn apply_qdq_transform_with_options(
310 graph: &mut GraphProto,
311 inputs: &[QdqWeightInput],
312 options: SaveOptions,
313) -> Result<()> {
314 let shape_map: HashMap<String, Vec<i64>> = graph
318 .initializer
319 .iter()
320 .map(|init| (init.name.clone(), init.dims.clone()))
321 .collect();
322
323 let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
324
325 graph
329 .initializer
330 .retain(|init| !quant_set.contains(init.name.as_str()));
331
332 graph
339 .input
340 .retain(|inp| !quant_set.contains(inp.name.as_str()));
341
342 let mut dq_nodes = Vec::new();
346
347 for inp in inputs {
348 let shape =
349 shape_map
350 .get(&inp.original_name)
351 .ok_or_else(|| QuantizeError::GraphTransform {
352 reason: format!(
353 "Weight '{}' not found in model initializers — \
354 verify the name matches exactly",
355 inp.original_name
356 ),
357 })?;
358
359 let expected_len: i64 = shape.iter().product();
360 if inp.quantized_values.len() as i64 != expected_len {
361 return Err(QuantizeError::GraphTransform {
362 reason: format!(
363 "Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
364 inp.original_name,
365 inp.quantized_values.len(),
366 shape,
367 expected_len
368 ),
369 });
370 }
371
372 let names = DequantLinearNames::from_original(&inp.original_name);
373
374 let format = if options.native_int4 && inp.bits == 4 {
375 StorageFormat::NativeInt4
376 } else {
377 StorageFormat::Int8Widened
378 };
379
380 graph.initializer.push(build_quantized_weight_tensor(
381 &names,
382 &inp.quantized_values,
383 shape,
384 format,
385 ));
386 graph
387 .initializer
388 .push(build_scale_tensor(&names, &inp.scales));
389 graph
390 .initializer
391 .push(build_zero_point_tensor(&names, &inp.zero_points, format));
392
393 dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
394 }
395
396 let existing_nodes = std::mem::take(&mut graph.node);
402 graph.node = dq_nodes;
403 graph.node.extend(existing_nodes);
404
405 Ok(())
406}
407
408#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::onnx_proto::{
416 tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
417 ValueInfoProto,
418 };
419
420 fn make_simple_graph() -> GraphProto {
427 GraphProto {
428 input: vec![ValueInfoProto {
429 name: "input".to_string(),
430 ..Default::default()
431 }],
432 initializer: vec![TensorProto {
433 name: "w".to_string(),
434 data_type: tensor_proto::DataType::Float as i32,
435 dims: vec![2, 2],
436 float_data: vec![1.0, 2.0, 3.0, 4.0],
437 ..Default::default()
438 }],
439 node: vec![NodeProto {
440 op_type: "Conv".to_string(),
441 name: "conv0".to_string(),
442 input: vec!["input".to_string(), "w".to_string()],
443 output: vec!["out".to_string()],
444 ..Default::default()
445 }],
446 ..Default::default()
447 }
448 }
449
450 fn make_two_weight_graph() -> GraphProto {
452 GraphProto {
453 input: vec![ValueInfoProto {
454 name: "input".to_string(),
455 ..Default::default()
456 }],
457 initializer: vec![
458 TensorProto {
459 name: "w1".to_string(),
460 data_type: tensor_proto::DataType::Float as i32,
461 dims: vec![2, 2],
462 float_data: vec![1.0, 2.0, 3.0, 4.0],
463 ..Default::default()
464 },
465 TensorProto {
466 name: "w2".to_string(),
467 data_type: tensor_proto::DataType::Float as i32,
468 dims: vec![2, 2],
469 float_data: vec![5.0, 6.0, 7.0, 8.0],
470 ..Default::default()
471 },
472 ],
473 node: vec![
474 NodeProto {
475 op_type: "Conv".to_string(),
476 name: "conv1".to_string(),
477 input: vec!["input".to_string(), "w1".to_string()],
478 output: vec!["mid".to_string()],
479 ..Default::default()
480 },
481 NodeProto {
482 op_type: "Conv".to_string(),
483 name: "conv2".to_string(),
484 input: vec!["mid".to_string(), "w2".to_string()],
485 output: vec!["out".to_string()],
486 ..Default::default()
487 },
488 ],
489 ..Default::default()
490 }
491 }
492
493 #[test]
498 fn test_connectivity_passes_on_valid_graph() {
499 let graph = make_simple_graph();
500 let report = validate_graph_connectivity(&graph);
501 assert!(
502 report.valid,
503 "original graph should be valid; broken: {:?}",
504 report.broken_refs
505 );
506 }
507
508 #[test]
509 fn test_connectivity_detects_renamed_initializer() {
510 let mut graph = make_simple_graph();
513
514 for init in graph.initializer.iter_mut() {
515 if init.name == "w" {
516 init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
517 }
518 }
519
520 let report = validate_graph_connectivity(&graph);
521 assert!(!report.valid, "should detect broken reference to 'w'");
522 assert_eq!(report.broken_refs.len(), 1);
523 assert!(
524 report.broken_refs[0].contains("'w'"),
525 "error should mention 'w': {}",
526 report.broken_refs[0]
527 );
528 }
529
530 #[test]
531 fn test_connectivity_detects_multiple_broken_refs() {
532 let mut graph = make_two_weight_graph();
533
534 for init in graph.initializer.iter_mut() {
535 if init.name == "w1" {
536 init.name = "w1_broken".to_string();
537 } else if init.name == "w2" {
538 init.name = "w2_broken".to_string();
539 }
540 }
541
542 let report = validate_graph_connectivity(&graph);
543 assert!(!report.valid);
544 assert_eq!(report.broken_refs.len(), 2);
545 }
546
547 #[test]
548 fn test_connectivity_summary_formatting() {
549 let valid = ConnectivityReport {
550 valid: true,
551 broken_refs: vec![],
552 };
553 assert!(valid.summary().contains("OK"));
554
555 let broken = ConnectivityReport {
556 valid: false,
557 broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
558 };
559 let s = broken.summary();
560 assert!(s.contains("BROKEN"));
561 assert!(s.contains("1 dangling reference"));
562 assert!(s.contains("unknown input 'y'"));
563 }
564
565 #[test]
570 fn test_ensure_opset_bumps_low_version() {
571 let mut model = ModelProto {
572 opset_import: vec![OperatorSetIdProto {
573 domain: String::new(),
574 version: 10,
575 }],
576 ..Default::default()
577 };
578
579 ensure_opset_version(&mut model, 13);
580
581 assert_eq!(model.opset_import[0].version, 13);
582 }
583
584 #[test]
585 fn test_ensure_opset_leaves_sufficient_version() {
586 let mut model = ModelProto {
587 opset_import: vec![OperatorSetIdProto {
588 domain: String::new(),
589 version: 17,
590 }],
591 ..Default::default()
592 };
593
594 ensure_opset_version(&mut model, 13);
595
596 assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
597 }
598
599 #[test]
600 fn test_ensure_opset_adds_missing_default_domain() {
601 let mut model = ModelProto::default();
602 ensure_opset_version(&mut model, 13);
604
605 assert_eq!(model.opset_import.len(), 1);
606 assert!(model.opset_import[0].domain.is_empty());
607 assert_eq!(model.opset_import[0].version, 13);
608 }
609
610 #[test]
615 fn test_qdq_single_weight_produces_valid_graph() {
616 let mut graph = make_simple_graph();
617
618 let inputs = vec![QdqWeightInput {
619 original_name: "w".to_string(),
620 quantized_values: vec![25, 51, 76, 102],
621 scales: vec![0.039_215_686], zero_points: vec![0],
623 bits: 8,
624 axis: None,
625 }];
626
627 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
628
629 let report = validate_graph_connectivity(&graph);
630 assert!(
631 report.valid,
632 "graph after QDQ must be valid; broken: {:?}",
633 report.broken_refs
634 );
635 }
636
637 #[test]
638 fn test_qdq_adds_correct_initializers() {
639 let mut graph = make_simple_graph();
640
641 let inputs = vec![QdqWeightInput {
642 original_name: "w".to_string(),
643 quantized_values: vec![10, 20, 30, 40],
644 scales: vec![0.1],
645 zero_points: vec![-5],
646 bits: 8,
647 axis: None,
648 }];
649
650 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
651
652 let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
653
654 assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
655 assert!(init_names.contains(&"w_scale"), "missing w_scale");
656 assert!(init_names.contains(&"w_zp"), "missing w_zp");
657 assert!(
658 !init_names.contains(&"w"),
659 "original FP32 'w' should be removed"
660 );
661 }
662
663 #[test]
664 fn test_qdq_node_order_dequant_first() {
665 let mut graph = make_simple_graph();
666
667 let inputs = vec![QdqWeightInput {
668 original_name: "w".to_string(),
669 quantized_values: vec![10, 20, 30, 40],
670 scales: vec![0.1],
671 zero_points: vec![0],
672 bits: 8,
673 axis: None,
674 }];
675
676 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
677
678 let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
679
680 assert_eq!(ops.len(), 2);
681 assert_eq!(ops[0], "DequantizeLinear");
682 assert_eq!(ops[1], "Conv");
683 }
684
685 #[test]
686 fn test_qdq_dequant_output_is_original_name() {
687 let mut graph = make_simple_graph();
688
689 let inputs = vec![QdqWeightInput {
690 original_name: "w".to_string(),
691 quantized_values: vec![1, 2, 3, 4],
692 scales: vec![1.0],
693 zero_points: vec![0],
694 bits: 8,
695 axis: None,
696 }];
697
698 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
699
700 let dq = &graph.node[0]; assert_eq!(
702 dq.output[0], "w",
703 "DequantizeLinear output must be original name"
704 );
705 }
706
707 #[test]
708 fn test_qdq_two_weights_both_transformed() {
709 let mut graph = make_two_weight_graph();
710
711 let inputs = vec![
712 QdqWeightInput {
713 original_name: "w1".to_string(),
714 quantized_values: vec![10, 20, 30, 40],
715 scales: vec![0.1],
716 zero_points: vec![0],
717 bits: 8,
718 axis: None,
719 },
720 QdqWeightInput {
721 original_name: "w2".to_string(),
722 quantized_values: vec![50, 60, 70, 80],
723 scales: vec![0.2],
724 zero_points: vec![-1],
725 bits: 8,
726 axis: None,
727 },
728 ];
729
730 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
731
732 let report = validate_graph_connectivity(&graph);
734 assert!(
735 report.valid,
736 "two-weight graph broken: {:?}",
737 report.broken_refs
738 );
739
740 assert_eq!(graph.node.len(), 4);
742
743 assert_eq!(graph.node[0].op_type, "DequantizeLinear");
745 assert_eq!(graph.node[1].op_type, "DequantizeLinear");
746
747 let dq_outputs: Vec<&str> = graph
749 .node
750 .iter()
751 .take(2)
752 .map(|n| n.output[0].as_str())
753 .collect();
754 assert!(dq_outputs.contains(&"w1"));
755 assert!(dq_outputs.contains(&"w2"));
756 }
757
758 #[test]
759 fn test_qdq_int4_values_stored_as_int8() {
760 let mut graph = make_simple_graph();
761
762 let inputs = vec![QdqWeightInput {
764 original_name: "w".to_string(),
765 quantized_values: vec![-8, -1, 0, 7],
766 scales: vec![0.5],
767 zero_points: vec![0],
768 bits: 4, axis: None,
770 }];
771
772 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
773
774 let quant_init = graph
775 .initializer
776 .iter()
777 .find(|i| i.name == "w_quantized")
778 .expect("w_quantized not found");
779
780 assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
782
783 let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
785 assert_eq!(recovered, vec![-8, -1, 0, 7]);
786 }
787
788 #[test]
789 fn test_qdq_unknown_weight_returns_error() {
790 let mut graph = make_simple_graph();
791
792 let inputs = vec![QdqWeightInput {
793 original_name: "does_not_exist".to_string(),
794 quantized_values: vec![1, 2, 3],
795 scales: vec![1.0],
796 zero_points: vec![0],
797 bits: 8,
798 axis: None,
799 }];
800
801 let result = apply_qdq_transform(&mut graph, &inputs);
802 assert!(result.is_err());
803 assert!(
804 result.unwrap_err().to_string().contains("does_not_exist"),
805 "error should name the missing weight"
806 );
807 }
808
809 #[test]
810 fn test_qdq_non_quantized_initializers_preserved() {
811 let mut graph = make_simple_graph();
814
815 graph.initializer.push(TensorProto {
816 name: "bias".to_string(),
817 data_type: tensor_proto::DataType::Float as i32,
818 dims: vec![2],
819 float_data: vec![0.1, 0.2],
820 ..Default::default()
821 });
822
823 graph.node[0].input.push("bias".to_string());
825
826 let inputs = vec![QdqWeightInput {
827 original_name: "w".to_string(),
828 quantized_values: vec![10, 20, 30, 40],
829 scales: vec![0.1],
830 zero_points: vec![0],
831 bits: 8,
832 axis: None,
833 }];
834
835 apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
836
837 let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
839
840 assert!(
841 bias_init.is_some(),
842 "non-quantized 'bias' initializer must be preserved"
843 );
844 assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
845
846 let report = validate_graph_connectivity(&graph);
848 assert!(report.valid, "broken: {:?}", report.broken_refs);
849 }
850
851 #[test]
852 fn test_ensure_opset_strips_deprecated_attrs() {
853 use crate::onnx_proto::NodeProto;
854
855 let mut model = ModelProto {
856 opset_import: vec![OperatorSetIdProto {
857 domain: String::new(),
858 version: 7,
859 }],
860 graph: Some(GraphProto {
861 node: vec![
862 NodeProto {
864 op_type: "BatchNormalization".to_string(),
865 input: vec!["x".into(), "s".into(), "b".into(), "m".into(), "v".into()],
866 output: vec!["bn_out".into()],
867 attribute: vec![
868 AttributeProto {
869 name: "epsilon".to_string(),
870 r#type: attribute_proto::AttributeType::Float as i32,
871 f: 1e-5,
872 ..Default::default()
873 },
874 AttributeProto {
875 name: "spatial".to_string(),
876 r#type: attribute_proto::AttributeType::Int as i32,
877 i: 1,
878 ..Default::default()
879 },
880 ],
881 ..Default::default()
882 },
883 NodeProto {
886 op_type: "Dropout".to_string(),
887 input: vec!["bn_out".into()],
888 output: vec!["drop_out".into(), "drop_mask".into()],
889 attribute: vec![AttributeProto {
890 name: "ratio".to_string(),
891 r#type: attribute_proto::AttributeType::Float as i32,
892 f: 0.3,
893 ..Default::default()
894 }],
895 ..Default::default()
896 },
897 NodeProto {
900 op_type: "Softmax".to_string(),
901 input: vec!["drop_out".into()],
902 output: vec!["sm_out".into()],
903 attribute: vec![],
904 ..Default::default()
905 },
906 ],
907 ..Default::default()
908 }),
909 ..Default::default()
910 };
911
912 ensure_opset_version(&mut model, 13);
913
914 let opset = model
916 .opset_import
917 .iter()
918 .find(|o| o.domain.is_empty())
919 .unwrap();
920 assert_eq!(opset.version, 13);
921
922 let graph = model.graph.as_ref().unwrap();
923
924 let bn = &graph.node[0];
926 assert!(
927 !bn.attribute.iter().any(|a| a.name == "spatial"),
928 "BatchNormalization.spatial should be stripped"
929 );
930 assert!(
931 bn.attribute.iter().any(|a| a.name == "epsilon"),
932 "BatchNormalization.epsilon should be preserved"
933 );
934
935 let drop = &graph.node[1];
937 assert!(
938 !drop.attribute.iter().any(|a| a.name == "ratio"),
939 "Dropout.ratio attribute should be removed"
940 );
941 assert_eq!(drop.input.len(), 2, "Dropout should now have 2 inputs");
942 let ratio_init_name = &drop.input[1];
943
944 let ratio_init = graph
946 .initializer
947 .iter()
948 .find(|i| &i.name == ratio_init_name)
949 .expect("Dropout ratio initializer should exist");
950 assert_eq!(ratio_init.data_type, tensor_proto::DataType::Float as i32);
951 assert!(
952 (ratio_init.float_data[0] - 0.3).abs() < 1e-6,
953 "ratio should be 0.3"
954 );
955
956 let sm = &graph.node[2];
958 assert_eq!(sm.op_type, "Softmax");
959 let axis_attr = sm
960 .attribute
961 .iter()
962 .find(|a| a.name == "axis")
963 .expect("Softmax should have axis attribute added");
964 assert_eq!(axis_attr.i, 1, "Softmax axis should be 1 (old default)");
965 }
966
967 #[test]
968 fn test_ensure_opset_no_downgrade() {
969 let mut model = ModelProto {
970 opset_import: vec![OperatorSetIdProto {
971 domain: String::new(),
972 version: 15,
973 }],
974 graph: Some(GraphProto::default()),
975 ..Default::default()
976 };
977
978 ensure_opset_version(&mut model, 10);
980 let opset = model
981 .opset_import
982 .iter()
983 .find(|o| o.domain.is_empty())
984 .unwrap();
985 assert_eq!(opset.version, 15);
986 }
987}