Skip to main content

bb_compiler/
validate.rs

1//! Structural sanity check. Reject malformed input before any
2//! other pass mutates it. Pure function over `GraphProto`.
3//!
4//! Implemented rules:
5//!
6//! - Rule 1 — op type known (framework reserved opsets + `ai.onnx`).
7//! - Rule 2 — inputs reachable.
8//! - Rule 3 — outputs unique.
9//! - Rule 5 — every graph input has a matching `ValueInfoProto`.
10//! - Rule 6 — role-domain NodeProtos carry canonical metadata.
11//! - Rule 7 — no cycles.
12
13use std::collections::{HashMap, HashSet};
14
15use crate::error::ValidationError;
16use bb_ir::proto::onnx::GraphProto;
17
18/// Validate the recorded graph. Pure.
19pub fn validate(graph: &GraphProto) -> Result<(), ValidationError> {
20    rule_1_known_op(graph)?;
21    rule_2_inputs_reachable(graph)?;
22    rule_3_outputs_unique(graph)?;
23    rule_5_type_declarations_present(graph)?;
24    rule_6_slot_metadata_well_formed(graph)?;
25    rule_7_no_cycles(graph)?;
26    Ok(())
27}
28
29/// Reserved framework opset prefixes
30const RESERVED_OPSET_PREFIXES: &[&str] = &["ai.bytesandbrains", "ai.onnx"];
31
32fn is_reserved_opset(domain: &str) -> bool {
33    RESERVED_OPSET_PREFIXES
34        .iter()
35        .any(|p| domain == *p || domain.starts_with(&format!("{p}.")))
36}
37
38/// Rule 1 - every `(domain, op_type)` belongs to a known opset.
39fn rule_1_known_op(graph: &GraphProto) -> Result<(), ValidationError> {
40    for node in &graph.node {
41        if !is_reserved_opset(&node.domain) {
42            return Err(ValidationError::UnknownOp {
43                node_name: node.name.clone(),
44                op_type: node.op_type.clone(),
45                domain: node.domain.clone(),
46            });
47        }
48    }
49    Ok(())
50}
51
52/// Rule 2 - every input value name is produced upstream or appears
53/// in `graph.input`.
54fn rule_2_inputs_reachable(graph: &GraphProto) -> Result<(), ValidationError> {
55    let mut produced: HashSet<&str> = HashSet::new();
56    for input in &graph.input {
57        produced.insert(input.name.as_str());
58    }
59    // First scan all node outputs so we don't reject forward refs
60    // within a DAG-valid topological order before we've collected
61    // them - rule 7 separately enforces acyclicity.
62    for node in &graph.node {
63        for out in &node.output {
64            produced.insert(out.as_str());
65        }
66    }
67    for node in &graph.node {
68        for inp in &node.input {
69            if inp.is_empty() {
70                // ONNX permits empty input slots (e.g. optional Conv
71                // bias) - skip rather than flag.
72                continue;
73            }
74            if !produced.contains(inp.as_str()) {
75                return Err(ValidationError::DanglingInput {
76                    node_name: node.name.clone(),
77                    input_name: inp.clone(),
78                });
79            }
80        }
81    }
82    Ok(())
83}
84
85/// Rule 3 - every output value name is written by at most one op.
86fn rule_3_outputs_unique(graph: &GraphProto) -> Result<(), ValidationError> {
87    let mut writers: HashMap<&str, &str> = HashMap::new();
88    for node in &graph.node {
89        for out in &node.output {
90            if out.is_empty() {
91                continue;
92            }
93            if let Some(&prev) = writers.get(out.as_str()) {
94                return Err(ValidationError::DuplicateOutput {
95                    value_name: out.clone(),
96                    node_a: prev.to_string(),
97                    node_b: node.name.clone(),
98                });
99            }
100            writers.insert(out.as_str(), node.name.as_str());
101        }
102    }
103    Ok(())
104}
105
106/// Rule 5 - every `graph.input` has a matching `ValueInfoProto.type`.
107fn rule_5_type_declarations_present(graph: &GraphProto) -> Result<(), ValidationError> {
108    for input in &graph.input {
109        if input.r#type.is_none() {
110            return Err(ValidationError::MissingTypeInfo {
111                input_name: input.name.clone(),
112            });
113        }
114    }
115    Ok(())
116}
117
118/// Rule 6 - every role-domain NodeProto carries the canonical
119/// metadata keys.
120///
121/// For `domain` starting with `"ai.bytesandbrains.role."`:
122/// - EITHER `(concrete_type, instance)` BOTH present, OR
123/// - `(required_trait, slot_id)` BOTH present.
124fn rule_6_slot_metadata_well_formed(graph: &GraphProto) -> Result<(), ValidationError> {
125    for node in &graph.node {
126        if !node.domain.starts_with("ai.bytesandbrains.role.") {
127            continue;
128        }
129        let has_concrete = meta_has(node, "ai.bytesandbrains.concrete_type")
130            && meta_has(node, "ai.bytesandbrains.instance");
131        let has_generic = meta_has(node, "ai.bytesandbrains.required_trait")
132            && meta_has(node, "ai.bytesandbrains.slot_id");
133        if !has_concrete && !has_generic {
134            return Err(ValidationError::MalformedSlotMetadata {
135                node_name: node.name.clone(),
136                detail: format!(
137                    "role-domain NodeProto {} lacks both (concrete_type, instance) and (required_trait, slot_id) metadata",
138                    node.op_type,
139                ),
140            });
141        }
142    }
143    Ok(())
144}
145
146fn meta_has(node: &bb_ir::proto::onnx::NodeProto, key: &str) -> bool {
147    node.metadata_props.iter().any(|p| p.key == key)
148}
149
150/// Rule 7 - no cycles. Kahn's algorithm over the producer-consumer
151/// DAG.
152fn rule_7_no_cycles(graph: &GraphProto) -> Result<(), ValidationError> {
153    // Build producer map: value_name → producing NodeProto index.
154    let mut producer: HashMap<&str, usize> = HashMap::new();
155    for (idx, node) in graph.node.iter().enumerate() {
156        for out in &node.output {
157            producer.insert(out.as_str(), idx);
158        }
159    }
160    // Build edges: for each node, find each input's producing node →
161    // adjacency.
162    let n = graph.node.len();
163    let mut in_degree = vec![0usize; n];
164    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
165    for (idx, node) in graph.node.iter().enumerate() {
166        for inp in &node.input {
167            if let Some(&p) = producer.get(inp.as_str()) {
168                if p != idx {
169                    adj[p].push(idx);
170                    in_degree[idx] += 1;
171                }
172            }
173        }
174    }
175    // Kahn's: drain zero-in-degree nodes.
176    let mut queue: std::collections::VecDeque<usize> = in_degree
177        .iter()
178        .enumerate()
179        .filter(|(_, d)| **d == 0)
180        .map(|(i, _)| i)
181        .collect();
182    let mut visited = 0;
183    while let Some(idx) = queue.pop_front() {
184        visited += 1;
185        for &next in &adj[idx] {
186            in_degree[next] -= 1;
187            if in_degree[next] == 0 {
188                queue.push_back(next);
189            }
190        }
191    }
192    if visited != n {
193        let involves: Vec<String> = graph
194            .node
195            .iter()
196            .enumerate()
197            .filter(|(i, _)| in_degree[*i] > 0)
198            .map(|(_, n)| n.name.clone())
199            .collect();
200        return Err(ValidationError::CyclicGraph { involves });
201    }
202    Ok(())
203}
204