Skip to main content

quantize_rs/onnx_utils/
graph_builder.rs

1//! Graph-level operations for quantized ONNX models.
2//!
3//! Three responsibilities:
4//!   1. **QDQ transform** — replace FP32 initializers with INT8 + DequantizeLinear
5//!   2. **Connectivity validation** — walk the graph and verify every edge resolves
6//!   3. **Opset management** — ensure the model declares a sufficient opset and
7//!      upgrades deprecated op attributes when bumping across breaking opset boundaries
8
9use crate::errors::{QuantizeError, Result};
10use crate::onnx_proto::{
11    attribute_proto, tensor_proto, AttributeProto, GraphProto, ModelProto, OperatorSetIdProto,
12    TensorProto,
13};
14use std::collections::{HashMap, HashSet};
15
16use super::quantization_nodes::{
17    build_dequantize_linear_node, build_quantized_weight_tensor, build_scale_tensor,
18    build_zero_point_tensor, DequantLinearNames,
19};
20
21// ===========================================================================
22// Public types
23// ===========================================================================
24
25/// One weight to convert: FP32 initializer → INT8 + DequantizeLinear block.
26#[derive(Debug)]
27pub struct QdqWeightInput {
28    /// Original initializer name (e.g., `"conv1.weight"`)
29    pub original_name: String,
30    /// Quantized values as i8.  For INT4 these are in [-8, 7]; for INT8 in [-128, 127].
31    /// Always unpacked — one value per element.
32    pub quantized_values: Vec<i8>,
33    /// Quantization scales (FP32).
34    /// Length 1 for per-tensor; one per channel for per-channel.
35    pub scales: Vec<f32>,
36    /// Zero points (INT8).
37    /// Same length as `scales`.
38    pub zero_points: Vec<i8>,
39    /// Original bit-width (4 or 8).  Informational only — ONNX storage is always INT8.
40    /// Persisted in model metadata so `load_quantized_info` can recover it.
41    pub bits: u8,
42    /// Per-channel quantization axis, or `None` for per-tensor.
43    pub axis: Option<usize>,
44}
45
46/// Result of a graph-connectivity check.
47#[derive(Debug)]
48#[must_use]
49pub struct ConnectivityReport {
50    /// `true` if every node input resolves to a known tensor.
51    pub valid: bool,
52    /// Human-readable description of every dangling reference.  Empty when valid.
53    pub broken_refs: Vec<String>,
54}
55
56impl ConnectivityReport {
57    /// Render the report as a printable string (useful for CLI `validate` output).
58    pub fn summary(&self) -> String {
59        if self.valid {
60            "  Graph connectivity: OK\n".to_string()
61        } else {
62            let mut s = format!(
63                "  Graph connectivity: BROKEN ({} dangling reference{})\n",
64                self.broken_refs.len(),
65                if self.broken_refs.len() == 1 { "" } else { "s" }
66            );
67            for (i, r) in self.broken_refs.iter().enumerate() {
68                s.push_str(&format!("    {}. {}\n", i + 1, r));
69            }
70            s
71        }
72    }
73}
74
75// ===========================================================================
76// Connectivity validation
77// ===========================================================================
78
79/// Walk the graph and verify every node input resolves to *something*.
80///
81/// A valid input is exactly one of:
82///   • a declared graph input (`graph.input`)
83///   • an initializer name (`graph.initializer`)
84///   • the output of a node that appears **earlier** in `graph.node`
85///
86/// This is the check ONNX Runtime performs on load — and the check that
87/// v0.2.0's `validate` command skipped, letting the rename bug through.
88pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
89    let mut known: HashSet<String> = HashSet::new();
90
91    // Seed: graph inputs + initializers are always available
92    for inp in &graph.input {
93        known.insert(inp.name.clone());
94    }
95    for init in &graph.initializer {
96        known.insert(init.name.clone());
97    }
98
99    let mut broken = Vec::new();
100
101    // Walk nodes in serialized order; each node's outputs become known afterwards
102    for node in &graph.node {
103        for name in &node.input {
104            if name.is_empty() {
105                continue; // optional input slot — empty string is valid
106            }
107            if !known.contains(name.as_str()) {
108                broken.push(format!(
109                    "Node '{}' (op={}) → unknown input '{}'",
110                    node.name, node.op_type, name
111                ));
112            }
113        }
114        // Register outputs so later nodes can consume them
115        for name in &node.output {
116            if !name.is_empty() {
117                known.insert(name.clone());
118            }
119        }
120    }
121
122    ConnectivityReport {
123        valid: broken.is_empty(),
124        broken_refs: broken,
125    }
126}
127
128// ===========================================================================
129// Opset version management
130// ===========================================================================
131
132/// Ensure the default ONNX domain opset is at least `min_version`.
133///
134/// DequantizeLinear requires opset ≥ 10 (per-tensor) or ≥ 13 (per-channel axis).
135///
136/// When bumping the opset past a breaking boundary, this function also
137/// upgrades deprecated op attributes:
138///   - **opset 9**: `BatchNormalization.spatial` removed (was always 1)
139///   - **opset 12**: `Dropout.ratio` migrated from attribute to input
140pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
141    let old_version = get_opset_version(model);
142
143    // Update or insert the default-domain opset entry
144    let mut found = false;
145    for opset in model.opset_import.iter_mut() {
146        if opset.domain.is_empty() {
147            if opset.version < min_version {
148                opset.version = min_version;
149            }
150            found = true;
151            break;
152        }
153    }
154    if !found {
155        model.opset_import.push(OperatorSetIdProto {
156            domain: String::new(),
157            version: min_version,
158        });
159    }
160
161    // Strip deprecated attributes when crossing breaking opset boundaries
162    if old_version < min_version {
163        if let Some(graph) = model.graph.as_mut() {
164            upgrade_deprecated_ops(graph, old_version, min_version);
165        }
166    }
167}
168
169/// Get the current default-domain opset version (0 if not present).
170fn get_opset_version(model: &ModelProto) -> i64 {
171    model
172        .opset_import
173        .iter()
174        .find(|o| o.domain.is_empty())
175        .map_or(0, |o| o.version)
176}
177
178/// Upgrade graph nodes whose attribute semantics changed between `old_opset`
179/// and `new_opset`.
180///
181/// - **BatchNormalization** (opset 9): `spatial` attribute removed (was always 1)
182/// - **Dropout** (opset 12): `ratio` attribute → 2nd input
183/// - **Softmax / LogSoftmax** (opset 13): default axis changed from 1 to -1
184fn upgrade_deprecated_ops(graph: &mut GraphProto, old_opset: i64, new_opset: i64) {
185    let mut new_initializers: Vec<TensorProto> = Vec::new();
186
187    for node in graph.node.iter_mut() {
188        // Opset 9: BatchNormalization removed the `spatial` attribute.
189        // It was always 1 (the only valid value) and had no effect.
190        if node.op_type == "BatchNormalization" && old_opset < 9 && new_opset >= 9 {
191            node.attribute.retain(|a| a.name != "spatial");
192        }
193
194        // Opset 12: Dropout `ratio` moved from attribute to 2nd input.
195        // Extract the value, remove the attribute, and wire in a constant.
196        if node.op_type == "Dropout" && old_opset < 12 && new_opset >= 12 {
197            let ratio = node
198                .attribute
199                .iter()
200                .find(|a| a.name == "ratio")
201                .map(|a| a.f)
202                .unwrap_or(0.5);
203            node.attribute.retain(|a| a.name != "ratio");
204
205            let init_name = format!(
206                "_quantize_rs_dropout_ratio_{}",
207                node.output.first().map_or("", |s| s.as_str()),
208            );
209            new_initializers.push(TensorProto {
210                name: init_name.clone(),
211                data_type: tensor_proto::DataType::Float as i32,
212                float_data: vec![ratio],
213                ..Default::default()
214            });
215
216            if node.input.len() < 2 {
217                node.input.push(init_name);
218            } else {
219                node.input[1] = init_name;
220            }
221        }
222
223        // Opset 13: Softmax/LogSoftmax default axis changed from 1 to -1.
224        // If the node has no explicit axis attribute, add axis=1 to preserve
225        // the old behavior.
226        if (node.op_type == "Softmax" || node.op_type == "LogSoftmax")
227            && old_opset < 13
228            && new_opset >= 13
229        {
230            let has_axis = node.attribute.iter().any(|a| a.name == "axis");
231            if !has_axis {
232                node.attribute.push(AttributeProto {
233                    name: "axis".to_string(),
234                    r#type: attribute_proto::AttributeType::Int as i32,
235                    i: 1, // old default
236                    ..Default::default()
237                });
238            }
239        }
240    }
241
242    graph.initializer.extend(new_initializers);
243}
244
245// ===========================================================================
246// QDQ transform
247// ===========================================================================
248
249/// Replace FP32 weight initializers with INT8 quantized equivalents +
250/// DequantizeLinear nodes.
251///
252/// ### What happens per weight in `inputs`:
253///
254/// **Removed:**
255///   - Initializer `"{name}"` (the original FP32 weight data)
256///
257/// **Added (initializers):**
258///   - `"{name}_quantized"` — INT8, same shape as original
259///   - `"{name}_scale"`     — FP32 scalar
260///   - `"{name}_zp"`        — INT8 scalar
261///
262/// **Added (node, prepended before all existing nodes):**
263///   - `DequantizeLinear` with output = `"{name}"`
264///
265/// Because the DequantizeLinear output carries the **original** name, every
266/// downstream node (Conv, MatMul, BatchNorm, …) remains completely unchanged.
267/// Graph connectivity is preserved by construction.
268///
269/// ---
270/// ### INT4 storage note
271///
272/// ONNX `DequantizeLinear` requires INT8 input in opsets < 21.  INT4-quantized
273/// values (range [-8, 7]) are widened to INT8 here.  The quantization *accuracy*
274/// is INT4-level (scale and zero_point were computed for the 4-bit range), but
275/// on-disk storage is 4× compression rather than the 8× that bit-packing would
276/// give.  True INT4 packing is planned for a future version (opset 21 or custom op).
277pub fn apply_qdq_transform(graph: &mut GraphProto, inputs: &[QdqWeightInput]) -> Result<()> {
278    // -----------------------------------------------------------------------
279    // 0.  Snapshot shapes before modifying the initializer list
280    // -----------------------------------------------------------------------
281    let shape_map: HashMap<String, Vec<i64>> = graph
282        .initializer
283        .iter()
284        .map(|init| (init.name.clone(), init.dims.clone()))
285        .collect();
286
287    let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
288
289    // -----------------------------------------------------------------------
290    // 1.  Remove the original FP32 initializers for every weight we're replacing
291    // -----------------------------------------------------------------------
292    graph
293        .initializer
294        .retain(|init| !quant_set.contains(init.name.as_str()));
295
296    // -----------------------------------------------------------------------
297    // 1b. Also remove weights from graph.input (critical fix for "Duplicate definition")
298    // -----------------------------------------------------------------------
299    // Some ONNX models list weights as both initializers AND graph inputs.
300    // This is valid ONNX, but when DequantizeLinear outputs reuse the original
301    // weight names, ONNX Runtime sees two definitions of the same tensor.
302    graph
303        .input
304        .retain(|inp| !quant_set.contains(inp.name.as_str()));
305
306    // -----------------------------------------------------------------------
307    // 2.  Add quantized initializer triples + build DequantizeLinear nodes
308    // -----------------------------------------------------------------------
309    let mut dq_nodes = Vec::new();
310
311    for inp in inputs {
312        let shape =
313            shape_map
314                .get(&inp.original_name)
315                .ok_or_else(|| QuantizeError::GraphTransform {
316                    reason: format!(
317                        "Weight '{}' not found in model initializers — \
318                         verify the name matches exactly",
319                        inp.original_name
320                    ),
321                })?;
322
323        let expected_len: i64 = shape.iter().product();
324        if inp.quantized_values.len() as i64 != expected_len {
325            return Err(QuantizeError::GraphTransform {
326                reason: format!(
327                    "Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
328                    inp.original_name,
329                    inp.quantized_values.len(),
330                    shape,
331                    expected_len
332                ),
333            });
334        }
335
336        let names = DequantLinearNames::from_original(&inp.original_name);
337
338        graph.initializer.push(build_quantized_weight_tensor(
339            &names,
340            &inp.quantized_values,
341            shape,
342        ));
343        graph
344            .initializer
345            .push(build_scale_tensor(&names, &inp.scales));
346        graph
347            .initializer
348            .push(build_zero_point_tensor(&names, &inp.zero_points));
349
350        dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
351    }
352
353    // -----------------------------------------------------------------------
354    // 3.  Prepend DequantizeLinear nodes before all existing computation nodes.
355    //     They must appear first so their outputs are "known" when the validator
356    //     (or ONNX Runtime) walks the node list in order.
357    // -----------------------------------------------------------------------
358    let existing_nodes = std::mem::take(&mut graph.node);
359    graph.node = dq_nodes;
360    graph.node.extend(existing_nodes);
361
362    Ok(())
363}
364
365// ===========================================================================
366// Tests
367// ===========================================================================
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::onnx_proto::{
373        tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
374        ValueInfoProto,
375    };
376
377    // -----------------------------------------------------------------------
378    // Test helpers
379    // -----------------------------------------------------------------------
380
381    /// Minimal graph: one graph input "input", one FP32 initializer "w" (shape [2,2]),
382    /// one Conv node consuming both, producing "out".
383    fn make_simple_graph() -> GraphProto {
384        GraphProto {
385            input: vec![ValueInfoProto {
386                name: "input".to_string(),
387                ..Default::default()
388            }],
389            initializer: vec![TensorProto {
390                name: "w".to_string(),
391                data_type: tensor_proto::DataType::Float as i32,
392                dims: vec![2, 2],
393                float_data: vec![1.0, 2.0, 3.0, 4.0],
394                ..Default::default()
395            }],
396            node: vec![NodeProto {
397                op_type: "Conv".to_string(),
398                name: "conv0".to_string(),
399                input: vec!["input".to_string(), "w".to_string()],
400                output: vec!["out".to_string()],
401                ..Default::default()
402            }],
403            ..Default::default()
404        }
405    }
406
407    /// Two-weight graph: "w1" and "w2", two Conv nodes chained.
408    fn make_two_weight_graph() -> GraphProto {
409        GraphProto {
410            input: vec![ValueInfoProto {
411                name: "input".to_string(),
412                ..Default::default()
413            }],
414            initializer: vec![
415                TensorProto {
416                    name: "w1".to_string(),
417                    data_type: tensor_proto::DataType::Float as i32,
418                    dims: vec![2, 2],
419                    float_data: vec![1.0, 2.0, 3.0, 4.0],
420                    ..Default::default()
421                },
422                TensorProto {
423                    name: "w2".to_string(),
424                    data_type: tensor_proto::DataType::Float as i32,
425                    dims: vec![2, 2],
426                    float_data: vec![5.0, 6.0, 7.0, 8.0],
427                    ..Default::default()
428                },
429            ],
430            node: vec![
431                NodeProto {
432                    op_type: "Conv".to_string(),
433                    name: "conv1".to_string(),
434                    input: vec!["input".to_string(), "w1".to_string()],
435                    output: vec!["mid".to_string()],
436                    ..Default::default()
437                },
438                NodeProto {
439                    op_type: "Conv".to_string(),
440                    name: "conv2".to_string(),
441                    input: vec!["mid".to_string(), "w2".to_string()],
442                    output: vec!["out".to_string()],
443                    ..Default::default()
444                },
445            ],
446            ..Default::default()
447        }
448    }
449
450    // -----------------------------------------------------------------------
451    // Connectivity validation tests
452    // -----------------------------------------------------------------------
453
454    #[test]
455    fn test_connectivity_passes_on_valid_graph() {
456        let graph = make_simple_graph();
457        let report = validate_graph_connectivity(&graph);
458        assert!(
459            report.valid,
460            "original graph should be valid; broken: {:?}",
461            report.broken_refs
462        );
463    }
464
465    #[test]
466    fn test_connectivity_detects_renamed_initializer() {
467        // Simulate the exact v0.2.0 bug: rename "w" in the initializer list
468        // without updating the Conv node that references it.
469        let mut graph = make_simple_graph();
470
471        for init in graph.initializer.iter_mut() {
472            if init.name == "w" {
473                init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
474            }
475        }
476
477        let report = validate_graph_connectivity(&graph);
478        assert!(!report.valid, "should detect broken reference to 'w'");
479        assert_eq!(report.broken_refs.len(), 1);
480        assert!(
481            report.broken_refs[0].contains("'w'"),
482            "error should mention 'w': {}",
483            report.broken_refs[0]
484        );
485    }
486
487    #[test]
488    fn test_connectivity_detects_multiple_broken_refs() {
489        let mut graph = make_two_weight_graph();
490
491        for init in graph.initializer.iter_mut() {
492            if init.name == "w1" {
493                init.name = "w1_broken".to_string();
494            } else if init.name == "w2" {
495                init.name = "w2_broken".to_string();
496            }
497        }
498
499        let report = validate_graph_connectivity(&graph);
500        assert!(!report.valid);
501        assert_eq!(report.broken_refs.len(), 2);
502    }
503
504    #[test]
505    fn test_connectivity_summary_formatting() {
506        let valid = ConnectivityReport {
507            valid: true,
508            broken_refs: vec![],
509        };
510        assert!(valid.summary().contains("OK"));
511
512        let broken = ConnectivityReport {
513            valid: false,
514            broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
515        };
516        let s = broken.summary();
517        assert!(s.contains("BROKEN"));
518        assert!(s.contains("1 dangling reference"));
519        assert!(s.contains("unknown input 'y'"));
520    }
521
522    // -----------------------------------------------------------------------
523    // Opset version tests
524    // -----------------------------------------------------------------------
525
526    #[test]
527    fn test_ensure_opset_bumps_low_version() {
528        let mut model = ModelProto {
529            opset_import: vec![OperatorSetIdProto {
530                domain: String::new(),
531                version: 10,
532            }],
533            ..Default::default()
534        };
535
536        ensure_opset_version(&mut model, 13);
537
538        assert_eq!(model.opset_import[0].version, 13);
539    }
540
541    #[test]
542    fn test_ensure_opset_leaves_sufficient_version() {
543        let mut model = ModelProto {
544            opset_import: vec![OperatorSetIdProto {
545                domain: String::new(),
546                version: 17,
547            }],
548            ..Default::default()
549        };
550
551        ensure_opset_version(&mut model, 13);
552
553        assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
554    }
555
556    #[test]
557    fn test_ensure_opset_adds_missing_default_domain() {
558        let mut model = ModelProto::default();
559        // No opset_import at all
560        ensure_opset_version(&mut model, 13);
561
562        assert_eq!(model.opset_import.len(), 1);
563        assert!(model.opset_import[0].domain.is_empty());
564        assert_eq!(model.opset_import[0].version, 13);
565    }
566
567    // -----------------------------------------------------------------------
568    // QDQ transform tests
569    // -----------------------------------------------------------------------
570
571    #[test]
572    fn test_qdq_single_weight_produces_valid_graph() {
573        let mut graph = make_simple_graph();
574
575        let inputs = vec![QdqWeightInput {
576            original_name: "w".to_string(),
577            quantized_values: vec![25, 51, 76, 102],
578            scales: vec![0.039_215_686], // ≈ 1/25.5
579            zero_points: vec![0],
580            bits: 8,
581            axis: None,
582        }];
583
584        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
585
586        let report = validate_graph_connectivity(&graph);
587        assert!(
588            report.valid,
589            "graph after QDQ must be valid; broken: {:?}",
590            report.broken_refs
591        );
592    }
593
594    #[test]
595    fn test_qdq_adds_correct_initializers() {
596        let mut graph = make_simple_graph();
597
598        let inputs = vec![QdqWeightInput {
599            original_name: "w".to_string(),
600            quantized_values: vec![10, 20, 30, 40],
601            scales: vec![0.1],
602            zero_points: vec![-5],
603            bits: 8,
604            axis: None,
605        }];
606
607        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
608
609        let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
610
611        assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
612        assert!(init_names.contains(&"w_scale"), "missing w_scale");
613        assert!(init_names.contains(&"w_zp"), "missing w_zp");
614        assert!(
615            !init_names.contains(&"w"),
616            "original FP32 'w' should be removed"
617        );
618    }
619
620    #[test]
621    fn test_qdq_node_order_dequant_first() {
622        let mut graph = make_simple_graph();
623
624        let inputs = vec![QdqWeightInput {
625            original_name: "w".to_string(),
626            quantized_values: vec![10, 20, 30, 40],
627            scales: vec![0.1],
628            zero_points: vec![0],
629            bits: 8,
630            axis: None,
631        }];
632
633        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
634
635        let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
636
637        assert_eq!(ops.len(), 2);
638        assert_eq!(ops[0], "DequantizeLinear");
639        assert_eq!(ops[1], "Conv");
640    }
641
642    #[test]
643    fn test_qdq_dequant_output_is_original_name() {
644        let mut graph = make_simple_graph();
645
646        let inputs = vec![QdqWeightInput {
647            original_name: "w".to_string(),
648            quantized_values: vec![1, 2, 3, 4],
649            scales: vec![1.0],
650            zero_points: vec![0],
651            bits: 8,
652            axis: None,
653        }];
654
655        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
656
657        let dq = &graph.node[0]; // first node = DequantizeLinear
658        assert_eq!(
659            dq.output[0], "w",
660            "DequantizeLinear output must be original name"
661        );
662    }
663
664    #[test]
665    fn test_qdq_two_weights_both_transformed() {
666        let mut graph = make_two_weight_graph();
667
668        let inputs = vec![
669            QdqWeightInput {
670                original_name: "w1".to_string(),
671                quantized_values: vec![10, 20, 30, 40],
672                scales: vec![0.1],
673                zero_points: vec![0],
674                bits: 8,
675                axis: None,
676            },
677            QdqWeightInput {
678                original_name: "w2".to_string(),
679                quantized_values: vec![50, 60, 70, 80],
680                scales: vec![0.2],
681                zero_points: vec![-1],
682                bits: 8,
683                axis: None,
684            },
685        ];
686
687        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
688
689        // Connectivity must still be valid
690        let report = validate_graph_connectivity(&graph);
691        assert!(
692            report.valid,
693            "two-weight graph broken: {:?}",
694            report.broken_refs
695        );
696
697        // Should have 2 DequantizeLinear + 2 Conv = 4 nodes
698        assert_eq!(graph.node.len(), 4);
699
700        // First two nodes are DequantizeLinear
701        assert_eq!(graph.node[0].op_type, "DequantizeLinear");
702        assert_eq!(graph.node[1].op_type, "DequantizeLinear");
703
704        // Their outputs are the original weight names
705        let dq_outputs: Vec<&str> = graph
706            .node
707            .iter()
708            .take(2)
709            .map(|n| n.output[0].as_str())
710            .collect();
711        assert!(dq_outputs.contains(&"w1"));
712        assert!(dq_outputs.contains(&"w2"));
713    }
714
715    #[test]
716    fn test_qdq_int4_values_stored_as_int8() {
717        let mut graph = make_simple_graph();
718
719        // INT4 range [-8, 7] — these arrive as i8 from ensure_unpacked()
720        let inputs = vec![QdqWeightInput {
721            original_name: "w".to_string(),
722            quantized_values: vec![-8, -1, 0, 7],
723            scales: vec![0.5],
724            zero_points: vec![0],
725            bits: 4, // flag says INT4; storage must still be INT8
726            axis: None,
727        }];
728
729        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
730
731        let quant_init = graph
732            .initializer
733            .iter()
734            .find(|i| i.name == "w_quantized")
735            .expect("w_quantized not found");
736
737        // Data type must be INT8 (ONNX DequantizeLinear requirement)
738        assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
739
740        // Byte-level round-trip must be exact
741        let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
742        assert_eq!(recovered, vec![-8, -1, 0, 7]);
743    }
744
745    #[test]
746    fn test_qdq_unknown_weight_returns_error() {
747        let mut graph = make_simple_graph();
748
749        let inputs = vec![QdqWeightInput {
750            original_name: "does_not_exist".to_string(),
751            quantized_values: vec![1, 2, 3],
752            scales: vec![1.0],
753            zero_points: vec![0],
754            bits: 8,
755            axis: None,
756        }];
757
758        let result = apply_qdq_transform(&mut graph, &inputs);
759        assert!(result.is_err());
760        assert!(
761            result.unwrap_err().to_string().contains("does_not_exist"),
762            "error should name the missing weight"
763        );
764    }
765
766    #[test]
767    fn test_qdq_non_quantized_initializers_preserved() {
768        // Add an extra initializer "bias" that is NOT being quantized.
769        // It must survive the transform untouched.
770        let mut graph = make_simple_graph();
771
772        graph.initializer.push(TensorProto {
773            name: "bias".to_string(),
774            data_type: tensor_proto::DataType::Float as i32,
775            dims: vec![2],
776            float_data: vec![0.1, 0.2],
777            ..Default::default()
778        });
779
780        // Also add "bias" as a Conv input so connectivity stays valid
781        graph.node[0].input.push("bias".to_string());
782
783        let inputs = vec![QdqWeightInput {
784            original_name: "w".to_string(),
785            quantized_values: vec![10, 20, 30, 40],
786            scales: vec![0.1],
787            zero_points: vec![0],
788            bits: 8,
789            axis: None,
790        }];
791
792        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
793
794        // "bias" must still be present and untouched
795        let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
796
797        assert!(
798            bias_init.is_some(),
799            "non-quantized 'bias' initializer must be preserved"
800        );
801        assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
802
803        // Full connectivity check
804        let report = validate_graph_connectivity(&graph);
805        assert!(report.valid, "broken: {:?}", report.broken_refs);
806    }
807
808    #[test]
809    fn test_ensure_opset_strips_deprecated_attrs() {
810        use crate::onnx_proto::NodeProto;
811
812        let mut model = ModelProto {
813            opset_import: vec![OperatorSetIdProto {
814                domain: String::new(),
815                version: 7,
816            }],
817            graph: Some(GraphProto {
818                node: vec![
819                    // BatchNormalization with deprecated `spatial` attr (removed opset 9)
820                    NodeProto {
821                        op_type: "BatchNormalization".to_string(),
822                        input: vec!["x".into(), "s".into(), "b".into(), "m".into(), "v".into()],
823                        output: vec!["bn_out".into()],
824                        attribute: vec![
825                            AttributeProto {
826                                name: "epsilon".to_string(),
827                                r#type: attribute_proto::AttributeType::Float as i32,
828                                f: 1e-5,
829                                ..Default::default()
830                            },
831                            AttributeProto {
832                                name: "spatial".to_string(),
833                                r#type: attribute_proto::AttributeType::Int as i32,
834                                i: 1,
835                                ..Default::default()
836                            },
837                        ],
838                        ..Default::default()
839                    },
840                    // Dropout with `ratio` attribute — should be migrated
841                    // from attribute to 2nd input.
842                    NodeProto {
843                        op_type: "Dropout".to_string(),
844                        input: vec!["bn_out".into()],
845                        output: vec!["drop_out".into(), "drop_mask".into()],
846                        attribute: vec![AttributeProto {
847                            name: "ratio".to_string(),
848                            r#type: attribute_proto::AttributeType::Float as i32,
849                            f: 0.3,
850                            ..Default::default()
851                        }],
852                        ..Default::default()
853                    },
854                    // Softmax with NO axis attribute — old default is 1,
855                    // opset 13 changes default to -1, so axis=1 must be added.
856                    NodeProto {
857                        op_type: "Softmax".to_string(),
858                        input: vec!["drop_out".into()],
859                        output: vec!["sm_out".into()],
860                        attribute: vec![],
861                        ..Default::default()
862                    },
863                ],
864                ..Default::default()
865            }),
866            ..Default::default()
867        };
868
869        ensure_opset_version(&mut model, 13);
870
871        // Opset should be 13
872        let opset = model
873            .opset_import
874            .iter()
875            .find(|o| o.domain.is_empty())
876            .unwrap();
877        assert_eq!(opset.version, 13);
878
879        let graph = model.graph.as_ref().unwrap();
880
881        // BatchNormalization: `spatial` must be removed, `epsilon` kept
882        let bn = &graph.node[0];
883        assert!(
884            !bn.attribute.iter().any(|a| a.name == "spatial"),
885            "BatchNormalization.spatial should be stripped"
886        );
887        assert!(
888            bn.attribute.iter().any(|a| a.name == "epsilon"),
889            "BatchNormalization.epsilon should be preserved"
890        );
891
892        // Dropout: `ratio` attribute must be removed and moved to 2nd input
893        let drop = &graph.node[1];
894        assert!(
895            !drop.attribute.iter().any(|a| a.name == "ratio"),
896            "Dropout.ratio attribute should be removed"
897        );
898        assert_eq!(drop.input.len(), 2, "Dropout should now have 2 inputs");
899        let ratio_init_name = &drop.input[1];
900
901        // The ratio value should be stored as an initializer
902        let ratio_init = graph
903            .initializer
904            .iter()
905            .find(|i| &i.name == ratio_init_name)
906            .expect("Dropout ratio initializer should exist");
907        assert_eq!(ratio_init.data_type, tensor_proto::DataType::Float as i32);
908        assert!(
909            (ratio_init.float_data[0] - 0.3).abs() < 1e-6,
910            "ratio should be 0.3"
911        );
912
913        // Softmax: must have explicit axis=1 added (old default, since opset 13 changed it to -1)
914        let sm = &graph.node[2];
915        assert_eq!(sm.op_type, "Softmax");
916        let axis_attr = sm
917            .attribute
918            .iter()
919            .find(|a| a.name == "axis")
920            .expect("Softmax should have axis attribute added");
921        assert_eq!(axis_attr.i, 1, "Softmax axis should be 1 (old default)");
922    }
923
924    #[test]
925    fn test_ensure_opset_no_downgrade() {
926        let mut model = ModelProto {
927            opset_import: vec![OperatorSetIdProto {
928                domain: String::new(),
929                version: 15,
930            }],
931            graph: Some(GraphProto::default()),
932            ..Default::default()
933        };
934
935        // Requesting 10 should NOT downgrade from 15
936        ensure_opset_version(&mut model, 10);
937        let opset = model
938            .opset_import
939            .iter()
940            .find(|o| o.domain.is_empty())
941            .unwrap();
942        assert_eq!(opset.version, 15);
943    }
944}