1use std::collections::{BTreeSet, HashMap};
45
46use crate::cluster::{ExpandedEndpoint, ExpandedGraph, PrimitiveCatalog, PrimitiveKind, ValueType};
47
48use super::types::{Endpoint, GraphValidationError, ValidatedEdge, ValidatedGraph, ValidatedNode};
49
50pub fn validate<C: PrimitiveCatalog>(
51 expanded: &ExpandedGraph,
52 catalog: &C,
53) -> Result<ValidatedGraph, GraphValidationError> {
54 let mut nodes: HashMap<String, ValidatedNode> = HashMap::new();
55
56 for (id, node) in &expanded.nodes {
57 let meta = catalog
58 .get(&node.implementation.impl_id, &node.implementation.version)
59 .ok_or_else(|| GraphValidationError::MissingPrimitive {
60 id: node.implementation.impl_id.clone(),
61 version: node.implementation.version.clone(),
62 })?;
63
64 nodes.insert(
65 id.clone(),
66 ValidatedNode {
67 runtime_id: id.clone(),
68 impl_id: node.implementation.impl_id.clone(),
69 version: node.implementation.version.clone(),
70 kind: meta.kind.clone(),
71 inputs: meta.inputs.clone(),
72 outputs: meta.outputs.clone(),
73 parameters: node.parameters.clone(),
74 },
75 );
76 }
77
78 let edges: Vec<ValidatedEdge> = expanded
79 .edges
80 .iter()
81 .map(|e| {
82 Ok(ValidatedEdge {
83 from: map_endpoint(&e.from)?,
84 to: map_endpoint(&e.to)?,
85 })
86 })
87 .collect::<Result<Vec<_>, _>>()?;
88
89 enforce_edge_nodes_exist(&nodes, &edges)?;
90 enforce_single_edge_per_input(&edges)?;
91 let topo_order = topological_sort(&nodes, &edges)?;
92
93 enforce_wiring_matrix(&nodes, &edges)?;
94 enforce_required_inputs(&nodes, &edges)?;
95 enforce_types(&nodes, &edges)?;
96 enforce_action_gating(&nodes, &edges)?;
97 enforce_boundary_outputs(&nodes, &expanded.boundary_outputs)?;
98
99 Ok(ValidatedGraph {
100 nodes,
101 edges,
102 topo_order,
103 boundary_outputs: expanded.boundary_outputs.clone(),
104 })
105}
106
107fn map_endpoint(ep: &ExpandedEndpoint) -> Result<Endpoint, GraphValidationError> {
108 match ep {
109 ExpandedEndpoint::NodePort { node_id, port_name } => Ok(Endpoint::NodePort {
110 node_id: node_id.clone(),
111 port_name: port_name.clone(),
112 }),
113 ExpandedEndpoint::ExternalInput { name } => {
114 Err(GraphValidationError::ExternalInputNotAllowed { name: name.clone() })
115 }
116 }
117}
118
119fn topological_sort(
120 nodes: &HashMap<String, ValidatedNode>,
121 edges: &[ValidatedEdge],
122) -> Result<Vec<String>, GraphValidationError> {
123 let mut in_degree: HashMap<String, usize> = nodes.keys().map(|k| (k.clone(), 0)).collect();
124 let mut dependents: HashMap<String, Vec<String>> =
125 nodes.keys().map(|k| (k.clone(), vec![])).collect();
126
127 for edge in edges {
128 let Endpoint::NodePort { node_id: from, .. } = &edge.from;
129 let Endpoint::NodePort { node_id: to, .. } = &edge.to;
130 *in_degree
131 .get_mut(to)
132 .ok_or_else(|| GraphValidationError::UnknownNode(to.clone()))? += 1;
133 dependents
134 .get_mut(from)
135 .ok_or_else(|| GraphValidationError::UnknownNode(from.clone()))?
136 .push(to.clone());
137 }
138
139 let mut queue: BTreeSet<String> = in_degree
140 .iter()
141 .filter(|(_, deg)| **deg == 0)
142 .map(|(id, _)| id.clone())
143 .collect();
144
145 let mut sorted = Vec::new();
146
147 while let Some(node_id) = queue.iter().next().cloned() {
148 queue.remove(&node_id);
149 sorted.push(node_id.clone());
150
151 if let Some(deps) = dependents.get(&node_id) {
152 for dep in deps {
153 let deg = in_degree
154 .get_mut(dep)
155 .ok_or_else(|| GraphValidationError::UnknownNode(dep.clone()))?;
156 *deg -= 1;
157 if *deg == 0 {
158 queue.insert(dep.clone());
159 }
160 }
161 }
162 }
163
164 if sorted.len() != nodes.len() {
165 return Err(GraphValidationError::CycleDetected);
166 }
167
168 Ok(sorted)
169}
170
171fn enforce_edge_nodes_exist(
172 nodes: &HashMap<String, ValidatedNode>,
173 edges: &[ValidatedEdge],
174) -> Result<(), GraphValidationError> {
175 for edge in edges {
176 let Endpoint::NodePort { node_id: from, .. } = &edge.from;
177 if !nodes.contains_key(from) {
178 return Err(GraphValidationError::UnknownNode(from.clone()));
179 }
180
181 let Endpoint::NodePort { node_id: to, .. } = &edge.to;
182 if !nodes.contains_key(to) {
183 return Err(GraphValidationError::UnknownNode(to.clone()));
184 }
185 }
186 Ok(())
187}
188
189fn enforce_wiring_matrix(
190 nodes: &HashMap<String, ValidatedNode>,
191 edges: &[ValidatedEdge],
192) -> Result<(), GraphValidationError> {
193 for edge in edges {
194 let Endpoint::NodePort {
195 node_id: from,
196 port_name: _from_port,
197 } = &edge.from;
198 let Endpoint::NodePort {
199 node_id: to,
200 port_name: to_port,
201 } = &edge.to;
202
203 let from_node = nodes
204 .get(from)
205 .ok_or_else(|| GraphValidationError::UnknownNode(from.clone()))?;
206 let to_node = nodes
207 .get(to)
208 .ok_or_else(|| GraphValidationError::UnknownNode(to.clone()))?;
209
210 if !wiring_allowed_for_edge(from_node, to_node, to_port)? {
211 return Err(GraphValidationError::InvalidEdgeKind {
212 from: from_node.kind.clone(),
213 to: to_node.kind.clone(),
214 });
215 }
216 }
217 Ok(())
218}
219
220fn enforce_required_inputs(
221 nodes: &HashMap<String, ValidatedNode>,
222 edges: &[ValidatedEdge],
223) -> Result<(), GraphValidationError> {
224 let mut incoming: HashMap<(&String, &str), bool> = HashMap::new();
225 for edge in edges {
226 let Endpoint::NodePort {
227 node_id: to,
228 port_name,
229 } = &edge.to;
230 incoming.insert((to, port_name.as_str()), true);
231 }
232
233 for node in nodes.values() {
234 for input in node.required_inputs() {
235 if !incoming.contains_key(&(&node.runtime_id, input.name.as_str())) {
236 return Err(GraphValidationError::MissingRequiredInput {
237 node: node.runtime_id.clone(),
238 input: input.name.clone(),
239 });
240 }
241 }
242 }
243 Ok(())
244}
245
246fn enforce_types(
247 nodes: &HashMap<String, ValidatedNode>,
248 edges: &[ValidatedEdge],
249) -> Result<(), GraphValidationError> {
250 for edge in edges {
251 let Endpoint::NodePort {
252 node_id: from,
253 port_name: from_port,
254 } = &edge.from;
255 let Endpoint::NodePort {
256 node_id: to,
257 port_name: to_port,
258 } = &edge.to;
259
260 let from_node = nodes
261 .get(from)
262 .ok_or_else(|| GraphValidationError::UnknownNode(from.clone()))?;
263 let to_node = nodes
264 .get(to)
265 .ok_or_else(|| GraphValidationError::UnknownNode(to.clone()))?;
266
267 let from_type = from_node
268 .outputs
269 .get(from_port)
270 .ok_or_else(|| GraphValidationError::MissingOutputMetadata {
271 node: from.clone(),
272 output: from_port.clone(),
273 })?
274 .value_type
275 .clone();
276
277 let expected = to_node
278 .inputs
279 .iter()
280 .find(|i| i.name == *to_port)
281 .ok_or_else(|| GraphValidationError::MissingInputMetadata {
282 node: to.clone(),
283 input: to_port.clone(),
284 })?
285 .value_type
286 .clone();
287
288 if from_type != expected {
289 return Err(GraphValidationError::TypeMismatch {
290 from: from.clone(),
291 output: from_port.clone(),
292 to: to.clone(),
293 input: to_port.clone(),
294 expected,
295 got: from_type,
296 });
297 }
298 }
299
300 Ok(())
301}
302
303fn enforce_action_gating(
304 nodes: &HashMap<String, ValidatedNode>,
305 edges: &[ValidatedEdge],
306) -> Result<(), GraphValidationError> {
307 let mut action_inputs: HashMap<String, bool> = HashMap::new();
308
309 for edge in edges {
310 let Endpoint::NodePort { node_id: to, .. } = &edge.to;
311 if let Some(target) = nodes.get(to) {
312 if target.kind == PrimitiveKind::Action {
313 let Endpoint::NodePort {
314 node_id: from,
315 port_name: from_port,
316 } = &edge.from;
317 if let Some(src) = nodes.get(from) {
318 if src.kind == PrimitiveKind::Trigger {
319 if let Some(meta) = src.outputs.get(from_port) {
320 if meta.value_type == ValueType::Event {
321 action_inputs.insert(to.clone(), true);
322 }
323 }
324 }
325 }
326 }
327 }
328 }
329
330 for (id, node) in nodes {
331 if node.kind == PrimitiveKind::Action && !action_inputs.get(id).copied().unwrap_or(false) {
332 return Err(GraphValidationError::ActionNotGated(id.clone()));
333 }
334 }
335
336 Ok(())
337}
338
339fn enforce_boundary_outputs(
340 nodes: &HashMap<String, ValidatedNode>,
341 boundary_outputs: &[crate::cluster::OutputPortSpec],
342) -> Result<(), GraphValidationError> {
343 for output in boundary_outputs {
344 let target_node = nodes
345 .get(&output.maps_to.node_id)
346 .ok_or_else(|| GraphValidationError::UnknownNode(output.maps_to.node_id.clone()))?;
347
348 if !target_node.outputs.contains_key(&output.maps_to.port_name) {
349 return Err(GraphValidationError::MissingOutputMetadata {
350 node: output.maps_to.node_id.clone(),
351 output: output.maps_to.port_name.clone(),
352 });
353 }
354 }
355
356 Ok(())
357}
358
359fn wiring_allowed(from: &PrimitiveKind, to: &PrimitiveKind) -> bool {
360 matches!(
361 (from, to),
362 (PrimitiveKind::Source, PrimitiveKind::Compute)
363 | (PrimitiveKind::Compute, PrimitiveKind::Compute)
364 | (PrimitiveKind::Compute, PrimitiveKind::Trigger)
365 | (PrimitiveKind::Trigger, PrimitiveKind::Trigger)
366 | (PrimitiveKind::Trigger, PrimitiveKind::Action)
367 )
368}
369
370fn wiring_allowed_for_edge(
371 from_node: &ValidatedNode,
372 to_node: &ValidatedNode,
373 to_port: &str,
374) -> Result<bool, GraphValidationError> {
375 if wiring_allowed(&from_node.kind, &to_node.kind) {
376 return Ok(true);
377 }
378
379 if matches!(
382 from_node.kind,
383 PrimitiveKind::Source | PrimitiveKind::Compute
384 ) && to_node.kind == PrimitiveKind::Action
385 {
386 let target_input = to_node
387 .inputs
388 .iter()
389 .find(|input| input.name == to_port)
390 .ok_or_else(|| GraphValidationError::MissingInputMetadata {
391 node: to_node.runtime_id.clone(),
392 input: to_port.to_string(),
393 })?;
394
395 if matches!(
396 target_input.value_type,
397 ValueType::Number | ValueType::Series | ValueType::Bool | ValueType::String
398 ) {
399 return Ok(true);
401 }
402 }
403
404 Ok(false)
405}
406
407fn enforce_single_edge_per_input(edges: &[ValidatedEdge]) -> Result<(), GraphValidationError> {
410 let mut inbound_count: HashMap<(&String, &String), usize> = HashMap::new();
411
412 for edge in edges {
413 let Endpoint::NodePort { node_id, port_name } = &edge.to;
414 *inbound_count.entry((node_id, port_name)).or_insert(0) += 1;
415 }
416
417 for ((node_id, port_name), count) in inbound_count {
418 if count > 1 {
419 return Err(GraphValidationError::MultipleInboundEdges {
420 node: node_id.clone(),
421 input: port_name.clone(),
422 });
423 }
424 }
425
426 Ok(())
427}