1use std::collections::{HashMap, HashSet};
14
15use crate::error::ValidationError;
16use bb_ir::proto::onnx::GraphProto;
17
18pub fn validate(graph: &GraphProto) -> Result<(), ValidationError> {
20 rule_1_known_op(graph)?;
21 rule_2_inputs_reachable(graph)?;
22 rule_3_outputs_unique(graph)?;
23 rule_5_type_declarations_present(graph)?;
24 rule_6_slot_metadata_well_formed(graph)?;
25 rule_7_no_cycles(graph)?;
26 Ok(())
27}
28
29const RESERVED_OPSET_PREFIXES: &[&str] = &["ai.bytesandbrains", "ai.onnx"];
31
32fn is_reserved_opset(domain: &str) -> bool {
33 RESERVED_OPSET_PREFIXES
34 .iter()
35 .any(|p| domain == *p || domain.starts_with(&format!("{p}.")))
36}
37
38fn rule_1_known_op(graph: &GraphProto) -> Result<(), ValidationError> {
40 for node in &graph.node {
41 if !is_reserved_opset(&node.domain) {
42 return Err(ValidationError::UnknownOp {
43 node_name: node.name.clone(),
44 op_type: node.op_type.clone(),
45 domain: node.domain.clone(),
46 });
47 }
48 }
49 Ok(())
50}
51
52fn rule_2_inputs_reachable(graph: &GraphProto) -> Result<(), ValidationError> {
55 let mut produced: HashSet<&str> = HashSet::new();
56 for input in &graph.input {
57 produced.insert(input.name.as_str());
58 }
59 for node in &graph.node {
63 for out in &node.output {
64 produced.insert(out.as_str());
65 }
66 }
67 for node in &graph.node {
68 for inp in &node.input {
69 if inp.is_empty() {
70 continue;
73 }
74 if !produced.contains(inp.as_str()) {
75 return Err(ValidationError::DanglingInput {
76 node_name: node.name.clone(),
77 input_name: inp.clone(),
78 });
79 }
80 }
81 }
82 Ok(())
83}
84
85fn rule_3_outputs_unique(graph: &GraphProto) -> Result<(), ValidationError> {
87 let mut writers: HashMap<&str, &str> = HashMap::new();
88 for node in &graph.node {
89 for out in &node.output {
90 if out.is_empty() {
91 continue;
92 }
93 if let Some(&prev) = writers.get(out.as_str()) {
94 return Err(ValidationError::DuplicateOutput {
95 value_name: out.clone(),
96 node_a: prev.to_string(),
97 node_b: node.name.clone(),
98 });
99 }
100 writers.insert(out.as_str(), node.name.as_str());
101 }
102 }
103 Ok(())
104}
105
106fn rule_5_type_declarations_present(graph: &GraphProto) -> Result<(), ValidationError> {
108 for input in &graph.input {
109 if input.r#type.is_none() {
110 return Err(ValidationError::MissingTypeInfo {
111 input_name: input.name.clone(),
112 });
113 }
114 }
115 Ok(())
116}
117
118fn rule_6_slot_metadata_well_formed(graph: &GraphProto) -> Result<(), ValidationError> {
125 for node in &graph.node {
126 if !node.domain.starts_with("ai.bytesandbrains.role.") {
127 continue;
128 }
129 let has_concrete = meta_has(node, "ai.bytesandbrains.concrete_type")
130 && meta_has(node, "ai.bytesandbrains.instance");
131 let has_generic = meta_has(node, "ai.bytesandbrains.required_trait")
132 && meta_has(node, "ai.bytesandbrains.slot_id");
133 if !has_concrete && !has_generic {
134 return Err(ValidationError::MalformedSlotMetadata {
135 node_name: node.name.clone(),
136 detail: format!(
137 "role-domain NodeProto {} lacks both (concrete_type, instance) and (required_trait, slot_id) metadata",
138 node.op_type,
139 ),
140 });
141 }
142 }
143 Ok(())
144}
145
146fn meta_has(node: &bb_ir::proto::onnx::NodeProto, key: &str) -> bool {
147 node.metadata_props.iter().any(|p| p.key == key)
148}
149
150fn rule_7_no_cycles(graph: &GraphProto) -> Result<(), ValidationError> {
153 let mut producer: HashMap<&str, usize> = HashMap::new();
155 for (idx, node) in graph.node.iter().enumerate() {
156 for out in &node.output {
157 producer.insert(out.as_str(), idx);
158 }
159 }
160 let n = graph.node.len();
163 let mut in_degree = vec![0usize; n];
164 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
165 for (idx, node) in graph.node.iter().enumerate() {
166 for inp in &node.input {
167 if let Some(&p) = producer.get(inp.as_str()) {
168 if p != idx {
169 adj[p].push(idx);
170 in_degree[idx] += 1;
171 }
172 }
173 }
174 }
175 let mut queue: std::collections::VecDeque<usize> = in_degree
177 .iter()
178 .enumerate()
179 .filter(|(_, d)| **d == 0)
180 .map(|(i, _)| i)
181 .collect();
182 let mut visited = 0;
183 while let Some(idx) = queue.pop_front() {
184 visited += 1;
185 for &next in &adj[idx] {
186 in_degree[next] -= 1;
187 if in_degree[next] == 0 {
188 queue.push_back(next);
189 }
190 }
191 }
192 if visited != n {
193 let involves: Vec<String> = graph
194 .node
195 .iter()
196 .enumerate()
197 .filter(|(i, _)| in_degree[*i] > 0)
198 .map(|(_, n)| n.name.clone())
199 .collect();
200 return Err(ValidationError::CyclicGraph { involves });
201 }
202 Ok(())
203}
204