Skip to main content

quantize_rs/onnx_utils/
graph_builder.rs

1//! Graph-level operations for quantized ONNX models.
2//!
3//! Three responsibilities:
4//!   1. **QDQ transform** — replace FP32 initializers with INT8 + DequantizeLinear
5//!   2. **Connectivity validation** — walk the graph and verify every edge resolves
6//!   3. **Opset management** — ensure the model declares opset ≥ 13
7
8use crate::errors::{QuantizeError, Result};
9use crate::onnx_proto::{GraphProto, ModelProto, OperatorSetIdProto};
10use std::collections::{HashMap, HashSet};
11
12use super::quantization_nodes::{
13    DequantLinearNames,
14    build_dequantize_linear_node,
15    build_quantized_weight_tensor,
16    build_scale_tensor,
17    build_zero_point_tensor,
18};
19
20// ===========================================================================
21// Public types
22// ===========================================================================
23
24/// One weight to convert: FP32 initializer → INT8 + DequantizeLinear block.
25#[derive(Debug)]
26pub struct QdqWeightInput {
27    /// Original initializer name (e.g., `"conv1.weight"`)
28    pub original_name: String,
29    /// Quantized values as i8.  For INT4 these are in [-8, 7]; for INT8 in [-128, 127].
30    /// Always unpacked — one value per element.
31    pub quantized_values: Vec<i8>,
32    /// Quantization scales (FP32).
33    /// Length 1 for per-tensor; one per channel for per-channel.
34    pub scales: Vec<f32>,
35    /// Zero points (INT8).
36    /// Same length as `scales`.
37    pub zero_points: Vec<i8>,
38    /// Original bit-width (4 or 8).  Informational only — ONNX storage is always INT8.
39    /// Persisted in model metadata so `load_quantized_info` can recover it.
40    pub bits: u8,
41    /// Per-channel quantization axis, or `None` for per-tensor.
42    pub axis: Option<usize>,
43}
44
45/// Result of a graph-connectivity check.
46#[derive(Debug)]
47#[must_use]
48pub struct ConnectivityReport {
49    /// `true` if every node input resolves to a known tensor.
50    pub valid: bool,
51    /// Human-readable description of every dangling reference.  Empty when valid.
52    pub broken_refs: Vec<String>,
53}
54
55impl ConnectivityReport {
56    /// Render the report as a printable string (useful for CLI `validate` output).
57    pub fn summary(&self) -> String {
58        if self.valid {
59            "  Graph connectivity: OK\n".to_string()
60        } else {
61            let mut s = format!(
62                "  Graph connectivity: BROKEN ({} dangling reference{})\n",
63                self.broken_refs.len(),
64                if self.broken_refs.len() == 1 { "" } else { "s" }
65            );
66            for (i, r) in self.broken_refs.iter().enumerate() {
67                s.push_str(&format!("    {}. {}\n", i + 1, r));
68            }
69            s
70        }
71    }
72}
73
74// ===========================================================================
75// Connectivity validation
76// ===========================================================================
77
78/// Walk the graph and verify every node input resolves to *something*.
79///
80/// A valid input is exactly one of:
81///   • a declared graph input (`graph.input`)
82///   • an initializer name (`graph.initializer`)
83///   • the output of a node that appears **earlier** in `graph.node`
84///
85/// This is the check ONNX Runtime performs on load — and the check that
86/// v0.2.0's `validate` command skipped, letting the rename bug through.
87pub fn validate_graph_connectivity(graph: &GraphProto) -> ConnectivityReport {
88    let mut known: HashSet<String> = HashSet::new();
89
90    // Seed: graph inputs + initializers are always available
91    for inp in &graph.input {
92        known.insert(inp.name.clone());
93    }
94    for init in &graph.initializer {
95        known.insert(init.name.clone());
96    }
97
98    let mut broken = Vec::new();
99
100    // Walk nodes in serialized order; each node's outputs become known afterwards
101    for node in &graph.node {
102        for name in &node.input {
103            if name.is_empty() {
104                continue; // optional input slot — empty string is valid
105            }
106            if !known.contains(name.as_str()) {
107                broken.push(format!(
108                    "Node '{}' (op={}) → unknown input '{}'",
109                    node.name, node.op_type, name
110                ));
111            }
112        }
113        // Register outputs so later nodes can consume them
114        for name in &node.output {
115            if !name.is_empty() {
116                known.insert(name.clone());
117            }
118        }
119    }
120
121    ConnectivityReport {
122        valid: broken.is_empty(),
123        broken_refs: broken,
124    }
125}
126
127// ===========================================================================
128// Opset version management
129// ===========================================================================
130
131/// Ensure the default ONNX domain opset is at least `min_version`.
132///
133/// DequantizeLinear requires opset ≥ 10 (per-tensor) or ≥ 13 (per-channel axis).
134/// We always request 13 to leave the door open for per-channel.
135pub fn ensure_opset_version(model: &mut ModelProto, min_version: i64) {
136    // The default ONNX domain is identified by an empty string
137    for opset in model.opset_import.iter_mut() {
138        if opset.domain.is_empty() {
139            if opset.version < min_version {
140                opset.version = min_version;
141            }
142            return; // found and updated (or already sufficient)
143        }
144    }
145
146    // No default-domain entry at all — add one
147    model.opset_import.push(OperatorSetIdProto {
148        domain:  String::new(), // "" = standard ONNX domain
149        version: min_version,
150    });
151}
152
153// ===========================================================================
154// QDQ transform
155// ===========================================================================
156
157/// Replace FP32 weight initializers with INT8 quantized equivalents +
158/// DequantizeLinear nodes.
159///
160/// ### What happens per weight in `inputs`:
161///
162/// **Removed:**
163///   - Initializer `"{name}"` (the original FP32 weight data)
164///
165/// **Added (initializers):**
166///   - `"{name}_quantized"` — INT8, same shape as original
167///   - `"{name}_scale"`     — FP32 scalar
168///   - `"{name}_zp"`        — INT8 scalar
169///
170/// **Added (node, prepended before all existing nodes):**
171///   - `DequantizeLinear` with output = `"{name}"`
172///
173/// Because the DequantizeLinear output carries the **original** name, every
174/// downstream node (Conv, MatMul, BatchNorm, …) remains completely unchanged.
175/// Graph connectivity is preserved by construction.
176///
177/// ---
178/// ### INT4 storage note
179///
180/// ONNX `DequantizeLinear` requires INT8 input in opsets < 21.  INT4-quantized
181/// values (range [-8, 7]) are widened to INT8 here.  The quantization *accuracy*
182/// is INT4-level (scale and zero_point were computed for the 4-bit range), but
183/// on-disk storage is 4× compression rather than the 8× that bit-packing would
184/// give.  True INT4 packing is planned for a future version (opset 21 or custom op).
185pub fn apply_qdq_transform(
186    graph: &mut GraphProto,
187    inputs: &[QdqWeightInput],
188) -> Result<()> {
189    // -----------------------------------------------------------------------
190    // 0.  Snapshot shapes before modifying the initializer list
191    // -----------------------------------------------------------------------
192    let shape_map: HashMap<String, Vec<i64>> = graph
193        .initializer
194        .iter()
195        .map(|init| (init.name.clone(), init.dims.clone()))
196        .collect();
197
198    let quant_set: HashSet<&str> = inputs.iter().map(|i| i.original_name.as_str()).collect();
199
200    // -----------------------------------------------------------------------
201    // 1.  Remove the original FP32 initializers for every weight we're replacing
202    // -----------------------------------------------------------------------
203    graph.initializer.retain(|init| !quant_set.contains(init.name.as_str()));
204
205    // -----------------------------------------------------------------------
206    // 1b. Also remove weights from graph.input (critical fix for "Duplicate definition")
207    // -----------------------------------------------------------------------
208    // Some ONNX models list weights as both initializers AND graph inputs.
209    // This is valid ONNX, but when DequantizeLinear outputs reuse the original
210    // weight names, ONNX Runtime sees two definitions of the same tensor.
211    graph.input.retain(|inp| !quant_set.contains(inp.name.as_str()));
212
213    // -----------------------------------------------------------------------
214    // 2.  Add quantized initializer triples + build DequantizeLinear nodes
215    // -----------------------------------------------------------------------
216    let mut dq_nodes = Vec::new();
217
218    for inp in inputs {
219        let shape = shape_map
220            .get(&inp.original_name)
221            .ok_or_else(|| {
222                QuantizeError::GraphTransform {
223                    reason: format!(
224                        "Weight '{}' not found in model initializers — \
225                         verify the name matches exactly",
226                        inp.original_name
227                    ),
228                }
229            })?;
230
231        let expected_len: i64 = shape.iter().product();
232        if inp.quantized_values.len() as i64 != expected_len {
233            return Err(QuantizeError::GraphTransform {
234                reason: format!(
235                    "Weight '{}': quantized_values has {} elements but shape {:?} expects {}",
236                    inp.original_name, inp.quantized_values.len(), shape, expected_len
237                ),
238            });
239        }
240
241        let names = DequantLinearNames::from_original(&inp.original_name);
242
243        graph.initializer.push(
244            build_quantized_weight_tensor(&names, &inp.quantized_values, shape),
245        );
246        graph.initializer.push(
247            build_scale_tensor(&names, &inp.scales),
248        );
249        graph.initializer.push(
250            build_zero_point_tensor(&names, &inp.zero_points),
251        );
252
253        dq_nodes.push(build_dequantize_linear_node(&names, inp.axis));
254    }
255
256    // -----------------------------------------------------------------------
257    // 3.  Prepend DequantizeLinear nodes before all existing computation nodes.
258    //     They must appear first so their outputs are "known" when the validator
259    //     (or ONNX Runtime) walks the node list in order.
260    // -----------------------------------------------------------------------
261    let existing_nodes = std::mem::take(&mut graph.node);
262    graph.node = dq_nodes;
263    graph.node.extend(existing_nodes);
264
265    Ok(())
266}
267
268// ===========================================================================
269// Tests
270// ===========================================================================
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::onnx_proto::{
276        GraphProto, ModelProto, NodeProto, OperatorSetIdProto,
277        TensorProto, ValueInfoProto, tensor_proto,
278    };
279
280    // -----------------------------------------------------------------------
281    // Test helpers
282    // -----------------------------------------------------------------------
283
284    /// Minimal graph: one graph input "input", one FP32 initializer "w" (shape [2,2]),
285    /// one Conv node consuming both, producing "out".
286    fn make_simple_graph() -> GraphProto {
287        GraphProto {
288            input: vec![ValueInfoProto { name: "input".to_string(), ..Default::default() }],
289            initializer: vec![TensorProto {
290                name:       "w".to_string(),
291                data_type:  tensor_proto::DataType::Float as i32,
292                dims:       vec![2, 2],
293                float_data: vec![1.0, 2.0, 3.0, 4.0],
294                ..Default::default()
295            }],
296            node: vec![NodeProto {
297                op_type: "Conv".to_string(),
298                name:    "conv0".to_string(),
299                input:   vec!["input".to_string(), "w".to_string()],
300                output:  vec!["out".to_string()],
301                ..Default::default()
302            }],
303            ..Default::default()
304        }
305    }
306
307    /// Two-weight graph: "w1" and "w2", two Conv nodes chained.
308    fn make_two_weight_graph() -> GraphProto {
309        GraphProto {
310            input: vec![ValueInfoProto { name: "input".to_string(), ..Default::default() }],
311            initializer: vec![
312                TensorProto {
313                    name:       "w1".to_string(),
314                    data_type:  tensor_proto::DataType::Float as i32,
315                    dims:       vec![2, 2],
316                    float_data: vec![1.0, 2.0, 3.0, 4.0],
317                    ..Default::default()
318                },
319                TensorProto {
320                    name:       "w2".to_string(),
321                    data_type:  tensor_proto::DataType::Float as i32,
322                    dims:       vec![2, 2],
323                    float_data: vec![5.0, 6.0, 7.0, 8.0],
324                    ..Default::default()
325                },
326            ],
327            node: vec![
328                NodeProto {
329                    op_type: "Conv".to_string(),
330                    name:    "conv1".to_string(),
331                    input:   vec!["input".to_string(), "w1".to_string()],
332                    output:  vec!["mid".to_string()],
333                    ..Default::default()
334                },
335                NodeProto {
336                    op_type: "Conv".to_string(),
337                    name:    "conv2".to_string(),
338                    input:   vec!["mid".to_string(), "w2".to_string()],
339                    output:  vec!["out".to_string()],
340                    ..Default::default()
341                },
342            ],
343            ..Default::default()
344        }
345    }
346
347    // -----------------------------------------------------------------------
348    // Connectivity validation tests
349    // -----------------------------------------------------------------------
350
351    #[test]
352    fn test_connectivity_passes_on_valid_graph() {
353        let graph  = make_simple_graph();
354        let report = validate_graph_connectivity(&graph);
355        assert!(
356            report.valid,
357            "original graph should be valid; broken: {:?}",
358            report.broken_refs
359        );
360    }
361
362    #[test]
363    fn test_connectivity_detects_renamed_initializer() {
364        // Simulate the exact v0.2.0 bug: rename "w" in the initializer list
365        // without updating the Conv node that references it.
366        let mut graph = make_simple_graph();
367
368        for init in graph.initializer.iter_mut() {
369            if init.name == "w" {
370                init.name = "w__qINT8_s0.00392_z-3_len4".to_string();
371            }
372        }
373
374        let report = validate_graph_connectivity(&graph);
375        assert!(!report.valid, "should detect broken reference to 'w'");
376        assert_eq!(report.broken_refs.len(), 1);
377        assert!(
378            report.broken_refs[0].contains("'w'"),
379            "error should mention 'w': {}",
380            report.broken_refs[0]
381        );
382    }
383
384    #[test]
385    fn test_connectivity_detects_multiple_broken_refs() {
386        let mut graph = make_two_weight_graph();
387
388        for init in graph.initializer.iter_mut() {
389            if init.name == "w1" {
390                init.name = "w1_broken".to_string();
391            } else if init.name == "w2" {
392                init.name = "w2_broken".to_string();
393            }
394        }
395
396        let report = validate_graph_connectivity(&graph);
397        assert!(!report.valid);
398        assert_eq!(report.broken_refs.len(), 2);
399    }
400
401    #[test]
402    fn test_connectivity_summary_formatting() {
403        let valid = ConnectivityReport {
404            valid: true,
405            broken_refs: vec![],
406        };
407        assert!(valid.summary().contains("OK"));
408
409        let broken = ConnectivityReport {
410            valid: false,
411            broken_refs: vec!["Node 'x' → unknown input 'y'".to_string()],
412        };
413        let s = broken.summary();
414        assert!(s.contains("BROKEN"));
415        assert!(s.contains("1 dangling reference"));
416        assert!(s.contains("unknown input 'y'"));
417    }
418
419    // -----------------------------------------------------------------------
420    // Opset version tests
421    // -----------------------------------------------------------------------
422
423    #[test]
424    fn test_ensure_opset_bumps_low_version() {
425        let mut model = ModelProto {
426            opset_import: vec![OperatorSetIdProto { domain: String::new(), version: 10 }],
427            ..Default::default()
428        };
429
430        ensure_opset_version(&mut model, 13);
431
432        assert_eq!(model.opset_import[0].version, 13);
433    }
434
435    #[test]
436    fn test_ensure_opset_leaves_sufficient_version() {
437        let mut model = ModelProto {
438            opset_import: vec![OperatorSetIdProto { domain: String::new(), version: 17 }],
439            ..Default::default()
440        };
441
442        ensure_opset_version(&mut model, 13);
443
444        assert_eq!(model.opset_import[0].version, 17, "should not downgrade");
445    }
446
447    #[test]
448    fn test_ensure_opset_adds_missing_default_domain() {
449        let mut model = ModelProto::default();
450        // No opset_import at all
451        ensure_opset_version(&mut model, 13);
452
453        assert_eq!(model.opset_import.len(), 1);
454        assert!(model.opset_import[0].domain.is_empty());
455        assert_eq!(model.opset_import[0].version, 13);
456    }
457
458    // -----------------------------------------------------------------------
459    // QDQ transform tests
460    // -----------------------------------------------------------------------
461
462    #[test]
463    fn test_qdq_single_weight_produces_valid_graph() {
464        let mut graph = make_simple_graph();
465
466        let inputs = vec![QdqWeightInput {
467            original_name:    "w".to_string(),
468            quantized_values: vec![25, 51, 76, 102],
469            scales:           vec![0.039_215_686], // ≈ 1/25.5
470            zero_points:      vec![0],
471            bits:             8,
472            axis:             None,
473        }];
474
475        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
476
477        let report = validate_graph_connectivity(&graph);
478        assert!(
479            report.valid,
480            "graph after QDQ must be valid; broken: {:?}",
481            report.broken_refs
482        );
483    }
484
485    #[test]
486    fn test_qdq_adds_correct_initializers() {
487        let mut graph = make_simple_graph();
488
489        let inputs = vec![QdqWeightInput {
490            original_name:    "w".to_string(),
491            quantized_values: vec![10, 20, 30, 40],
492            scales:           vec![0.1],
493            zero_points:      vec![-5],
494            bits:             8,
495            axis:             None,
496        }];
497
498        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
499
500        let init_names: Vec<&str> = graph.initializer.iter().map(|i| i.name.as_str()).collect();
501
502        assert!(init_names.contains(&"w_quantized"), "missing w_quantized");
503        assert!(init_names.contains(&"w_scale"),     "missing w_scale");
504        assert!(init_names.contains(&"w_zp"),        "missing w_zp");
505        assert!(
506            !init_names.contains(&"w"),
507            "original FP32 'w' should be removed"
508        );
509    }
510
511    #[test]
512    fn test_qdq_node_order_dequant_first() {
513        let mut graph = make_simple_graph();
514
515        let inputs = vec![QdqWeightInput {
516            original_name:    "w".to_string(),
517            quantized_values: vec![10, 20, 30, 40],
518            scales:           vec![0.1],
519            zero_points:      vec![0],
520            bits:             8,
521            axis:             None,
522        }];
523
524        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
525
526        let ops: Vec<&str> = graph.node.iter().map(|n| n.op_type.as_str()).collect();
527
528        assert_eq!(ops.len(), 2);
529        assert_eq!(ops[0], "DequantizeLinear");
530        assert_eq!(ops[1], "Conv");
531    }
532
533    #[test]
534    fn test_qdq_dequant_output_is_original_name() {
535        let mut graph = make_simple_graph();
536
537        let inputs = vec![QdqWeightInput {
538            original_name:    "w".to_string(),
539            quantized_values: vec![1, 2, 3, 4],
540            scales:           vec![1.0],
541            zero_points:      vec![0],
542            bits:             8,
543            axis:             None,
544        }];
545
546        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
547
548        let dq = &graph.node[0]; // first node = DequantizeLinear
549        assert_eq!(dq.output[0], "w", "DequantizeLinear output must be original name");
550    }
551
552    #[test]
553    fn test_qdq_two_weights_both_transformed() {
554        let mut graph = make_two_weight_graph();
555
556        let inputs = vec![
557            QdqWeightInput {
558                original_name:    "w1".to_string(),
559                quantized_values: vec![10, 20, 30, 40],
560                scales:           vec![0.1],
561                zero_points:      vec![0],
562                bits:             8,
563                axis:             None,
564            },
565            QdqWeightInput {
566                original_name:    "w2".to_string(),
567                quantized_values: vec![50, 60, 70, 80],
568                scales:           vec![0.2],
569                zero_points:      vec![-1],
570                bits:             8,
571                axis:             None,
572            },
573        ];
574
575        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
576
577        // Connectivity must still be valid
578        let report = validate_graph_connectivity(&graph);
579        assert!(report.valid, "two-weight graph broken: {:?}", report.broken_refs);
580
581        // Should have 2 DequantizeLinear + 2 Conv = 4 nodes
582        assert_eq!(graph.node.len(), 4);
583
584        // First two nodes are DequantizeLinear
585        assert_eq!(graph.node[0].op_type, "DequantizeLinear");
586        assert_eq!(graph.node[1].op_type, "DequantizeLinear");
587
588        // Their outputs are the original weight names
589        let dq_outputs: Vec<&str> = graph.node.iter().take(2)
590            .map(|n| n.output[0].as_str())
591            .collect();
592        assert!(dq_outputs.contains(&"w1"));
593        assert!(dq_outputs.contains(&"w2"));
594    }
595
596    #[test]
597    fn test_qdq_int4_values_stored_as_int8() {
598        let mut graph = make_simple_graph();
599
600        // INT4 range [-8, 7] — these arrive as i8 from ensure_unpacked()
601        let inputs = vec![QdqWeightInput {
602            original_name:    "w".to_string(),
603            quantized_values: vec![-8, -1, 0, 7],
604            scales:           vec![0.5],
605            zero_points:      vec![0],
606            bits:             4, // flag says INT4; storage must still be INT8
607            axis:             None,
608        }];
609
610        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
611
612        let quant_init = graph
613            .initializer
614            .iter()
615            .find(|i| i.name == "w_quantized")
616            .expect("w_quantized not found");
617
618        // Data type must be INT8 (ONNX DequantizeLinear requirement)
619        assert_eq!(quant_init.data_type, tensor_proto::DataType::Int8 as i32);
620
621        // Byte-level round-trip must be exact
622        let recovered: Vec<i8> = quant_init.raw_data.iter().map(|&b| b as i8).collect();
623        assert_eq!(recovered, vec![-8, -1, 0, 7]);
624    }
625
626    #[test]
627    fn test_qdq_unknown_weight_returns_error() {
628        let mut graph = make_simple_graph();
629
630        let inputs = vec![QdqWeightInput {
631            original_name:    "does_not_exist".to_string(),
632            quantized_values: vec![1, 2, 3],
633            scales:           vec![1.0],
634            zero_points:      vec![0],
635            bits:             8,
636            axis:             None,
637        }];
638
639        let result = apply_qdq_transform(&mut graph, &inputs);
640        assert!(result.is_err());
641        assert!(
642            result.unwrap_err().to_string().contains("does_not_exist"),
643            "error should name the missing weight"
644        );
645    }
646
647    #[test]
648    fn test_qdq_non_quantized_initializers_preserved() {
649        // Add an extra initializer "bias" that is NOT being quantized.
650        // It must survive the transform untouched.
651        let mut graph = make_simple_graph();
652
653        graph.initializer.push(TensorProto {
654            name:       "bias".to_string(),
655            data_type:  tensor_proto::DataType::Float as i32,
656            dims:       vec![2],
657            float_data: vec![0.1, 0.2],
658            ..Default::default()
659        });
660
661        // Also add "bias" as a Conv input so connectivity stays valid
662        graph.node[0].input.push("bias".to_string());
663
664        let inputs = vec![QdqWeightInput {
665            original_name:    "w".to_string(),
666            quantized_values: vec![10, 20, 30, 40],
667            scales:           vec![0.1],
668            zero_points:      vec![0],
669            bits:             8,
670            axis:             None,
671        }];
672
673        apply_qdq_transform(&mut graph, &inputs).expect("QDQ transform failed");
674
675        // "bias" must still be present and untouched
676        let bias_init = graph.initializer.iter().find(|i| i.name == "bias");
677
678        assert!(bias_init.is_some(), "non-quantized 'bias' initializer must be preserved");
679        assert!((bias_init.unwrap().float_data[0] - 0.1).abs() < 1e-6);
680
681        // Full connectivity check
682        let report = validate_graph_connectivity(&graph);
683        assert!(report.valid, "broken: {:?}", report.broken_refs);
684    }
685}