Skip to main content

rlx_ir/
verify.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Graph verification — catches IR bugs early.
17//!
18//! Verifies structural invariants: valid node references, input counts,
19//! acyclicity, output validity, and (optionally) shape consistency.
20
21use crate::graph::{Graph, NodeId};
22use crate::infer_shape;
23
24/// Error found during graph verification.
25#[derive(Debug)]
26pub struct VerifyError {
27    pub node: Option<NodeId>,
28    pub message: String,
29}
30
31impl std::fmt::Display for VerifyError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self.node {
34            Some(id) => write!(f, "at {id}: {}", self.message),
35            None => write!(f, "{}", self.message),
36        }
37    }
38}
39
40/// Verify structural integrity of a graph. Returns all errors found.
41pub fn verify(graph: &Graph) -> Vec<VerifyError> {
42    let mut errors = Vec::new();
43    let num_nodes = graph.len();
44
45    for node in graph.nodes() {
46        // Check that all input references are valid and precede this node (DAG property).
47        for &input in &node.inputs {
48            if input.0 as usize >= num_nodes {
49                errors.push(VerifyError {
50                    node: Some(node.id),
51                    message: format!(
52                        "input {input} references non-existent node (graph has {num_nodes} nodes)"
53                    ),
54                });
55            } else if input.0 >= node.id.0 {
56                errors.push(VerifyError {
57                    node: Some(node.id),
58                    message: format!(
59                        "input {input} is not before {}: graph is not a DAG",
60                        node.id
61                    ),
62                });
63            }
64        }
65
66        // Check input count matches op expectation (except variadic ops like Concat).
67        let expected = node.op.num_inputs();
68        if expected > 0 && node.inputs.len() != expected {
69            errors.push(VerifyError {
70                node: Some(node.id),
71                message: format!(
72                    "{} expects {} inputs, got {}",
73                    node.op,
74                    expected,
75                    node.inputs.len()
76                ),
77            });
78        }
79    }
80
81    // Check outputs reference valid nodes.
82    for &out in &graph.outputs {
83        if out.0 as usize >= num_nodes {
84            errors.push(VerifyError {
85                node: None,
86                message: format!("output {out} references non-existent node"),
87            });
88        }
89    }
90
91    errors
92}
93
94/// True when `declared` and `inferred` describe the same logical tensor.
95fn shapes_compatible(declared: &crate::Shape, inferred: &crate::Shape) -> bool {
96    if declared == inferred {
97        return true;
98    }
99    if declared.dtype() != inferred.dtype() {
100        return false;
101    }
102    // Scalar conventions: rank-0 `[]` and rank-1 `[1]` both mean one element.
103    matches!(
104        (declared.num_elements(), inferred.num_elements()),
105        (Some(1), Some(1))
106    )
107}
108
109/// Re-derive output shapes from inputs and diff against declared shapes.
110pub fn verify_shapes(graph: &Graph) -> Vec<VerifyError> {
111    let mut errors = Vec::new();
112    for node in graph.nodes() {
113        let Some(expected) = infer_shape::infer_output_shape(graph, node) else {
114            continue;
115        };
116        if !shapes_compatible(&node.shape, &expected) {
117            errors.push(VerifyError {
118                node: Some(node.id),
119                message: format!(
120                    "shape mismatch: declared {}, inferred {expected}",
121                    node.shape
122                ),
123            });
124        }
125    }
126    errors
127}
128
129/// Structural + shape verification.
130pub fn verify_all(graph: &Graph) -> Vec<VerifyError> {
131    let mut errors = verify(graph);
132    errors.extend(verify_shapes(graph));
133    errors
134}
135
136/// Panic when verification fails. **Debug builds only** — in release
137/// this macro expands to nothing and is not compiled.
138#[macro_export]
139macro_rules! debug_assert_valid {
140    ($graph:expr, $stage:expr) => {{
141        #[cfg(debug_assertions)]
142        {
143            let __errors = $crate::verify::verify_all($graph);
144            if !__errors.is_empty() {
145                let __msg = __errors
146                    .iter()
147                    .map(|e| e.to_string())
148                    .collect::<Vec<_>>()
149                    .join("\n  ");
150                panic!("IR verifier failed at `{}`:\n  {}", $stage, __msg);
151            }
152        }
153    }};
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::*;
160
161    #[test]
162    fn shape_mismatch_is_caught() {
163        let mut g = Graph::new("bad");
164        let x = g.input("x", Shape::new(&[4, 8], DType::F32));
165        let w = g.param("w", Shape::new(&[8, 16], DType::F32));
166        // Wrong output shape on purpose.
167        let mm = g.matmul(x, w, Shape::new(&[99, 99], DType::F32));
168        g.set_outputs(vec![mm]);
169
170        let errs = verify_shapes(&g);
171        assert_eq!(errs.len(), 1);
172        assert!(errs[0].message.contains("shape mismatch"));
173    }
174
175    #[test]
176    fn scalar_rank0_and_rank1_are_compatible() {
177        let mut g = Graph::new("scalar");
178        let x = g.input("x", Shape::new(&[3], DType::F32));
179        let loss = g.add_node(
180            Op::Reduce {
181                op: crate::op::ReduceOp::Sum,
182                axes: vec![0],
183                keep_dim: false,
184            },
185            vec![x],
186            Shape::new(&[1], DType::F32),
187        );
188        g.set_outputs(vec![loss]);
189        assert!(
190            verify_shapes(&g).is_empty(),
191            "[] inferred vs [1] declared should match for a scalar"
192        );
193    }
194
195    #[test]
196    fn verify_all_combines_checks() {
197        let mut g = Graph::new("ok");
198        let x = g.input("x", Shape::new(&[4, 384], DType::F32));
199        let w = g.param("w", Shape::new(&[384, 384], DType::F32));
200        let mm = g.matmul(x, w, Shape::new(&[4, 384], DType::F32));
201        g.set_outputs(vec![mm]);
202        assert!(verify_all(&g).is_empty());
203    }
204}