hugr_core/ops/
validate.rs

1//! Definitions for validating hugr nodes according to their operation type.
2//!
3//! Adds a `validity_flags` method to [`OpType`] that returns a series of flags
4//! used by the [`crate::hugr::validate`] module.
5//!
6//! It also defines a `validate_op_children` method for more complex tests that
7//! require traversing the children.
8
9use itertools::Itertools;
10use thiserror::Error;
11
12use crate::core::HugrNode;
13use crate::types::TypeRow;
14use crate::{Node, Port, PortIndex};
15
16use super::dataflow::{DataflowOpTrait, DataflowParent};
17use super::{BasicBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp, impl_validate_op};
18
19/// A function that checks the edges between children of a node.
20///
21/// Part of the [`OpValidityFlags`] struct.
22pub type EdgeCheck<N> = fn(ChildrenEdgeData<N>) -> Result<(), EdgeValidationError<N>>;
23
24/// A set of property flags required for an operation.
25#[non_exhaustive]
26pub struct OpValidityFlags<N: HugrNode = Node> {
27    /// The set of valid children operation types.
28    pub allowed_children: OpTag,
29    /// Additional restrictions on the first child operation.
30    ///
31    /// This is checked in addition to the child allowing the parent optype.
32    pub allowed_first_child: OpTag,
33    /// Additional restrictions on the second child operation
34    ///
35    /// This is checked in addition to the child allowing the parent optype.
36    pub allowed_second_child: OpTag,
37    /// Whether the operation must have children.
38    pub requires_children: bool,
39    /// Whether the children must form a DAG (no cycles).
40    pub requires_dag: bool,
41    /// A validation check for edges between children
42    ///
43    // Enclosed in an `Option` to avoid iterating over the edges if not needed.
44    pub edge_check: Option<EdgeCheck<N>>,
45}
46
47impl<N: HugrNode> Default for OpValidityFlags<N> {
48    fn default() -> Self {
49        // Defaults to flags valid for non-container operations
50        Self {
51            allowed_children: OpTag::None,
52            allowed_first_child: OpTag::Any,
53            allowed_second_child: OpTag::Any,
54            requires_children: false,
55            requires_dag: false,
56            edge_check: None,
57        }
58    }
59}
60
61impl ValidateOp for super::Module {
62    fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
63        OpValidityFlags {
64            allowed_children: OpTag::ModuleOp,
65            requires_children: false,
66            ..Default::default()
67        }
68    }
69}
70
71impl ValidateOp for super::Conditional {
72    fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
73        OpValidityFlags {
74            allowed_children: OpTag::Case,
75            requires_children: true,
76            requires_dag: false,
77            ..Default::default()
78        }
79    }
80
81    fn validate_op_children<'a, N: HugrNode>(
82        &self,
83        children: impl DoubleEndedIterator<Item = (N, &'a OpType)>,
84    ) -> Result<(), ChildrenValidationError<N>> {
85        let children = children.collect_vec();
86        // The first input to the ɣ-node is a value of Sum type,
87        // whose arity matches the number of children of the ɣ-node.
88        if self.sum_rows.len() != children.len() {
89            return Err(ChildrenValidationError::InvalidConditionalSum {
90                child: children[0].0, // Pass an arbitrary child
91                expected_count: children.len(),
92                actual_sum_rows: self.sum_rows.clone(),
93            });
94        }
95
96        // Each child must have its variant's row and the rest of `inputs` as input,
97        // and matching output
98        for (i, (child, optype)) in children.into_iter().enumerate() {
99            let case_op = optype
100                .as_case()
101                .expect("Child check should have already checked valid ops.");
102            let sig = &case_op.inner_signature();
103            if sig.input != self.case_input_row(i).unwrap() || sig.output != self.outputs {
104                return Err(ChildrenValidationError::ConditionalCaseSignature {
105                    child,
106                    optype: optype.clone(),
107                });
108            }
109        }
110
111        Ok(())
112    }
113}
114
115impl ValidateOp for super::CFG {
116    fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
117        OpValidityFlags {
118            allowed_children: OpTag::ControlFlowChild,
119            allowed_first_child: OpTag::DataflowBlock,
120            allowed_second_child: OpTag::BasicBlockExit,
121            requires_children: true,
122            requires_dag: false,
123            edge_check: Some(validate_cfg_edge),
124            ..Default::default()
125        }
126    }
127
128    fn validate_op_children<'a, N: HugrNode>(
129        &self,
130        mut children: impl Iterator<Item = (N, &'a OpType)>,
131    ) -> Result<(), ChildrenValidationError<N>> {
132        let (entry, entry_op) = children.next().unwrap();
133        let (exit, exit_op) = children.next().unwrap();
134        let entry_op = entry_op
135            .as_dataflow_block()
136            .expect("Child check should have already checked valid ops.");
137        let exit_op = exit_op
138            .as_exit_block()
139            .expect("Child check should have already checked valid ops.");
140
141        let sig = self.signature();
142        if entry_op.inner_signature().input() != sig.input() {
143            return Err(ChildrenValidationError::IOSignatureMismatch {
144                child: entry,
145                actual: entry_op.inner_signature().input().clone(),
146                expected: sig.input().clone(),
147                node_desc: "BasicBlock Input",
148                container_desc: "CFG",
149            });
150        }
151        if &exit_op.cfg_outputs != sig.output() {
152            return Err(ChildrenValidationError::IOSignatureMismatch {
153                child: exit,
154                actual: exit_op.cfg_outputs.clone(),
155                expected: sig.output().clone(),
156                node_desc: "BasicBlockExit Output",
157                container_desc: "CFG",
158            });
159        }
160        for (child, optype) in children {
161            if optype.tag() == OpTag::BasicBlockExit {
162                return Err(ChildrenValidationError::InternalExitChildren { child });
163            }
164        }
165        Ok(())
166    }
167}
168/// Errors that can occur while checking the children of a node.
169#[derive(Debug, Clone, PartialEq, Error)]
170#[allow(missing_docs)]
171#[non_exhaustive]
172pub enum ChildrenValidationError<N: HugrNode> {
173    /// An CFG graph has an exit operation as a non-second child.
174    #[error("Exit basic blocks are only allowed as the second child in a CFG graph")]
175    InternalExitChildren { child: N },
176    /// An operation only allowed as the first/second child was found as an intermediate child.
177    #[error("A {optype} operation is only allowed as a {expected_position} child")]
178    InternalIOChildren {
179        child: N,
180        optype: OpType,
181        expected_position: &'static str,
182    },
183    /// The signature of the contained dataflow graph does not match the one of the container.
184    #[error(
185        "The {node_desc} node of a {container_desc} has a signature of {actual}, which differs from the expected type row {expected}"
186    )]
187    IOSignatureMismatch {
188        child: N,
189        actual: TypeRow,
190        expected: TypeRow,
191        node_desc: &'static str,
192        container_desc: &'static str,
193    },
194    /// The signature of a child case in a conditional operation does not match the container's signature.
195    #[error("A conditional case has optype {sig}, which differs from the signature of Conditional container", sig=optype.dataflow_signature().unwrap_or_default())]
196    ConditionalCaseSignature { child: N, optype: OpType },
197    /// The conditional container's branching value does not match the number of children.
198    #[error("The conditional container's branch Sum input should be a sum with {expected_count} elements, but it had {} elements. Sum rows: {actual_sum_rows:?}",
199        actual_sum_rows.len())]
200    InvalidConditionalSum {
201        child: N,
202        expected_count: usize,
203        actual_sum_rows: Vec<TypeRow>,
204    },
205}
206
207impl<N: HugrNode> ChildrenValidationError<N> {
208    /// Returns the node index of the child that caused the error.
209    pub fn child(&self) -> N {
210        match self {
211            ChildrenValidationError::InternalIOChildren { child, .. } => *child,
212            ChildrenValidationError::InternalExitChildren { child, .. } => *child,
213            ChildrenValidationError::ConditionalCaseSignature { child, .. } => *child,
214            ChildrenValidationError::IOSignatureMismatch { child, .. } => *child,
215            ChildrenValidationError::InvalidConditionalSum { child, .. } => *child,
216        }
217    }
218}
219
220/// Errors that can occur while checking the edges between children of a node.
221#[derive(Debug, Clone, PartialEq, Error)]
222#[allow(missing_docs)]
223#[non_exhaustive]
224pub enum EdgeValidationError<N: HugrNode> {
225    /// The dataflow signature of two connected basic blocks does not match.
226    #[error("The dataflow signature of two connected basic blocks does not match. The source type was {source_ty} but the target had type {target_types}",
227        source_ty = source_types.clone().unwrap_or_default(),
228    )]
229    CFGEdgeSignatureMismatch {
230        edge: ChildrenEdgeData<N>,
231        source_types: Option<TypeRow>,
232        target_types: TypeRow,
233    },
234}
235
236impl<N: HugrNode> EdgeValidationError<N> {
237    /// Returns information on the edge that caused the error.
238    pub fn edge(&self) -> &ChildrenEdgeData<N> {
239        match self {
240            EdgeValidationError::CFGEdgeSignatureMismatch { edge, .. } => edge,
241        }
242    }
243}
244
245/// Auxiliary structure passed as data to [`OpValidityFlags::edge_check`].
246#[derive(Debug, Clone, PartialEq)]
247pub struct ChildrenEdgeData<N: HugrNode> {
248    /// Source child.
249    pub source: N,
250    /// Target child.
251    pub target: N,
252    /// Operation type of the source child.
253    pub source_op: OpType,
254    /// Operation type of the target child.
255    pub target_op: OpType,
256    /// Source port.
257    pub source_port: Port,
258    /// Target port.
259    pub target_port: Port,
260}
261
262impl<T: DataflowParent> ValidateOp for T {
263    /// Returns the set of allowed parent operation types.
264    fn validity_flags<N: HugrNode>(&self) -> OpValidityFlags<N> {
265        OpValidityFlags {
266            allowed_children: OpTag::DataflowChild,
267            allowed_first_child: OpTag::Input,
268            allowed_second_child: OpTag::Output,
269            requires_children: true,
270            requires_dag: true,
271            ..Default::default()
272        }
273    }
274
275    /// Validate the ordered list of children.
276    fn validate_op_children<'a, N: HugrNode>(
277        &self,
278        children: impl DoubleEndedIterator<Item = (N, &'a OpType)>,
279    ) -> Result<(), ChildrenValidationError<N>> {
280        let sig = self.inner_signature();
281        validate_io_nodes(&sig.input, &sig.output, "DataflowParent", children)
282    }
283}
284
285/// Checks a that the list of children nodes does not contain Input and Output
286/// nodes outside of the first and second elements respectively, and that those
287/// have the correct signature.
288fn validate_io_nodes<'a, N: HugrNode>(
289    expected_input: &TypeRow,
290    expected_output: &TypeRow,
291    container_desc: &'static str,
292    mut children: impl Iterator<Item = (N, &'a OpType)>,
293) -> Result<(), ChildrenValidationError<N>> {
294    // Check that the signature matches with the Input and Output rows.
295    let (first, first_optype) = children.next().unwrap();
296    let (second, second_optype) = children.next().unwrap();
297
298    let first_sig = first_optype.dataflow_signature().unwrap_or_default();
299    if &first_sig.output != expected_input {
300        return Err(ChildrenValidationError::IOSignatureMismatch {
301            child: first,
302            actual: first_sig.into_owned().output,
303            expected: expected_input.clone(),
304            node_desc: "Input",
305            container_desc,
306        });
307    }
308    let second_sig = second_optype.dataflow_signature().unwrap_or_default();
309
310    if &second_sig.input != expected_output {
311        return Err(ChildrenValidationError::IOSignatureMismatch {
312            child: second,
313            actual: second_sig.into_owned().input,
314            expected: expected_output.clone(),
315            node_desc: "Output",
316            container_desc,
317        });
318    }
319
320    // The first and second children have already been popped from the iterator.
321    for (child, optype) in children {
322        match optype.tag() {
323            OpTag::Input => {
324                return Err(ChildrenValidationError::InternalIOChildren {
325                    child,
326                    optype: optype.clone(),
327                    expected_position: "first",
328                });
329            }
330            OpTag::Output => {
331                return Err(ChildrenValidationError::InternalIOChildren {
332                    child,
333                    optype: optype.clone(),
334                    expected_position: "second",
335                });
336            }
337            _ => {}
338        }
339    }
340    Ok(())
341}
342
343/// Validate an edge between two basic blocks in a CFG sibling graph.
344fn validate_cfg_edge<N: HugrNode>(edge: ChildrenEdgeData<N>) -> Result<(), EdgeValidationError<N>> {
345    let source = &edge
346        .source_op
347        .as_dataflow_block()
348        .expect("CFG sibling graphs can only contain basic block operations.");
349
350    let target_input = match &edge.target_op {
351        OpType::DataflowBlock(dfb) => dfb.dataflow_input(),
352        OpType::ExitBlock(exit) => exit.dataflow_input(),
353        _ => panic!("CFG sibling graphs can only contain basic block operations."),
354    };
355
356    let source_types = source.successor_input(edge.source_port.index());
357    if source_types.as_ref() != Some(target_input) {
358        let target_types = target_input.clone();
359        return Err(EdgeValidationError::CFGEdgeSignatureMismatch {
360            edge,
361            source_types,
362            target_types,
363        });
364    }
365
366    Ok(())
367}
368
369#[cfg(test)]
370mod test {
371    use crate::extension::prelude::{Noop, usize_t};
372    use crate::ops::dataflow::IOTrait;
373    use crate::{Node, NodeIndex as _, ops};
374    use cool_asserts::assert_matches;
375    use portgraph::NodeIndex;
376
377    use super::*;
378
379    #[test]
380    fn test_validate_io_nodes() {
381        let in_types: TypeRow = vec![usize_t()].into();
382        let out_types: TypeRow = vec![usize_t(), usize_t()].into();
383
384        let input_node: OpType = ops::Input::new(in_types.clone()).into();
385        let output_node = ops::Output::new(out_types.clone()).into();
386        let leaf_node = Noop(usize_t()).into();
387
388        // Well-formed dataflow sibling nodes. Check the input and output node signatures.
389        let children = vec![
390            (0, &input_node),
391            (1, &output_node),
392            (2, &leaf_node),
393            (3, &leaf_node),
394        ];
395        assert_eq!(
396            validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
397            Ok(())
398        );
399        assert_matches!(
400            validate_io_nodes(&out_types, &out_types, "test", make_iter(&children)),
401            Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 0
402        );
403        assert_matches!(
404            validate_io_nodes(&in_types, &in_types, "test", make_iter(&children)),
405            Err(ChildrenValidationError::IOSignatureMismatch { child, .. }) if child.index() == 1
406        );
407
408        // Internal I/O nodes
409        let children = vec![
410            (0, &input_node),
411            (1, &output_node),
412            (42, &leaf_node),
413            (2, &leaf_node),
414            (3, &output_node),
415        ];
416        assert_matches!(
417            validate_io_nodes(&in_types, &out_types, "test", make_iter(&children)),
418            Err(ChildrenValidationError::InternalIOChildren { child, .. }) if child.index() == 3
419        );
420    }
421
422    fn make_iter<'a>(
423        children: &'a [(usize, &OpType)],
424    ) -> impl DoubleEndedIterator<Item = (Node, &'a OpType)> {
425        children
426            .iter()
427            .map(|(n, op)| (NodeIndex::new(*n).into(), *op))
428    }
429}
430
431use super::{
432    AliasDecl, AliasDefn, Call, CallIndirect, Const, ExtensionOp, FuncDecl, Input, LoadConstant,
433    LoadFunction, OpaqueOp, Output, Tag,
434};
435impl_validate_op!(FuncDecl);
436impl_validate_op!(AliasDecl);
437impl_validate_op!(AliasDefn);
438impl_validate_op!(Input);
439impl_validate_op!(Output);
440impl_validate_op!(Const);
441impl_validate_op!(Call);
442impl_validate_op!(LoadConstant);
443impl_validate_op!(LoadFunction);
444impl_validate_op!(CallIndirect);
445impl_validate_op!(ExtensionOp);
446impl_validate_op!(OpaqueOp);
447impl_validate_op!(Tag);
448impl_validate_op!(ExitBlock);