1use itertools::Itertools;
10use thiserror::Error;
11
12use crate::core::HugrNode;
13use crate::types::TypeRow;
14use crate::{Node, Port, PortIndex};
15
16use super::dataflow::{DataflowOpTrait, DataflowParent};
17use super::{BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp, impl_validate_op};
18
19pub type EdgeCheck<N> = fn(ChildrenEdgeData<N>) -> Result<(), EdgeValidationError<N>>;
23
24#[non_exhaustive]
26pub struct OpValidityFlags<N: HugrNode = Node> {
27 pub allowed_children: OpTag,
29 pub allowed_first_child: OpTag,
33 pub allowed_second_child: OpTag,
37 pub requires_children: bool,
39 pub requires_dag: bool,
41 pub edge_check: Option<EdgeCheck<N>>,
45}
46
47impl<N: HugrNode> Default for OpValidityFlags<N> {
48 fn default() -> Self {
49 Self {
51 allowed_children: OpTag::None,
52 allowed_first_child: OpTag::Any,
53 allowed_second_child: OpTag::Any,
54 requires_children: false,
55 requires_dag: false,
56 edge_check: None,
57 }
58 }
59}
60
61impl ValidateOp for super::Module {
62 fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
63 OpValidityFlags {
64 allowed_children: OpTag::ModuleOp,
65 requires_children: false,
66 ..Default::default()
67 }
68 }
69}
70
71impl ValidateOp for super::Conditional {
72 fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
73 OpValidityFlags {
74 allowed_children: OpTag::Case,
75 requires_children: true,
76 requires_dag: false,
77 ..Default::default()
78 }
79 }
80
81 fn validate_op_children<'a, N: HugrNode>(
82 &self,
83 children: impl DoubleEndedIterator<Item = (N, &'a OpType)>,
84 ) -> Result<(), ChildrenValidationError<N>> {
85 let children = children.collect_vec();
86 if self.sum_rows.len() != children.len() {
89 return Err(ChildrenValidationError::InvalidConditionalSum {
90 child: children[0].0, expected_count: children.len(),
92 actual_sum_rows: self.sum_rows.clone(),
93 });
94 }
95
96 for (i, (child, optype)) in children.into_iter().enumerate() {
99 let case_op = optype
100 .as_case()
101 .expect("Child check should have already checked valid ops.");
102 let sig = &case_op.inner_signature();
103 if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs {
104 return Err(ChildrenValidationError::ConditionalCaseSignature {
105 child,
106 optype: optype.clone(),
107 });
108 }
109 }
110
111 Ok(())
112 }
113}
114
115impl ValidateOp for super::CFG {
116 fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
117 OpValidityFlags {
118 allowed_children: OpTag::ControlFlowChild,
119 allowed_first_child: OpTag::DataflowBlock,
120 allowed_second_child: OpTag::BasicBlockExit,
121 requires_children: true,
122 requires_dag: false,
123 edge_check: Some(validate_cfg_edge),
124 ..Default::default()
125 }
126 }
127
128 fn validate_op_children<'a, N: HugrNode>(
129 &self,
130 mut children: impl Iterator<Item = (N, &'a OpType)>,
131 ) -> Result<(), ChildrenValidationError<N>> {
132 let (entry, entry_op) = children.next().unwrap();
133 let (exit, exit_op) = children.next().unwrap();
134 let entry_op = entry_op
135 .as_dataflow_block()
136 .expect("Child check should have already checked valid ops.");
137 let exit_op = exit_op
138 .as_exit_block()
139 .expect("Child check should have already checked valid ops.");
140
141 let sig = self.signature();
142 if entry_op.inner_signature().input() != sig.input() {
143 return Err(ChildrenValidationError::IOSignatureMismatch {
144 child: entry,
145 actual: entry_op.inner_signature().input().clone(),
146 expected: sig.input().clone(),
147 node_desc: "BasicBlock Input",
148 container_desc: "CFG",
149 });
150 }
151 if &exit_op.cfg_outputs != sig.output() {
152 return Err(ChildrenValidationError::IOSignatureMismatch {
153 child: exit,
154 actual: exit_op.cfg_outputs.clone(),
155 expected: sig.output().clone(),
156 node_desc: "BasicBlockExit Output",
157 container_desc: "CFG",
158 });
159 }
160 for (child, optype) in children {
161 if optype.tag() == OpTag::BasicBlockExit {
162 return Err(ChildrenValidationError::InternalExitChildren { child });
163 }
164 }
165 Ok(())
166 }
167}
168#[derive(Debug, Clone, PartialEq, Error)]
170#[allow(missing_docs)]
171#[non_exhaustive]
172pub enum ChildrenValidationError<N: HugrNode> {
173 #[error("Exit basic blocks are only allowed as the second child in a CFG graph")]
175 InternalExitChildren { child: N },
176 #[error("A {optype} operation is only allowed as a {expected_position} child")]
178 InternalIOChildren {
179 child: N,
180 optype: OpType,
181 expected_position: &'static str,
182 },
183 #[error(
185 "The {node_desc} node of a {container_desc} has a signature of {actual}, which differs from the expected type row {expected}"
186 )]
187 IOSignatureMismatch {
188 child: N,
189 actual: TypeRow,
190 expected: TypeRow,
191 node_desc: &'static str,
192 container_desc: &'static str,
193 },
194 #[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())]
196 ConditionalCaseSignature { child: N, optype: OpType },
197 #[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}",
199 actual_sum_rows.len())]
200 InvalidConditionalSum {
201 child: N,
202 expected_count: usize,
203 actual_sum_rows: Vec<TypeRow>,
204 },
205}
206
207impl<N: HugrNode> ChildrenValidationError<N> {
208 pub fn child(&self) -> N {
210 match self {
211 ChildrenValidationError::InternalIOChildren { child, .. } => *child,
212 ChildrenValidationError::InternalExitChildren { child, .. } => *child,
213 ChildrenValidationError::ConditionalCaseSignature { child, .. } => *child,
214 ChildrenValidationError::IOSignatureMismatch { child, .. } => *child,
215 ChildrenValidationError::InvalidConditionalSum { child, .. } => *child,
216 }
217 }
218}
219
220#[derive(Debug, Clone, PartialEq, Error)]
222#[allow(missing_docs)]
223#[non_exhaustive]
224pub enum EdgeValidationError<N: HugrNode> {
225 #[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}",
227 source_ty = source_types.clone().unwrap_or_default(),
228 )]
229 CFGEdgeSignatureMismatch {
230 edge: ChildrenEdgeData<N>,
231 source_types: Option<TypeRow>,
232 target_types: TypeRow,
233 },
234}
235
236impl<N: HugrNode> EdgeValidationError<N> {
237 pub fn edge(&self) -> &ChildrenEdgeData<N> {
239 match self {
240 EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge,
241 }
242 }
243}
244
245#[derive(Debug, Clone, PartialEq)]
247pub struct ChildrenEdgeData<N: HugrNode> {
248 pub source: N,
250 pub target: N,
252 pub source_op: OpType,
254 pub target_op: OpType,
256 pub source_port: Port,
258 pub target_port: Port,
260}
261
262impl<T: DataflowParent> ValidateOp for T {
263 fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
265 OpValidityFlags {
266 allowed_children: OpTag::DataflowChild,
267 allowed_first_child: OpTag::Input,
268 allowed_second_child: OpTag::Output,
269 requires_children: true,
270 requires_dag: true,
271 ..Default::default()
272 }
273 }
274
275 fn validate_op_children<'a, N: HugrNode>(
277 &self,
278 children: impl DoubleEndedIterator<Item = (N, &'a OpType)>,
279 ) -> Result<(), ChildrenValidationError<N>> {
280 let sig = self.inner_signature();
281 validate_io_nodes(&sig.input, &sig.output, "DataflowParent", children)
282 }
283}
284
285fn validate_io_nodes<'a, N: HugrNode>(
289 expected_input: &TypeRow,
290 expected_output: &TypeRow,
291 container_desc: &'static str,
292 mut children: impl Iterator<Item = (N, &'a OpType)>,
293) -> Result<(), ChildrenValidationError<N>> {
294 let (first, first_optype) = children.next().unwrap();
296 let (second, second_optype) = children.next().unwrap();
297
298 let first_sig = first_optype.dataflow_signature().unwrap_or_default();
299 if &first_sig.output != expected_input {
300 return Err(ChildrenValidationError::IOSignatureMismatch {
301 child: first,
302 actual: first_sig.into_owned().output,
303 expected: expected_input.clone(),
304 node_desc: "Input",
305 container_desc,
306 });
307 }
308 let second_sig = second_optype.dataflow_signature().unwrap_or_default();
309
310 if &second_sig.input != expected_output {
311 return Err(ChildrenValidationError::IOSignatureMismatch {
312 child: second,
313 actual: second_sig.into_owned().input,
314 expected: expected_output.clone(),
315 node_desc: "Output",
316 container_desc,
317 });
318 }
319
320 for (child, optype) in children {
322 match optype.tag() {
323 OpTag::Input => {
324 return Err(ChildrenValidationError::InternalIOChildren {
325 child,
326 optype: optype.clone(),
327 expected_position: "first",
328 });
329 }
330 OpTag::Output => {
331 return Err(ChildrenValidationError::InternalIOChildren {
332 child,
333 optype: optype.clone(),
334 expected_position: "second",
335 });
336 }
337 _ => {}
338 }
339 }
340 Ok(())
341}
342
343fn validate_cfg_edge<N: HugrNode>(edge: ChildrenEdgeData<N>) -> Result<(), EdgeValidationError<N>> {
345 let source = &edge
346 .source_op
347 .as_dataflow_block()
348 .expect("CFG sibling graphs can only contain basic block operations.");
349
350 let target_input = match &edge.target_op {
351 OpType::DataflowBlock(dfb) => dfb.dataflow_input(),
352 OpType::ExitBlock(exit) => exit.dataflow_input(),
353 _ => panic!("CFG sibling graphs can only contain basic block operations."),
354 };
355
356 let source_types = source.successor_input(edge.source_port.index());
357 if source_types.as_ref() != Some(target_input) {
358 let target_types = target_input.clone();
359 return Err(EdgeValidationError::CFGEdgeSignatureMismatch {
360 edge,
361 source_types,
362 target_types,
363 });
364 }
365
366 Ok(())
367}
368
369#[cfg(test)]
370mod test {
371 use crate::extension::prelude::{Noop, usize_t};
372 use crate::ops::dataflow::IOTrait;
373 use crate::{Node, NodeIndex as _, ops};
374 use cool_asserts::assert_matches;
375 use portgraph::NodeIndex;
376
377 use super::*;
378
379 #[test]
380 fn test_validate_io_nodes() {
381 let in_types: TypeRow = vec![usize_t()].into();
382 let out_types: TypeRow = vec![usize_t(), usize_t()].into();
383
384 let input_node: OpType = ops::Input::new(in_types.clone()).into();
385 let output_node = ops::Output::new(out_types.clone()).into();
386 let leaf_node = Noop(usize_t()).into();
387
388 let children = vec![
390 (0, &input_node),
391 (1, &output_node),
392 (2, &leaf_node),
393 (3, &leaf_node),
394 ];
395 assert_eq!(
396 validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
397 Ok(())
398 );
399 assert_matches!(
400 validate_io_nodes(&out_types, &out_types, "test", make_iter(&children)),
401 Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 0
402 );
403 assert_matches!(
404 validate_io_nodes(&in_types, &in_types, "test", make_iter(&children)),
405 Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 1
406 );
407
408 let children = vec![
410 (0, &input_node),
411 (1, &output_node),
412 (42, &leaf_node),
413 (2, &leaf_node),
414 (3, &output_node),
415 ];
416 assert_matches!(
417 validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
418 Err(ChildrenValidationError::InternalIOChildren { child, .. }) if child.index() == 3
419 );
420 }
421
422 fn make_iter<'a>(
423 children: &'a [(usize, &OpType)],
424 ) -> impl DoubleEndedIterator<Item = (Node, &'a OpType)> {
425 children
426 .iter()
427 .map(|(n, op)| (NodeIndex::new(*n).into(), *op))
428 }
429}
430
431use super::{
432 AliasDecl, AliasDefn, Call, CallIndirect, Const, ExtensionOp, FuncDecl, Input, LoadConstant,
433 LoadFunction, OpaqueOp, Output, Tag,
434};
435impl_validate_op!(FuncDecl);
436impl_validate_op!(AliasDecl);
437impl_validate_op!(AliasDefn);
438impl_validate_op!(Input);
439impl_validate_op!(Output);
440impl_validate_op!(Const);
441impl_validate_op!(Call);
442impl_validate_op!(LoadConstant);
443impl_validate_op!(LoadFunction);
444impl_validate_op!(CallIndirect);
445impl_validate_op!(ExtensionOp);
446impl_validate_op!(OpaqueOp);
447impl_validate_op!(Tag);
448impl_validate_op!(ExitBlock);