1use crate::graph::{Graph, NodeId};
22use crate::infer_shape;
23
24#[derive(Debug)]
26pub struct VerifyError {
27 pub node: Option<NodeId>,
28 pub message: String,
29}
30
31impl std::fmt::Display for VerifyError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self.node {
34 Some(id) => write!(f, "at {id}: {}", self.message),
35 None => write!(f, "{}", self.message),
36 }
37 }
38}
39
40pub fn verify(graph: &Graph) -> Vec<VerifyError> {
42 let mut errors = Vec::new();
43 let num_nodes = graph.len();
44
45 for node in graph.nodes() {
46 for &input in &node.inputs {
48 if input.0 as usize >= num_nodes {
49 errors.push(VerifyError {
50 node: Some(node.id),
51 message: format!(
52 "input {input} references non-existent node (graph has {num_nodes} nodes)"
53 ),
54 });
55 } else if input.0 >= node.id.0 {
56 errors.push(VerifyError {
57 node: Some(node.id),
58 message: format!(
59 "input {input} is not before {}: graph is not a DAG",
60 node.id
61 ),
62 });
63 }
64 }
65
66 let expected = node.op.num_inputs();
68 if expected > 0 && node.inputs.len() != expected {
69 errors.push(VerifyError {
70 node: Some(node.id),
71 message: format!(
72 "{} expects {} inputs, got {}",
73 node.op,
74 expected,
75 node.inputs.len()
76 ),
77 });
78 }
79 }
80
81 for &out in &graph.outputs {
83 if out.0 as usize >= num_nodes {
84 errors.push(VerifyError {
85 node: None,
86 message: format!("output {out} references non-existent node"),
87 });
88 }
89 }
90
91 errors
92}
93
94fn shapes_compatible(declared: &crate::Shape, inferred: &crate::Shape) -> bool {
96 if declared == inferred {
97 return true;
98 }
99 if declared.dtype() != inferred.dtype() {
100 return false;
101 }
102 matches!(
104 (declared.num_elements(), inferred.num_elements()),
105 (Some(1), Some(1))
106 )
107}
108
109pub fn verify_shapes(graph: &Graph) -> Vec<VerifyError> {
111 let mut errors = Vec::new();
112 for node in graph.nodes() {
113 let Some(expected) = infer_shape::infer_output_shape(graph, node) else {
114 continue;
115 };
116 if !shapes_compatible(&node.shape, &expected) {
117 errors.push(VerifyError {
118 node: Some(node.id),
119 message: format!(
120 "shape mismatch: declared {}, inferred {expected}",
121 node.shape
122 ),
123 });
124 }
125 }
126 errors
127}
128
129pub fn verify_all(graph: &Graph) -> Vec<VerifyError> {
131 let mut errors = verify(graph);
132 errors.extend(verify_shapes(graph));
133 errors
134}
135
136#[macro_export]
139macro_rules! debug_assert_valid {
140 ($graph:expr, $stage:expr) => {{
141 #[cfg(debug_assertions)]
142 {
143 let __errors = $crate::verify::verify_all($graph);
144 if !__errors.is_empty() {
145 let __msg = __errors
146 .iter()
147 .map(|e| e.to_string())
148 .collect::<Vec<_>>()
149 .join("\n ");
150 panic!("IR verifier failed at `{}`:\n {}", $stage, __msg);
151 }
152 }
153 }};
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::*;
160
161 #[test]
162 fn shape_mismatch_is_caught() {
163 let mut g = Graph::new("bad");
164 let x = g.input("x", Shape::new(&[4, 8], DType::F32));
165 let w = g.param("w", Shape::new(&[8, 16], DType::F32));
166 let mm = g.matmul(x, w, Shape::new(&[99, 99], DType::F32));
168 g.set_outputs(vec![mm]);
169
170 let errs = verify_shapes(&g);
171 assert_eq!(errs.len(), 1);
172 assert!(errs[0].message.contains("shape mismatch"));
173 }
174
175 #[test]
176 fn scalar_rank0_and_rank1_are_compatible() {
177 let mut g = Graph::new("scalar");
178 let x = g.input("x", Shape::new(&[3], DType::F32));
179 let loss = g.add_node(
180 Op::Reduce {
181 op: crate::op::ReduceOp::Sum,
182 axes: vec![0],
183 keep_dim: false,
184 },
185 vec![x],
186 Shape::new(&[1], DType::F32),
187 );
188 g.set_outputs(vec![loss]);
189 assert!(
190 verify_shapes(&g).is_empty(),
191 "[] inferred vs [1] declared should match for a scalar"
192 );
193 }
194
195 #[test]
196 fn verify_all_combines_checks() {
197 let mut g = Graph::new("ok");
198 let x = g.input("x", Shape::new(&[4, 384], DType::F32));
199 let w = g.param("w", Shape::new(&[384, 384], DType::F32));
200 let mm = g.matmul(x, w, Shape::new(&[4, 384], DType::F32));
201 g.set_outputs(vec![mm]);
202 assert!(verify_all(&g).is_empty());
203 }
204}