Skip to main content

bb_ir/
verify.rs

1//! IR well-formedness checkers run between compiler passes. Each
2//! function is a pure check returning `Result<(), VerifyError>`;
3//! the compiler invokes them at the per-pass seams described in
4//! `docs/COMPILER.md`.
5
6use std::collections::HashSet;
7
8use crate::proto::onnx::{FunctionProto, ModelProto};
9
10/// IR shape failure with enough context to locate the offending op.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum VerifyError {
13    /// `NodeProto.op_type` empty.
14    EmptyOpType {
15        /// Function carrying the bad node.
16        function_name: String,
17        /// Position within the function's `node`.
18        node_index: usize,
19    },
20
21    /// `wire.Send` without a matching `wire.Recv`. Downstream blocks
22    /// forever.
23    UnpairedWireSend {
24        /// Token stamped on the orphan Send.
25        wire_id: u64,
26        /// Host function name.
27        function_name: String,
28    },
29
30    /// `wire.Recv` without a matching `wire.Send`.
31    UnpairedWireRecv {
32        /// Token stamped on the orphan Recv.
33        wire_id: u64,
34        /// Host function name.
35        function_name: String,
36    },
37
38    /// `Call*` node names a function absent from `model.functions`.
39    UnresolvedFunctionCall {
40        /// Callee name.
41        target_name: String,
42        /// Caller function name.
43        function_name: String,
44        /// Position of the bad `Call*` node within the caller.
45        node_index: usize,
46    },
47
48    /// `ModelProto.functions` empty.
49    EmptyFunctionTable,
50}
51
52impl std::fmt::Display for VerifyError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            Self::EmptyOpType {
56                function_name,
57                node_index,
58            } => write!(
59                f,
60                "empty op_type at function `{}` node #{}",
61                function_name, node_index
62            ),
63            Self::UnpairedWireSend {
64                wire_id,
65                function_name,
66            } => write!(
67                f,
68                "wire.Send with wire_id={} in function `{}` has no matching Recv",
69                wire_id, function_name
70            ),
71            Self::UnpairedWireRecv {
72                wire_id,
73                function_name,
74            } => write!(
75                f,
76                "wire.Recv with wire_id={} in function `{}` has no matching Send",
77                wire_id, function_name
78            ),
79            Self::UnresolvedFunctionCall {
80                target_name,
81                function_name,
82                node_index,
83            } => write!(
84                f,
85                "function `{}` node #{} calls undefined function `{}`",
86                function_name, node_index, target_name
87            ),
88            Self::EmptyFunctionTable => f.write_str("ModelProto.functions is empty"),
89        }
90    }
91}
92
93impl std::error::Error for VerifyError {}
94
95/// Verify every `NodeProto.op_type` is non-empty and the function
96/// table is non-empty.
97pub fn types(model: &ModelProto) -> Result<(), VerifyError> {
98    if model.functions.is_empty() {
99        return Err(VerifyError::EmptyFunctionTable);
100    }
101    for function in &model.functions {
102        for (i, node) in function.node.iter().enumerate() {
103            if node.op_type.is_empty() {
104                return Err(VerifyError::EmptyOpType {
105                    function_name: function.name.clone(),
106                    node_index: i,
107                });
108            }
109        }
110    }
111    Ok(())
112}
113
114/// Verify each `wire_id` has both a `Send` and a `Recv`.
115pub fn wire_pairs(model: &ModelProto) -> Result<(), VerifyError> {
116    for function in &model.functions {
117        let mut sends: HashSet<u64> = HashSet::new();
118        let mut recvs: HashSet<u64> = HashSet::new();
119        for node in &function.node {
120            let Some(wire_id) = read_wire_id(node) else {
121                continue;
122            };
123            if node.op_type == "Send" {
124                sends.insert(wire_id);
125            } else if node.op_type == "Recv" {
126                recvs.insert(wire_id);
127            }
128        }
129        if let Some(wire_id) = sends.difference(&recvs).next() {
130            return Err(VerifyError::UnpairedWireSend {
131                wire_id: *wire_id,
132                function_name: function.name.clone(),
133            });
134        }
135        if let Some(wire_id) = recvs.difference(&sends).next() {
136            return Err(VerifyError::UnpairedWireRecv {
137                wire_id: *wire_id,
138                function_name: function.name.clone(),
139            });
140        }
141    }
142    Ok(())
143}
144
145/// Verify each `Call*` node names a function in `model.functions`.
146pub fn function_calls(model: &ModelProto) -> Result<(), VerifyError> {
147    let defined: HashSet<&str> = model.functions.iter().map(|f| f.name.as_str()).collect();
148    for function in &model.functions {
149        for (i, node) in function.node.iter().enumerate() {
150            if !node.op_type.starts_with("Call") {
151                continue;
152            }
153            // Recorder stamps the target name on `node.name`.
154            let target = node.name.as_str();
155            if target.is_empty() {
156                continue;
157            }
158            if !defined.contains(target) {
159                return Err(VerifyError::UnresolvedFunctionCall {
160                    target_name: target.to_string(),
161                    function_name: function.name.clone(),
162                    node_index: i,
163                });
164            }
165        }
166    }
167    Ok(())
168}
169
170/// Read [`crate::keys::WIRE_ID_KEY`] as `u64`. `None` when missing
171/// or non-numeric.
172fn read_wire_id(node: &crate::proto::onnx::NodeProto) -> Option<u64> {
173    node.metadata_props
174        .iter()
175        .find(|p| p.key == crate::keys::WIRE_ID_KEY)
176        .and_then(|p| p.value.parse::<u64>().ok())
177}
178
179/// Single-function wire-id check (no wrapping `ModelProto` needed).
180pub fn wire_pairs_in_function(function: &FunctionProto) -> Result<(), VerifyError> {
181    let mut sends: HashSet<u64> = HashSet::new();
182    let mut recvs: HashSet<u64> = HashSet::new();
183    for node in &function.node {
184        let Some(wire_id) = read_wire_id(node) else {
185            continue;
186        };
187        if node.op_type == "Send" {
188            sends.insert(wire_id);
189        } else if node.op_type == "Recv" {
190            recvs.insert(wire_id);
191        }
192    }
193    if let Some(wire_id) = sends.difference(&recvs).next() {
194        return Err(VerifyError::UnpairedWireSend {
195            wire_id: *wire_id,
196            function_name: function.name.clone(),
197        });
198    }
199    if let Some(wire_id) = recvs.difference(&sends).next() {
200        return Err(VerifyError::UnpairedWireRecv {
201            wire_id: *wire_id,
202            function_name: function.name.clone(),
203        });
204    }
205    Ok(())
206}
207