hugr_core/ops/
validate.rs

1//! Definitions for validating hugr nodes according to their operation type.
2//!
3//! Adds a `validity_flags` method to [`OpType`] that returns a series of flags
4//! used by the [`crate::hugr::validate`] module.
5//!
6//! It also defines a `validate_op_children` method for more complex tests that
7//! require traversing the children.
8
9use itertools::Itertools;
10use portgraph::{NodeIndex, PortOffset};
11use thiserror::Error;
12
13use crate::types::TypeRow;
14
15use super::dataflow::{DataflowOpTrait, DataflowParent};
16use super::{impl_validate_op, BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp};
17
18/// A set of property flags required for an operation.
19#[non_exhaustive]
20pub struct OpValidityFlags {
21    /// The set of valid children operation types.
22    pub allowed_children: OpTag,
23    /// Additional restrictions on the first child operation.
24    ///
25    /// This is checked in addition to the child allowing the parent optype.
26    pub allowed_first_child: OpTag,
27    /// Additional restrictions on the second child operation
28    ///
29    /// This is checked in addition to the child allowing the parent optype.
30    pub allowed_second_child: OpTag,
31    /// Whether the operation must have children.
32    pub requires_children: bool,
33    /// Whether the children must form a DAG (no cycles).
34    pub requires_dag: bool,
35    /// A validation check for edges between children
36    ///
37    // Enclosed in an `Option` to avoid iterating over the edges if not needed.
38    pub edge_check: Option<fn(ChildrenEdgeData) -> Result<(), EdgeValidationError>>,
39}
40
41impl Default for OpValidityFlags {
42    fn default() -> Self {
43        // Defaults to flags valid for non-container operations
44        Self {
45            allowed_children: OpTag::None,
46            allowed_first_child: OpTag::Any,
47            allowed_second_child: OpTag::Any,
48            requires_children: false,
49            requires_dag: false,
50            edge_check: None,
51        }
52    }
53}
54
55impl ValidateOp for super::Module {
56    fn validity_flags(&self) -> OpValidityFlags {
57        OpValidityFlags {
58            allowed_children: OpTag::ModuleOp,
59            requires_children: false,
60            ..Default::default()
61        }
62    }
63}
64
65impl ValidateOp for super::Conditional {
66    fn validity_flags(&self) -> OpValidityFlags {
67        OpValidityFlags {
68            allowed_children: OpTag::Case,
69            requires_children: true,
70            requires_dag: false,
71            ..Default::default()
72        }
73    }
74
75    fn validate_op_children<'a>(
76        &self,
77        children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
78    ) -> Result<(), ChildrenValidationError> {
79        let children = children.collect_vec();
80        // The first input to the ɣ-node is a value of Sum type,
81        // whose arity matches the number of children of the ɣ-node.
82        if self.sum_rows.len() != children.len() {
83            return Err(ChildrenValidationError::InvalidConditionalSum {
84                child: children[0].0, // Pass an arbitrary child
85                expected_count: children.len(),
86                actual_sum_rows: self.sum_rows.clone(),
87            });
88        }
89
90        // Each child must have its variant's row and the rest of `inputs` as input,
91        // and matching output
92        for (i, (child, optype)) in children.into_iter().enumerate() {
93            let case_op = optype
94                .as_case()
95                .expect("Child check should have already checked valid ops.");
96            let sig = &case_op.inner_signature();
97            if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs {
98                return Err(ChildrenValidationError::ConditionalCaseSignature {
99                    child,
100                    optype: optype.clone(),
101                });
102            }
103        }
104
105        Ok(())
106    }
107}
108
109impl ValidateOp for super::CFG {
110    fn validity_flags(&self) -> OpValidityFlags {
111        OpValidityFlags {
112            allowed_children: OpTag::ControlFlowChild,
113            allowed_first_child: OpTag::DataflowBlock,
114            allowed_second_child: OpTag::BasicBlockExit,
115            requires_children: true,
116            requires_dag: false,
117            edge_check: Some(validate_cfg_edge),
118            ..Default::default()
119        }
120    }
121
122    fn validate_op_children<'a>(
123        &self,
124        mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
125    ) -> Result<(), ChildrenValidationError> {
126        let (entry, entry_op) = children.next().unwrap();
127        let (exit, exit_op) = children.next().unwrap();
128        let entry_op = entry_op
129            .as_dataflow_block()
130            .expect("Child check should have already checked valid ops.");
131        let exit_op = exit_op
132            .as_exit_block()
133            .expect("Child check should have already checked valid ops.");
134
135        let sig = self.signature();
136        if entry_op.inner_signature().input() != sig.input() {
137            return Err(ChildrenValidationError::IOSignatureMismatch {
138                child: entry,
139                actual: entry_op.inner_signature().input().clone(),
140                expected: sig.input().clone(),
141                node_desc: "BasicBlock Input",
142                container_desc: "CFG",
143            });
144        }
145        if &exit_op.cfg_outputs != sig.output() {
146            return Err(ChildrenValidationError::IOSignatureMismatch {
147                child: exit,
148                actual: exit_op.cfg_outputs.clone(),
149                expected: sig.output().clone(),
150                node_desc: "BasicBlockExit Output",
151                container_desc: "CFG",
152            });
153        }
154        for (child, optype) in children {
155            if optype.tag() == OpTag::BasicBlockExit {
156                return Err(ChildrenValidationError::InternalExitChildren { child });
157            }
158        }
159        Ok(())
160    }
161}
162/// Errors that can occur while checking the children of a node.
163#[derive(Debug, Clone, PartialEq, Error)]
164#[allow(missing_docs)]
165#[non_exhaustive]
166pub enum ChildrenValidationError {
167    /// An CFG graph has an exit operation as a non-second child.
168    #[error("Exit basic blocks are only allowed as the second child in a CFG graph")]
169    InternalExitChildren { child: NodeIndex },
170    /// An operation only allowed as the first/second child was found as an intermediate child.
171    #[error("A {optype} operation is only allowed as a {expected_position} child")]
172    InternalIOChildren {
173        child: NodeIndex,
174        optype: OpType,
175        expected_position: &'static str,
176    },
177    /// The signature of the contained dataflow graph does not match the one of the container.
178    #[error("The {node_desc} node of a {container_desc} has a signature of {actual}, which differs from the expected type row {expected}")]
179    IOSignatureMismatch {
180        child: NodeIndex,
181        actual: TypeRow,
182        expected: TypeRow,
183        node_desc: &'static str,
184        container_desc: &'static str,
185    },
186    /// The signature of a child case in a conditional operation does not match the container's signature.
187    #[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())]
188    ConditionalCaseSignature { child: NodeIndex, optype: OpType },
189    /// The conditional container's branching value does not match the number of children.
190    #[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}",
191        actual_sum_rows.len())]
192    InvalidConditionalSum {
193        child: NodeIndex,
194        expected_count: usize,
195        actual_sum_rows: Vec<TypeRow>,
196    },
197}
198
199impl ChildrenValidationError {
200    /// Returns the node index of the child that caused the error.
201    pub fn child(&self) -> NodeIndex {
202        match self {
203            ChildrenValidationError::InternalIOChildren { child, .. } => *child,
204            ChildrenValidationError::InternalExitChildren { child, .. } => *child,
205            ChildrenValidationError::ConditionalCaseSignature { child, .. } => *child,
206            ChildrenValidationError::IOSignatureMismatch { child, .. } => *child,
207            ChildrenValidationError::InvalidConditionalSum { child, .. } => *child,
208        }
209    }
210}
211
212/// Errors that can occur while checking the edges between children of a node.
213#[derive(Debug, Clone, PartialEq, Error)]
214#[allow(missing_docs)]
215#[non_exhaustive]
216pub enum EdgeValidationError {
217    /// The dataflow signature of two connected basic blocks does not match.
218    #[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}",
219        source_ty = source_types.clone().unwrap_or_default(),
220    )]
221    CFGEdgeSignatureMismatch {
222        edge: ChildrenEdgeData,
223        source_types: Option<TypeRow>,
224        target_types: TypeRow,
225    },
226}
227
228impl EdgeValidationError {
229    /// Returns information on the edge that caused the error.
230    pub fn edge(&self) -> &ChildrenEdgeData {
231        match self {
232            EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge,
233        }
234    }
235}
236
237/// Auxiliary structure passed as data to [`OpValidityFlags::edge_check`].
238#[derive(Debug, Clone, PartialEq)]
239pub struct ChildrenEdgeData {
240    /// Source child.
241    pub source: NodeIndex,
242    /// Target child.
243    pub target: NodeIndex,
244    /// Operation type of the source child.
245    pub source_op: OpType,
246    /// Operation type of the target child.
247    pub target_op: OpType,
248    /// Source port.
249    pub source_port: PortOffset,
250    /// Target port.
251    pub target_port: PortOffset,
252}
253
254impl<T: DataflowParent> ValidateOp for T {
255    /// Returns the set of allowed parent operation types.
256    fn validity_flags(&self) -> OpValidityFlags {
257        OpValidityFlags {
258            allowed_children: OpTag::DataflowChild,
259            allowed_first_child: OpTag::Input,
260            allowed_second_child: OpTag::Output,
261            requires_children: true,
262            requires_dag: true,
263            ..Default::default()
264        }
265    }
266
267    /// Validate the ordered list of children.
268    fn validate_op_children<'a>(
269        &self,
270        children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
271    ) -> Result<(), ChildrenValidationError> {
272        let sig = self.inner_signature();
273        validate_io_nodes(&sig.input, &sig.output, "DataflowParent", children)
274    }
275}
276
277/// Checks a that the list of children nodes does not contain Input and Output
278/// nodes outside of the first and second elements respectively, and that those
279/// have the correct signature.
280fn validate_io_nodes<'a>(
281    expected_input: &TypeRow,
282    expected_output: &TypeRow,
283    container_desc: &'static str,
284    mut children: impl Iterator<Item = (NodeIndex, &'a OpType)>,
285) -> Result<(), ChildrenValidationError> {
286    // Check that the signature matches with the Input and Output rows.
287    let (first, first_optype) = children.next().unwrap();
288    let (second, second_optype) = children.next().unwrap();
289
290    let first_sig = first_optype.dataflow_signature().unwrap_or_default();
291    if &first_sig.output != expected_input {
292        return Err(ChildrenValidationError::IOSignatureMismatch {
293            child: first,
294            actual: first_sig.into_owned().output,
295            expected: expected_input.clone(),
296            node_desc: "Input",
297            container_desc,
298        });
299    }
300    let second_sig = second_optype.dataflow_signature().unwrap_or_default();
301
302    if &second_sig.input != expected_output {
303        return Err(ChildrenValidationError::IOSignatureMismatch {
304            child: second,
305            actual: second_sig.into_owned().input,
306            expected: expected_output.clone(),
307            node_desc: "Output",
308            container_desc,
309        });
310    }
311
312    // The first and second children have already been popped from the iterator.
313    for (child, optype) in children {
314        match optype.tag() {
315            OpTag::Input => {
316                return Err(ChildrenValidationError::InternalIOChildren {
317                    child,
318                    optype: optype.clone(),
319                    expected_position: "first",
320                })
321            }
322            OpTag::Output => {
323                return Err(ChildrenValidationError::InternalIOChildren {
324                    child,
325                    optype: optype.clone(),
326                    expected_position: "second",
327                })
328            }
329            _ => {}
330        }
331    }
332    Ok(())
333}
334
335/// Validate an edge between two basic blocks in a CFG sibling graph.
336fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> {
337    let source = &edge
338        .source_op
339        .as_dataflow_block()
340        .expect("CFG sibling graphs can only contain basic block operations.");
341
342    let target_input = match &edge.target_op {
343        OpType::DataflowBlock(dfb) => dfb.dataflow_input(),
344        OpType::ExitBlock(exit) => exit.dataflow_input(),
345        _ => panic!("CFG sibling graphs can only contain basic block operations."),
346    };
347
348    let source_types = source.successor_input(edge.source_port.index());
349    if source_types.as_ref() != Some(target_input) {
350        let target_types = target_input.clone();
351        return Err(EdgeValidationError::CFGEdgeSignatureMismatch {
352            edge,
353            source_types,
354            target_types,
355        });
356    }
357
358    Ok(())
359}
360
361#[cfg(test)]
362mod test {
363    use crate::extension::prelude::{usize_t, Noop};
364    use crate::ops;
365    use crate::ops::dataflow::IOTrait;
366    use cool_asserts::assert_matches;
367
368    use super::*;
369
370    #[test]
371    fn test_validate_io_nodes() {
372        let in_types: TypeRow = vec![usize_t()].into();
373        let out_types: TypeRow = vec![usize_t(), usize_t()].into();
374
375        let input_node: OpType = ops::Input::new(in_types.clone()).into();
376        let output_node = ops::Output::new(out_types.clone()).into();
377        let leaf_node = Noop(usize_t()).into();
378
379        // Well-formed dataflow sibling nodes. Check the input and output node signatures.
380        let children = vec![
381            (0, &input_node),
382            (1, &output_node),
383            (2, &leaf_node),
384            (3, &leaf_node),
385        ];
386        assert_eq!(
387            validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
388            Ok(())
389        );
390        assert_matches!(
391            validate_io_nodes(&out_types, &out_types, "test", make_iter(&children)),
392            Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 0
393        );
394        assert_matches!(
395            validate_io_nodes(&in_types, &in_types, "test", make_iter(&children)),
396            Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 1
397        );
398
399        // Internal I/O nodes
400        let children = vec![
401            (0, &input_node),
402            (1, &output_node),
403            (42, &leaf_node),
404            (2, &leaf_node),
405            (3, &output_node),
406        ];
407        assert_matches!(
408            validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
409            Err(ChildrenValidationError::InternalIOChildren { child, .. }) if child.index() == 3
410        );
411    }
412
413    fn make_iter<'a>(
414        children: &'a [(usize, &OpType)],
415    ) -> impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)> {
416        children.iter().map(|(n, op)| (NodeIndex::new(*n), *op))
417    }
418}
419
420use super::{
421    AliasDecl, AliasDefn, Call, CallIndirect, Const, ExtensionOp, FuncDecl, Input, LoadConstant,
422    LoadFunction, OpaqueOp, Output, Tag,
423};
424impl_validate_op!(FuncDecl);
425impl_validate_op!(AliasDecl);
426impl_validate_op!(AliasDefn);
427impl_validate_op!(Input);
428impl_validate_op!(Output);
429impl_validate_op!(Const);
430impl_validate_op!(Call);
431impl_validate_op!(LoadConstant);
432impl_validate_op!(LoadFunction);
433impl_validate_op!(CallIndirect);
434impl_validate_op!(ExtensionOp);
435impl_validate_op!(OpaqueOp);
436impl_validate_op!(Tag);
437impl_validate_op!(ExitBlock);