Skip to main content

quantize_rs/onnx_utils/
graph_builder.rs

1//! Graph-level operations for quantized ONNX models.
2//!
3//! Three responsibilities:
4//!   1. **QDQ transform** — replace FP32 initializers with INT8 + DequantizeLinear
5//!   2. **Connectivity validation** — walk the graph and verify every edge resolves
6//!   3. **Opset management** — ensure the model declares a sufficient opset and
7//!      upgrades deprecated op attributes when bumping across breaking opset boundaries
8
9use crate::errors::{QuantizeError, Result};
10use crate::onnx_proto::{
11    attribute_proto, tensor_proto, AttributeProto, GraphProto, ModelProto, OperatorSetIdProto,
12    TensorProto,
13};
14use std::collections::{HashMap, HashSet};
15
16use super::quantization_nodes::{
17    build_dequantize_linear_node, build_quantized_weight_tensor, build_scale_tensor,
18    build_zero_point_tensor, DequantLinearNames, StorageFormat,
19};
20
21// ===========================================================================
22// Public types
23// ===========================================================================
24
25/// One weight to convert: FP32 initializer → INT8 + DequantizeLinear block.
26#[derive(Debug, Clone)]
27pub struct QdqWeightInput {
28    /// Original initializer name (e.g., `"conv1.weight"`)
29    pub original_name: String,
30    /// Quantized values as i8.  For INT4 these are in [-8, 7]; for INT8 in [-128, 127].
31    /// Always unpacked — one value per element.
32    pub quantized_values: Vec<i8>,
33    /// Quantization scales (FP32).
34    /// Length 1 for per-tensor; one per channel for per-channel.
35    pub scales: Vec<f32>,
36    /// Zero points (INT8).
37    /// Same length as `scales`.
38    pub zero_points: Vec<i8>,
39    /// Original bit-width (4 or 8).  Informational only — ONNX storage is always INT8.
40    /// Persisted in model metadata so `load_quantized_info` can recover it.
41    pub bits: u8,
42    /// Per-channel quantization axis, or `None` for per-tensor.
43    pub axis: Option<usize>,
44}
45
46/// Options controlling how a quantized model is written to disk.
47#[derive(Debug, Clone, Copy, Default)]
48pub struct SaveOptions {
49    /// Use native ONNX INT4 storage (opset 21, `DataType::Int4`) instead of
50    /// widening INT4 values into INT8 bytes.
51    ///
52    /// Native storage halves the on-disk size of INT4-quantized weights
53    /// (true 8× compression from FP32) but raises the required opset to 21.
54    /// Older runtimes without opset 21 support will refuse to load the model.
55    /// Defaults to `false` for backward compatibility.
56    ///
57    /// Only affects weights where [`QdqWeightInput::bits`] is 4; INT8 weights
58    /// are stored as INT8 regardless of this flag.
59    pub native_int4: bool,
60}
61
62impl SaveOptions {
63    /// Opt in to native INT4 storage (requires opset 21).
64    pub fn with_native_int4(mut self, enabled: bool) -> Self {
65        self.native_int4 = enabled;
66        self
67    }
68}
69
70/// Result of a graph-connectivity check.
71#[derive(Debug)]
72#[must_use]
73pub struct ConnectivityReport {
74    /// `true` if every node input resolves to a known tensor.
75    pub valid: bool,
76    /// Human-readable description of every dangling reference.  Empty when valid.
77    pub broken_refs: Vec<String>,
78}
79
80impl ConnectivityReport {
81    /// Render the report as a printable string (useful for CLI `validate` output).
82    pub fn summary(&self) -> String {
83        if self.valid {
84            "  Graph connectivity: OK\n".to_string()
85        } else {
86            let mut s = format!(
87                "  Graph connectivity: BROKEN ({} dangling reference{})\n",
88                self.broken_refs.len(),
89                if self.broken_refs.len() == 1 { "" } else { "s" }
90            );
91            for (i, r) in self.broken_refs.iter().enumerate() {
92                s.push_str(&format!("    {}. {}\n", i + 1, r));
93            }
94            s
95        }
96    }
97}
98
99// ===========================================================================
100// Connectivity validation
101// ===========================================================================
102
103/// Walk the graph and verify every node input resolves to *something*.
104///
105/// A valid input is exactly one of:
106///   • a declared graph input (`graph.input`)
107///   • an initializer name (`graph.initializer`)
108///   • the output of a node that appears **earlier** in `graph.node`
109///
110/// This is the check ONNX Runtime performs on load — and the check that
111/// v0.2.0's `validate` command skipped, letting the rename bug through.
112pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
113    let mut known: HashSet<String> = HashSet::new();
114
115    // Seed: graph inputs + initializers are always available
116    for inp in &graph.input {
117        known.insert(inp.name.clone());
118    }
119    for init in &graph.initializer {
120        known.insert(init.name.clone());
121    }
122
123    let mut broken = Vec::new();
124
125    // Walk nodes in serialized order; each node's outputs become known afterwards
126    for node in &graph.node {
127        for name in &node.input {
128            if name.is_empty() {
129                continue; // optional input slot — empty string is valid
130            }
131            if !known.contains(name.as_str()) {
132                broken.push(format!(
133                    "Node '{}' (op={}) → unknown input '{}'",
134                    node.name, node.op_type, name
135                ));
136            }
137        }
138        // Register outputs so later nodes can consume them
139        for name in &node.output {
140            if !name.is_empty() {
141                known.insert(name.clone());
142            }
143        }
144    }
145
146    ConnectivityReport {
147        valid: broken.is_empty(),
148        broken_refs: broken,
149    }
150}
151
152// ===========================================================================
153// Opset version management
154// ===========================================================================
155
156/// Ensure the default ONNX domain opset is at least `min_version`.
157///
158/// DequantizeLinear requires opset ≥ 10 (per-tensor) or ≥ 13 (per-channel axis).
159///
160/// When bumping the opset past a breaking boundary, this function also
161/// upgrades deprecated op attributes:
162///   - **opset 9**: `BatchNormalization.spatial` removed (was always 1)
163///   - **opset 12**: `Dropout.ratio` migrated from attribute to input
164pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
165    let old_version = get_opset_version(model);
166
167    // Update or insert the default-domain opset entry
168    let mut found = false;
169    for opset in model.opset_import.iter_mut() {
170        if opset.domain.is_empty() {
171            if opset.version < min_version {
172                opset.version = min_version;
173            }
174            found = true;
175            break;
176        }
177    }
178    if !found {
179        model.opset_import.push(OperatorSetIdProto {
180            domain: String::new(),
181            version: min_version,
182        });
183    }
184
185    // Strip deprecated attributes when crossing breaking opset boundaries
186    if old_version < min_version {
187        if let Some(graph) = model.graph.as_mut() {
188            upgrade_deprecated_ops(graph, old_version, min_version);
189        }
190    }
191}
192
193/// Get the current default-domain opset version (0 if not present).
194fn get_opset_version(model: &ModelProto) -> i64 {
195    model
196        .opset_import
197        .iter()
198        .find(|o| o.domain.is_empty())
199        .map_or(0, |o| o.version)
200}
201
202/// Upgrade graph nodes whose attribute semantics changed between `old_opset`
203/// and `new_opset`.
204///
205/// - **BatchNormalization** (opset 9): `spatial` attribute removed (was always 1)
206/// - **Dropout** (opset 12): `ratio` attribute → 2nd input
207/// - **Softmax / LogSoftmax** (opset 13): default axis changed from 1 to -1
208fn upgrade_deprecated_ops(graph: &mut GraphProto, old_opset: i64, new_opset: i64) {
209    let mut new_initializers: Vec<TensorProto> = Vec::new();
210
211    for node in graph.node.iter_mut() {
212        // Opset 9: BatchNormalization removed the `spatial` attribute.
213        // It was always 1 (the only valid value) and had no effect.
214        if node.op_type == "BatchNormalization" && old_opset < 9 && new_opset >= 9 {
215            node.attribute.retain(|a| a.name != "spatial");
216        }
217
218        // Opset 12: Dropout `ratio` moved from attribute to 2nd input.
219        // Extract the value, remove the attribute, and wire in a constant.
220        if node.op_type == "Dropout" && old_opset < 12 && new_opset >= 12 {
221            let ratio = node
222                .attribute
223                .iter()
224                .find(|a| a.name == "ratio")
225                .map(|a| a.f)
226                .unwrap_or(0.5);
227            node.attribute.retain(|a| a.name != "ratio");
228
229            let init_name = format!(
230                "_quantize_rs_dropout_ratio_{}",
231                node.output.first().map_or("", |s| s.as_str()),
232            );
233            new_initializers.push(TensorProto {
234                name: init_name.clone(),
235                data_type: tensor_proto::DataType::Float as i32,
236                float_data: vec![ratio],
237                ..Default::default()
238            });
239
240            if node.input.len() < 2 {
241                node.input.push(init_name);
242            } else {
243                node.input[1] = init_name;
244            }
245        }
246
247        // Opset 13: Softmax/LogSoftmax default axis changed from 1 to -1.
248        // If the node has no explicit axis attribute, add axis=1 to preserve
249        // the old behavior.
250        if (node.op_type == "Softmax" || node.op_type == "LogSoftmax")
251            && old_opset < 13
252            && new_opset >= 13
253        {
254            let has_axis = node.attribute.iter().any(|a| a.name == "axis");
255            if !has_axis {
256                node.attribute.push(AttributeProto {
257                    name: "axis".to_string(),
258                    r#type: attribute_proto::AttributeType::Int as i32,
259                    i: 1, // old default
260                    ..Default::default()
261                });
262            }
263        }
264    }
265
266    graph.initializer.extend(new_initializers);
267}
268
269// ===========================================================================
270// QDQ transform
271// ===========================================================================
272
273/// Replace FP32 weight initializers with INT8 quantized equivalents +
274/// DequantizeLinear nodes.
275///
276/// ### What happens per weight in `inputs`:
277///
278/// **Removed:**
279///   - Initializer `"{name}"` (the original FP32 weight data)
280///
281/// **Added (initializers):**
282///   - `"{name}_quantized"` — INT8, same shape as original
283///   - `"{name}_scale"`     — FP32 scalar
284///   - `"{name}_zp"`        — INT8 scalar
285///
286/// **Added (node, prepended before all existing nodes):**
287///   - `DequantizeLinear` with output = `"{name}"`
288///
289/// Because the DequantizeLinear output carries the **original** name, every
290/// downstream node (Conv, MatMul, BatchNorm, …) remains completely unchanged.
291/// Graph connectivity is preserved by construction.
292///
293/// ---
294/// ### INT4 storage note
295///
296/// ONNX `DequantizeLinear` requires INT8 input in opsets < 21.  By default,
297/// INT4-quantized values (range [-8, 7]) are widened to INT8 here — 4×
298/// compression from FP32.  To get the full 8× compression, pass
299/// [`SaveOptions::with_native_int4(true)`] to [`apply_qdq_transform_with_options`];
300/// that emits native `DataType::Int4` (opset 21) with two values packed per byte.
301pub fn apply_qdq_transform(graph: &mut GraphProto, inputs: &[QdqWeightInput]) -> Result<()> {
302    apply_qdq_transform_with_options(graph, inputs, SaveOptions::default())
303}
304
305/// Same as [`apply_qdq_transform`] but configurable via [`SaveOptions`].
306///
307/// Prefer this entry point for any new code; the short-form wrapper exists
308/// only for backward compatibility.
309pub fn apply_qdq_transform_with_options(
310    graph: &mut GraphProto,
311    inputs: &[QdqWeightInput],
312    options: SaveOptions,
313) -> Result<()> {
314    // -----------------------------------------------------------------------
315    // 0.  Snapshot shapes before modifying the initializer list
316    // -----------------------------------------------------------------------
317    let shape_map: HashMap<String, Vec<i64>> = graph
318        .initializer
319        .iter()
320        .map(|init| (init.name.clone(), init.dims.clone()))
321        .collect();
322
323    let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
324
325    // -----------------------------------------------------------------------
326    // 1.  Remove the original FP32 initializers for every weight we're replacing
327    // -----------------------------------------------------------------------
328    graph
329        .initializer
330        .retain(|init| !quant_set.contains(init.name.as_str()));
331
332    // -----------------------------------------------------------------------
333    // 1b. Also remove weights from graph.input (critical fix for "Duplicate definition")
334    // -----------------------------------------------------------------------
335    // Some ONNX models list weights as both initializers AND graph inputs.
336    // This is valid ONNX, but when DequantizeLinear outputs reuse the original
337    // weight names, ONNX Runtime sees two definitions of the same tensor.
338    graph
339        .input
340        .retain(|inp| !quant_set.contains(inp.name.as_str()));
341
342    // -----------------------------------------------------------------------
343    // 2.  Add quantized initializer triples + build DequantizeLinear nodes
344    // -----------------------------------------------------------------------
345    let mut dq_nodes = Vec::new();
346
347    for inp in inputs {
348        let shape =
349            shape_map
350                .get(&inp.original_name)
351                .ok_or_else(|| QuantizeError::GraphTransform {
352                    reason: format!(
353                        "Weight '{}' not found in model initializers — \
354                         verify the name matches exactly",
355                        inp.original_name
356                    ),
357                })?;
358
359        let expected_len: i64 = shape.iter().product();
360        if inp.quantized_values.len() as i64 != expected_len {
361            return Err(QuantizeError::GraphTransform {
362                reason: format!(
363                    "Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
364                    inp.original_name,
365                    inp.quantized_values.len(),
366                    shape,
367                    expected_len
368                ),
369            });
370        }
371
372        let names = DequantLinearNames::from_original(&inp.original_name);
373
374        let format = if options.native_int4 && inp.bits == 4 {
375            StorageFormat::NativeInt4
376        } else {
377            StorageFormat::Int8Widened
378        };
379
380        graph.initializer.push(build_quantized_weight_tensor(
381            &names,
382            &inp.quantized_values,
383            shape,
384            format,
385        ));
386        graph
387            .initializer
388            .push(build_scale_tensor(&names, &inp.scales));
389        graph
390            .initializer
391            .push(build_zero_point_tensor(&names, &inp.zero_points, format));
392
393        dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
394    }
395
396    // -----------------------------------------------------------------------
397    // 3.  Prepend DequantizeLinear nodes before all existing computation nodes.
398    //     They must appear first so their outputs are "known" when the validator
399    //     (or ONNX Runtime) walks the node list in order.
400    // -----------------------------------------------------------------------
401    let existing_nodes = std::mem::take(&mut graph.node);
402    graph.node = dq_nodes;
403    graph.node.extend(existing_nodes);
404
405    Ok(())
406}
407
408// ===========================================================================
409// Tests
410// ===========================================================================
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::onnx_proto::{
416        tensor_proto, GraphProto, ModelProto, NodeProto, OperatorSetIdProto, TensorProto,
417        ValueInfoProto,
418    };
419
420    // -----------------------------------------------------------------------
421    // Test helpers
422    // -----------------------------------------------------------------------
423
424    /// Minimal graph: one graph input "input", one FP32 initializer "w" (shape [2,2]),
425    /// one Conv node consuming both, producing "out".
426    fn make_simple_graph() -> GraphProto {
427        GraphProto {
428            input: vec![ValueInfoProto {
429                name: "input".to_string(),
430                ..Default::default()
431            }],
432            initializer: vec![TensorProto {
433                name: "w".to_string(),
434                data_type: tensor_proto::DataType::Float as i32,
435                dims: vec![2, 2],
436                float_data: vec![1.0, 2.0, 3.0, 4.0],
437                ..Default::default()
438            }],
439            node: vec![NodeProto {
440                op_type: "Conv".to_string(),
441                name: "conv0".to_string(),
442                input: vec!["input".to_string(), "w".to_string()],
443                output: vec!["out".to_string()],
444                ..Default::default()
445            }],
446            ..Default::default()
447        }
448    }
449
450    /// Two-weight graph: "w1" and "w2", two Conv nodes chained.
451    fn make_two_weight_graph() -> GraphProto {
452        GraphProto {
453            input: vec![ValueInfoProto {
454                name: "input".to_string(),
455                ..Default::default()
456            }],
457            initializer: vec![
458                TensorProto {
459                    name: "w1".to_string(),
460                    data_type: tensor_proto::DataType::Float as i32,
461                    dims: vec![2, 2],
462                    float_data: vec![1.0, 2.0, 3.0, 4.0],
463                    ..Default::default()
464                },
465                TensorProto {
466                    name: "w2".to_string(),
467                    data_type: tensor_proto::DataType::Float as i32,
468                    dims: vec![2, 2],
469                    float_data: vec![5.0, 6.0, 7.0, 8.0],
470                    ..Default::default()
471                },
472            ],
473            node: vec![
474                NodeProto {
475                    op_type: "Conv".to_string(),
476                    name: "conv1".to_string(),
477                    input: vec!["input".to_string(), "w1".to_string()],
478                    output: vec!["mid".to_string()],
479                    ..Default::default()
480                },
481                NodeProto {
482                    op_type: "Conv".to_string(),
483                    name: "conv2".to_string(),
484                    input: vec!["mid".to_string(), "w2".to_string()],
485                    output: vec!["out".to_string()],
486                    ..Default::default()
487                },
488            ],
489            ..Default::default()
490        }
491    }
492
493    // -----------------------------------------------------------------------
494    // Connectivity validation tests
495    // -----------------------------------------------------------------------
496
497    #[test]
498    fn test_connectivity_passes_on_valid_graph() {
499        let graph = make_simple_graph();
500        let report = validate_graph_connectivity(&graph);
501        assert!(
502            report.valid,
503            "original graph should be valid; broken: {:?}",
504            report.broken_refs
505        );
506    }
507
508    #[test]
509    fn test_connectivity_detects_renamed_initializer() {
510        // Simulate the exact v0.2.0 bug: rename "w" in the initializer list
511        // without updating the Conv node that references it.
512        let mut graph = make_simple_graph();
513
514        for init in graph.initializer.iter_mut() {
515            if init.name == "w" {
516                init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
517            }
518        }
519
520        let report = validate_graph_connectivity(&graph);
521        assert!(!report.valid, "should detect broken reference to 'w'");
522        assert_eq!(report.broken_refs.len(), 1);
523        assert!(
524            report.broken_refs[0].contains("'w'"),
525            "error should mention 'w': {}",
526            report.broken_refs[0]
527        );
528    }
529
530    #[test]
531    fn test_connectivity_detects_multiple_broken_refs() {
532        let mut graph = make_two_weight_graph();
533
534        for init in graph.initializer.iter_mut() {
535            if init.name == "w1" {
536                init.name = "w1_broken".to_string();
537            } else if init.name == "w2" {
538                init.name = "w2_broken".to_string();
539            }
540        }
541
542        let report = validate_graph_connectivity(&graph);
543        assert!(!report.valid);
544        assert_eq!(report.broken_refs.len(), 2);
545    }
546
547    #[test]
548    fn test_connectivity_summary_formatting() {
549        let valid = ConnectivityReport {
550            valid: true,
551            broken_refs: vec![],
552        };
553        assert!(valid.summary().contains("OK"));
554
555        let broken = ConnectivityReport {
556            valid: false,
557            broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
558        };
559        let s = broken.summary();
560        assert!(s.contains("BROKEN"));
561        assert!(s.contains("1 dangling reference"));
562        assert!(s.contains("unknown input 'y'"));
563    }
564
565    // -----------------------------------------------------------------------
566    // Opset version tests
567    // -----------------------------------------------------------------------
568
569    #[test]
570    fn test_ensure_opset_bumps_low_version() {
571        let mut model = ModelProto {
572            opset_import: vec![OperatorSetIdProto {
573                domain: String::new(),
574                version: 10,
575            }],
576            ..Default::default()
577        };
578
579        ensure_opset_version(&mut model, 13);
580
581        assert_eq!(model.opset_import[0].version, 13);
582    }
583
584    #[test]
585    fn test_ensure_opset_leaves_sufficient_version() {
586        let mut model = ModelProto {
587            opset_import: vec![OperatorSetIdProto {
588                domain: String::new(),
589                version: 17,
590            }],
591            ..Default::default()
592        };
593
594        ensure_opset_version(&mut model, 13);
595
596        assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
597    }
598
599    #[test]
600    fn test_ensure_opset_adds_missing_default_domain() {
601        let mut model = ModelProto::default();
602        // No opset_import at all
603        ensure_opset_version(&mut model, 13);
604
605        assert_eq!(model.opset_import.len(), 1);
606        assert!(model.opset_import[0].domain.is_empty());
607        assert_eq!(model.opset_import[0].version, 13);
608    }
609
610    // -----------------------------------------------------------------------
611    // QDQ transform tests
612    // -----------------------------------------------------------------------
613
614    #[test]
615    fn test_qdq_single_weight_produces_valid_graph() {
616        let mut graph = make_simple_graph();
617
618        let inputs = vec![QdqWeightInput {
619            original_name: "w".to_string(),
620            quantized_values: vec![25, 51, 76, 102],
621            scales: vec![0.039_215_686], // ≈ 1/25.5
622            zero_points: vec![0],
623            bits: 8,
624            axis: None,
625        }];
626
627        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
628
629        let report = validate_graph_connectivity(&graph);
630        assert!(
631            report.valid,
632            "graph after QDQ must be valid; broken: {:?}",
633            report.broken_refs
634        );
635    }
636
637    #[test]
638    fn test_qdq_adds_correct_initializers() {
639        let mut graph = make_simple_graph();
640
641        let inputs = vec![QdqWeightInput {
642            original_name: "w".to_string(),
643            quantized_values: vec![10, 20, 30, 40],
644            scales: vec![0.1],
645            zero_points: vec![-5],
646            bits: 8,
647            axis: None,
648        }];
649
650        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
651
652        let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
653
654        assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
655        assert!(init_names.contains(&"w_scale"), "missing w_scale");
656        assert!(init_names.contains(&"w_zp"), "missing w_zp");
657        assert!(
658            !init_names.contains(&"w"),
659            "original FP32 'w' should be removed"
660        );
661    }
662
663    #[test]
664    fn test_qdq_node_order_dequant_first() {
665        let mut graph = make_simple_graph();
666
667        let inputs = vec![QdqWeightInput {
668            original_name: "w".to_string(),
669            quantized_values: vec![10, 20, 30, 40],
670            scales: vec![0.1],
671            zero_points: vec![0],
672            bits: 8,
673            axis: None,
674        }];
675
676        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
677
678        let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
679
680        assert_eq!(ops.len(), 2);
681        assert_eq!(ops[0], "DequantizeLinear");
682        assert_eq!(ops[1], "Conv");
683    }
684
685    #[test]
686    fn test_qdq_dequant_output_is_original_name() {
687        let mut graph = make_simple_graph();
688
689        let inputs = vec![QdqWeightInput {
690            original_name: "w".to_string(),
691            quantized_values: vec![1, 2, 3, 4],
692            scales: vec![1.0],
693            zero_points: vec![0],
694            bits: 8,
695            axis: None,
696        }];
697
698        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
699
700        let dq = &graph.node[0]; // first node = DequantizeLinear
701        assert_eq!(
702            dq.output[0], "w",
703            "DequantizeLinear output must be original name"
704        );
705    }
706
707    #[test]
708    fn test_qdq_two_weights_both_transformed() {
709        let mut graph = make_two_weight_graph();
710
711        let inputs = vec![
712            QdqWeightInput {
713                original_name: "w1".to_string(),
714                quantized_values: vec![10, 20, 30, 40],
715                scales: vec![0.1],
716                zero_points: vec![0],
717                bits: 8,
718                axis: None,
719            },
720            QdqWeightInput {
721                original_name: "w2".to_string(),
722                quantized_values: vec![50, 60, 70, 80],
723                scales: vec![0.2],
724                zero_points: vec![-1],
725                bits: 8,
726                axis: None,
727            },
728        ];
729
730        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
731
732        // Connectivity must still be valid
733        let report = validate_graph_connectivity(&graph);
734        assert!(
735            report.valid,
736            "two-weight graph broken: {:?}",
737            report.broken_refs
738        );
739
740        // Should have 2 DequantizeLinear + 2 Conv = 4 nodes
741        assert_eq!(graph.node.len(), 4);
742
743        // First two nodes are DequantizeLinear
744        assert_eq!(graph.node[0].op_type, "DequantizeLinear");
745        assert_eq!(graph.node[1].op_type, "DequantizeLinear");
746
747        // Their outputs are the original weight names
748        let dq_outputs: Vec<&str> = graph
749            .node
750            .iter()
751            .take(2)
752            .map(|n| n.output[0].as_str())
753            .collect();
754        assert!(dq_outputs.contains(&"w1"));
755        assert!(dq_outputs.contains(&"w2"));
756    }
757
758    #[test]
759    fn test_qdq_int4_values_stored_as_int8() {
760        let mut graph = make_simple_graph();
761
762        // INT4 range [-8, 7] — these arrive as i8 from ensure_unpacked()
763        let inputs = vec![QdqWeightInput {
764            original_name: "w".to_string(),
765            quantized_values: vec![-8, -1, 0, 7],
766            scales: vec![0.5],
767            zero_points: vec![0],
768            bits: 4, // flag says INT4; storage must still be INT8
769            axis: None,
770        }];
771
772        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
773
774        let quant_init = graph
775            .initializer
776            .iter()
777            .find(|i| i.name == "w_quantized")
778            .expect("w_quantized not found");
779
780        // Data type must be INT8 (ONNX DequantizeLinear requirement)
781        assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
782
783        // Byte-level round-trip must be exact
784        let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
785        assert_eq!(recovered, vec![-8, -1, 0, 7]);
786    }
787
788    #[test]
789    fn test_qdq_unknown_weight_returns_error() {
790        let mut graph = make_simple_graph();
791
792        let inputs = vec![QdqWeightInput {
793            original_name: "does_not_exist".to_string(),
794            quantized_values: vec![1, 2, 3],
795            scales: vec![1.0],
796            zero_points: vec![0],
797            bits: 8,
798            axis: None,
799        }];
800
801        let result = apply_qdq_transform(&mut graph, &inputs);
802        assert!(result.is_err());
803        assert!(
804            result.unwrap_err().to_string().contains("does_not_exist"),
805            "error should name the missing weight"
806        );
807    }
808
809    #[test]
810    fn test_qdq_non_quantized_initializers_preserved() {
811        // Add an extra initializer "bias" that is NOT being quantized.
812        // It must survive the transform untouched.
813        let mut graph = make_simple_graph();
814
815        graph.initializer.push(TensorProto {
816            name: "bias".to_string(),
817            data_type: tensor_proto::DataType::Float as i32,
818            dims: vec![2],
819            float_data: vec![0.1, 0.2],
820            ..Default::default()
821        });
822
823        // Also add "bias" as a Conv input so connectivity stays valid
824        graph.node[0].input.push("bias".to_string());
825
826        let inputs = vec![QdqWeightInput {
827            original_name: "w".to_string(),
828            quantized_values: vec![10, 20, 30, 40],
829            scales: vec![0.1],
830            zero_points: vec![0],
831            bits: 8,
832            axis: None,
833        }];
834
835        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
836
837        // "bias" must still be present and untouched
838        let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
839
840        assert!(
841            bias_init.is_some(),
842            "non-quantized 'bias' initializer must be preserved"
843        );
844        assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
845
846        // Full connectivity check
847        let report = validate_graph_connectivity(&graph);
848        assert!(report.valid, "broken: {:?}", report.broken_refs);
849    }
850
851    #[test]
852    fn test_ensure_opset_strips_deprecated_attrs() {
853        use crate::onnx_proto::NodeProto;
854
855        let mut model = ModelProto {
856            opset_import: vec![OperatorSetIdProto {
857                domain: String::new(),
858                version: 7,
859            }],
860            graph: Some(GraphProto {
861                node: vec![
862                    // BatchNormalization with deprecated `spatial` attr (removed opset 9)
863                    NodeProto {
864                        op_type: "BatchNormalization".to_string(),
865                        input: vec!["x".into(), "s".into(), "b".into(), "m".into(), "v".into()],
866                        output: vec!["bn_out".into()],
867                        attribute: vec![
868                            AttributeProto {
869                                name: "epsilon".to_string(),
870                                r#type: attribute_proto::AttributeType::Float as i32,
871                                f: 1e-5,
872                                ..Default::default()
873                            },
874                            AttributeProto {
875                                name: "spatial".to_string(),
876                                r#type: attribute_proto::AttributeType::Int as i32,
877                                i: 1,
878                                ..Default::default()
879                            },
880                        ],
881                        ..Default::default()
882                    },
883                    // Dropout with `ratio` attribute — should be migrated
884                    // from attribute to 2nd input.
885                    NodeProto {
886                        op_type: "Dropout".to_string(),
887                        input: vec!["bn_out".into()],
888                        output: vec!["drop_out".into(), "drop_mask".into()],
889                        attribute: vec![AttributeProto {
890                            name: "ratio".to_string(),
891                            r#type: attribute_proto::AttributeType::Float as i32,
892                            f: 0.3,
893                            ..Default::default()
894                        }],
895                        ..Default::default()
896                    },
897                    // Softmax with NO axis attribute — old default is 1,
898                    // opset 13 changes default to -1, so axis=1 must be added.
899                    NodeProto {
900                        op_type: "Softmax".to_string(),
901                        input: vec!["drop_out".into()],
902                        output: vec!["sm_out".into()],
903                        attribute: vec![],
904                        ..Default::default()
905                    },
906                ],
907                ..Default::default()
908            }),
909            ..Default::default()
910        };
911
912        ensure_opset_version(&mut model, 13);
913
914        // Opset should be 13
915        let opset = model
916            .opset_import
917            .iter()
918            .find(|o| o.domain.is_empty())
919            .unwrap();
920        assert_eq!(opset.version, 13);
921
922        let graph = model.graph.as_ref().unwrap();
923
924        // BatchNormalization: `spatial` must be removed, `epsilon` kept
925        let bn = &graph.node[0];
926        assert!(
927            !bn.attribute.iter().any(|a| a.name == "spatial"),
928            "BatchNormalization.spatial should be stripped"
929        );
930        assert!(
931            bn.attribute.iter().any(|a| a.name == "epsilon"),
932            "BatchNormalization.epsilon should be preserved"
933        );
934
935        // Dropout: `ratio` attribute must be removed and moved to 2nd input
936        let drop = &graph.node[1];
937        assert!(
938            !drop.attribute.iter().any(|a| a.name == "ratio"),
939            "Dropout.ratio attribute should be removed"
940        );
941        assert_eq!(drop.input.len(), 2, "Dropout should now have 2 inputs");
942        let ratio_init_name = &drop.input[1];
943
944        // The ratio value should be stored as an initializer
945        let ratio_init = graph
946            .initializer
947            .iter()
948            .find(|i| &i.name == ratio_init_name)
949            .expect("Dropout ratio initializer should exist");
950        assert_eq!(ratio_init.data_type, tensor_proto::DataType::Float as i32);
951        assert!(
952            (ratio_init.float_data[0] - 0.3).abs() < 1e-6,
953            "ratio should be 0.3"
954        );
955
956        // Softmax: must have explicit axis=1 added (old default, since opset 13 changed it to -1)
957        let sm = &graph.node[2];
958        assert_eq!(sm.op_type, "Softmax");
959        let axis_attr = sm
960            .attribute
961            .iter()
962            .find(|a| a.name == "axis")
963            .expect("Softmax should have axis attribute added");
964        assert_eq!(axis_attr.i, 1, "Softmax axis should be 1 (old default)");
965    }
966
967    #[test]
968    fn test_ensure_opset_no_downgrade() {
969        let mut model = ModelProto {
970            opset_import: vec![OperatorSetIdProto {
971                domain: String::new(),
972                version: 15,
973            }],
974            graph: Some(GraphProto::default()),
975            ..Default::default()
976        };
977
978        // Requesting 10 should NOT downgrade from 15
979        ensure_opset_version(&mut model, 10);
980        let opset = model
981            .opset_import
982            .iter()
983            .find(|o| o.domain.is_empty())
984            .unwrap();
985        assert_eq!(opset.version, 15);
986    }
987}