1use itertools::Itertools;
10use portgraph::{NodeIndex, PortOffset};
11use thiserror::Error;
12
13use crate::types::TypeRow;
14
15use super::dataflow::{DataflowOpTrait, DataflowParent};
16use super::{impl_validate_op, BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp};
17
18#[non_exhaustive]
20pub struct OpValidityFlags {
21 pub allowed_children: OpTag,
23 pub allowed_first_child: OpTag,
27 pub allowed_second_child: OpTag,
31 pub requires_children: bool,
33 pub requires_dag: bool,
35 pub edge_check: Option<fn(ChildrenEdgeData) -> Result<(), EdgeValidationError>>,
39}
40
41impl Default for OpValidityFlags {
42 fn default() -> Self {
43 Self {
45 allowed_children: OpTag::None,
46 allowed_first_child: OpTag::Any,
47 allowed_second_child: OpTag::Any,
48 requires_children: false,
49 requires_dag: false,
50 edge_check: None,
51 }
52 }
53}
54
55impl ValidateOp for super::Module {
56 fn validity_flags(&self) -> OpValidityFlags {
57 OpValidityFlags {
58 allowed_children: OpTag::ModuleOp,
59 requires_children: false,
60 ..Default::default()
61 }
62 }
63}
64
65impl ValidateOp for super::Conditional {
66 fn validity_flags(&self) -> OpValidityFlags {
67 OpValidityFlags {
68 allowed_children: OpTag::Case,
69 requires_children: true,
70 requires_dag: false,
71 ..Default::default()
72 }
73 }
74
75 fn validate_op_children<'a>(
76 &self,
77 children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
78 ) -> Result<(), ChildrenValidationError> {
79 let children = children.collect_vec();
80 if self.sum_rows.len() != children.len() {
83 return Err(ChildrenValidationError::InvalidConditionalSum {
84 child: children[0].0, expected_count: children.len(),
86 actual_sum_rows: self.sum_rows.clone(),
87 });
88 }
89
90 for (i, (child, optype)) in children.into_iter().enumerate() {
93 let case_op = optype
94 .as_case()
95 .expect("Child check should have already checked valid ops.");
96 let sig = &case_op.inner_signature();
97 if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs {
98 return Err(ChildrenValidationError::ConditionalCaseSignature {
99 child,
100 optype: optype.clone(),
101 });
102 }
103 }
104
105 Ok(())
106 }
107}
108
109impl ValidateOp for super::CFG {
110 fn validity_flags(&self) -> OpValidityFlags {
111 OpValidityFlags {
112 allowed_children: OpTag::ControlFlowChild,
113 allowed_first_child: OpTag::DataflowBlock,
114 allowed_second_child: OpTag::BasicBlockExit,
115 requires_children: true,
116 requires_dag: false,
117 edge_check: Some(validate_cfg_edge),
118 ..Default::default()
119 }
120 }
121
122 fn validate_op_children<'a>(
123 &self,
124 mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
125 ) -> Result<(), ChildrenValidationError> {
126 let (entry, entry_op) = children.next().unwrap();
127 let (exit, exit_op) = children.next().unwrap();
128 let entry_op = entry_op
129 .as_dataflow_block()
130 .expect("Child check should have already checked valid ops.");
131 let exit_op = exit_op
132 .as_exit_block()
133 .expect("Child check should have already checked valid ops.");
134
135 let sig = self.signature();
136 if entry_op.inner_signature().input() != sig.input() {
137 return Err(ChildrenValidationError::IOSignatureMismatch {
138 child: entry,
139 actual: entry_op.inner_signature().input().clone(),
140 expected: sig.input().clone(),
141 node_desc: "BasicBlock Input",
142 container_desc: "CFG",
143 });
144 }
145 if &exit_op.cfg_outputs != sig.output() {
146 return Err(ChildrenValidationError::IOSignatureMismatch {
147 child: exit,
148 actual: exit_op.cfg_outputs.clone(),
149 expected: sig.output().clone(),
150 node_desc: "BasicBlockExit Output",
151 container_desc: "CFG",
152 });
153 }
154 for (child, optype) in children {
155 if optype.tag() == OpTag::BasicBlockExit {
156 return Err(ChildrenValidationError::InternalExitChildren { child });
157 }
158 }
159 Ok(())
160 }
161}
162#[derive(Debug, Clone, PartialEq, Error)]
164#[allow(missing_docs)]
165#[non_exhaustive]
166pub enum ChildrenValidationError {
167 #[error("Exit basic blocks are only allowed as the second child in a CFG graph")]
169 InternalExitChildren { child: NodeIndex },
170 #[error("A {optype} operation is only allowed as a {expected_position} child")]
172 InternalIOChildren {
173 child: NodeIndex,
174 optype: OpType,
175 expected_position: &'static str,
176 },
177 #[error("The {node_desc} node of a {container_desc} has a signature of {actual}, which differs from the expected type row {expected}")]
179 IOSignatureMismatch {
180 child: NodeIndex,
181 actual: TypeRow,
182 expected: TypeRow,
183 node_desc: &'static str,
184 container_desc: &'static str,
185 },
186 #[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())]
188 ConditionalCaseSignature { child: NodeIndex, optype: OpType },
189 #[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}",
191 actual_sum_rows.len())]
192 InvalidConditionalSum {
193 child: NodeIndex,
194 expected_count: usize,
195 actual_sum_rows: Vec<TypeRow>,
196 },
197}
198
199impl ChildrenValidationError {
200 pub fn child(&self) -> NodeIndex {
202 match self {
203 ChildrenValidationError::InternalIOChildren { child, .. } => *child,
204 ChildrenValidationError::InternalExitChildren { child, .. } => *child,
205 ChildrenValidationError::ConditionalCaseSignature { child, .. } => *child,
206 ChildrenValidationError::IOSignatureMismatch { child, .. } => *child,
207 ChildrenValidationError::InvalidConditionalSum { child, .. } => *child,
208 }
209 }
210}
211
212#[derive(Debug, Clone, PartialEq, Error)]
214#[allow(missing_docs)]
215#[non_exhaustive]
216pub enum EdgeValidationError {
217 #[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}",
219 source_ty = source_types.clone().unwrap_or_default(),
220 )]
221 CFGEdgeSignatureMismatch {
222 edge: ChildrenEdgeData,
223 source_types: Option<TypeRow>,
224 target_types: TypeRow,
225 },
226}
227
228impl EdgeValidationError {
229 pub fn edge(&self) -> &ChildrenEdgeData {
231 match self {
232 EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge,
233 }
234 }
235}
236
237#[derive(Debug, Clone, PartialEq)]
239pub struct ChildrenEdgeData {
240 pub source: NodeIndex,
242 pub target: NodeIndex,
244 pub source_op: OpType,
246 pub target_op: OpType,
248 pub source_port: PortOffset,
250 pub target_port: PortOffset,
252}
253
254impl<T: DataflowParent> ValidateOp for T {
255 fn validity_flags(&self) -> OpValidityFlags {
257 OpValidityFlags {
258 allowed_children: OpTag::DataflowChild,
259 allowed_first_child: OpTag::Input,
260 allowed_second_child: OpTag::Output,
261 requires_children: true,
262 requires_dag: true,
263 ..Default::default()
264 }
265 }
266
267 fn validate_op_children<'a>(
269 &self,
270 children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
271 ) -> Result<(), ChildrenValidationError> {
272 let sig = self.inner_signature();
273 validate_io_nodes(&sig.input, &sig.output, "DataflowParent", children)
274 }
275}
276
277fn validate_io_nodes<'a>(
281 expected_input: &TypeRow,
282 expected_output: &TypeRow,
283 container_desc: &'static str,
284 mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
285) -> Result<(), ChildrenValidationError> {
286 let (first, first_optype) = children.next().unwrap();
288 let (second, second_optype) = children.next().unwrap();
289
290 let first_sig = first_optype.dataflow_signature().unwrap_or_default();
291 if &first_sig.output != expected_input {
292 return Err(ChildrenValidationError::IOSignatureMismatch {
293 child: first,
294 actual: first_sig.into_owned().output,
295 expected: expected_input.clone(),
296 node_desc: "Input",
297 container_desc,
298 });
299 }
300 let second_sig = second_optype.dataflow_signature().unwrap_or_default();
301
302 if &second_sig.input != expected_output {
303 return Err(ChildrenValidationError::IOSignatureMismatch {
304 child: second,
305 actual: second_sig.into_owned().input,
306 expected: expected_output.clone(),
307 node_desc: "Output",
308 container_desc,
309 });
310 }
311
312 for (child, optype) in children {
314 match optype.tag() {
315 OpTag::Input => {
316 return Err(ChildrenValidationError::InternalIOChildren {
317 child,
318 optype: optype.clone(),
319 expected_position: "first",
320 })
321 }
322 OpTag::Output => {
323 return Err(ChildrenValidationError::InternalIOChildren {
324 child,
325 optype: optype.clone(),
326 expected_position: "second",
327 })
328 }
329 _ => {}
330 }
331 }
332 Ok(())
333}
334
335fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> {
337 let source = &edge
338 .source_op
339 .as_dataflow_block()
340 .expect("CFG sibling graphs can only contain basic block operations.");
341
342 let target_input = match &edge.target_op {
343 OpType::DataflowBlock(dfb) => dfb.dataflow_input(),
344 OpType::ExitBlock(exit) => exit.dataflow_input(),
345 _ => panic!("CFG sibling graphs can only contain basic block operations."),
346 };
347
348 let source_types = source.successor_input(edge.source_port.index());
349 if source_types.as_ref() != Some(target_input) {
350 let target_types = target_input.clone();
351 return Err(EdgeValidationError::CFGEdgeSignatureMismatch {
352 edge,
353 source_types,
354 target_types,
355 });
356 }
357
358 Ok(())
359}
360
361#[cfg(test)]
362mod test {
363 use crate::extension::prelude::{usize_t, Noop};
364 use crate::ops;
365 use crate::ops::dataflow::IOTrait;
366 use cool_asserts::assert_matches;
367
368 use super::*;
369
370 #[test]
371 fn test_validate_io_nodes() {
372 let in_types: TypeRow = vec![usize_t()].into();
373 let out_types: TypeRow = vec![usize_t(), usize_t()].into();
374
375 let input_node: OpType = ops::Input::new(in_types.clone()).into();
376 let output_node = ops::Output::new(out_types.clone()).into();
377 let leaf_node = Noop(usize_t()).into();
378
379 let children = vec![
381 (0, &input_node),
382 (1, &output_node),
383 (2, &leaf_node),
384 (3, &leaf_node),
385 ];
386 assert_eq!(
387 validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
388 Ok(())
389 );
390 assert_matches!(
391 validate_io_nodes(&out_types, &out_types, "test", make_iter(&children)),
392 Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 0
393 );
394 assert_matches!(
395 validate_io_nodes(&in_types, &in_types, "test", make_iter(&children)),
396 Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 1
397 );
398
399 let children = vec![
401 (0, &input_node),
402 (1, &output_node),
403 (42, &leaf_node),
404 (2, &leaf_node),
405 (3, &output_node),
406 ];
407 assert_matches!(
408 validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
409 Err(ChildrenValidationError::InternalIOChildren { child, .. }) if child.index() == 3
410 );
411 }
412
413 fn make_iter<'a>(
414 children: &'a [(usize, &OpType)],
415 ) -> impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)> {
416 children.iter().map(|(n, op)| (NodeIndex::new(*n), *op))
417 }
418}
419
420use super::{
421 AliasDecl, AliasDefn, Call, CallIndirect, Const, ExtensionOp, FuncDecl, Input, LoadConstant,
422 LoadFunction, OpaqueOp, Output, Tag,
423};
424impl_validate_op!(FuncDecl);
425impl_validate_op!(AliasDecl);
426impl_validate_op!(AliasDefn);
427impl_validate_op!(Input);
428impl_validate_op!(Output);
429impl_validate_op!(Const);
430impl_validate_op!(Call);
431impl_validate_op!(LoadConstant);
432impl_validate_op!(LoadFunction);
433impl_validate_op!(CallIndirect);
434impl_validate_op!(ExtensionOp);
435impl_validate_op!(OpaqueOp);
436impl_validate_op!(Tag);
437impl_validate_op!(ExitBlock);