hugr_core/
hugr.rs

1//! The Hugr data structure, and its basic component handles.
2
3pub mod hugrmut;
4
5pub(crate) mod ident;
6pub mod internal;
7pub mod rewrite;
8pub mod serialize;
9pub mod validate;
10pub mod views;
11
12use std::collections::VecDeque;
13use std::io::Read;
14use std::iter;
15
16pub(crate) use self::hugrmut::HugrMut;
17pub use self::validate::ValidationError;
18
19pub use ident::{IdentList, InvalidIdentifier};
20pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError};
21
22use portgraph::multiportgraph::MultiPortGraph;
23use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap};
24use thiserror::Error;
25
26pub use self::views::{HugrView, RootTagged};
27use crate::core::NodeIndex;
28use crate::extension::resolution::{
29    resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError,
30    WeakExtensionRegistry,
31};
32use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED};
33use crate::ops::{OpTag, OpTrait};
34pub use crate::ops::{OpType, DEFAULT_OPTYPE};
35use crate::{Direction, Node};
36
37/// The Hugr data structure.
38#[derive(Clone, Debug, PartialEq)]
39pub struct Hugr {
40    /// The graph encoding the adjacency structure of the HUGR.
41    graph: MultiPortGraph,
42
43    /// The node hierarchy.
44    hierarchy: Hierarchy,
45
46    /// The single root node in the hierarchy.
47    root: portgraph::NodeIndex,
48
49    /// Operation types for each node.
50    op_types: UnmanagedDenseMap<portgraph::NodeIndex, OpType>,
51
52    /// Node metadata
53    metadata: UnmanagedDenseMap<portgraph::NodeIndex, Option<NodeMetadataMap>>,
54
55    /// Extensions used by the operations in the Hugr.
56    extensions: ExtensionRegistry,
57}
58
59impl Default for Hugr {
60    fn default() -> Self {
61        Self::new(crate::ops::Module::new())
62    }
63}
64
65impl AsRef<Hugr> for Hugr {
66    fn as_ref(&self) -> &Hugr {
67        self
68    }
69}
70
71impl AsMut<Hugr> for Hugr {
72    fn as_mut(&mut self) -> &mut Hugr {
73        self
74    }
75}
76
77/// Arbitrary metadata entry for a node.
78///
79/// Each entry is associated to a string key.
80pub type NodeMetadata = serde_json::Value;
81
82/// The container of all the metadata entries for a node.
83pub type NodeMetadataMap = serde_json::Map<String, NodeMetadata>;
84
85/// Public API for HUGRs.
86impl Hugr {
87    /// Create a new Hugr, with a single root node.
88    pub fn new(root_node: impl Into<OpType>) -> Self {
89        Self::with_capacity(root_node.into(), 0, 0)
90    }
91
92    /// Load a Hugr from a json reader.
93    ///
94    /// Validates the Hugr against the provided extension registry, ensuring all
95    /// operations are resolved.
96    ///
97    /// If the feature `extension_inference` is enabled, we will ensure every function
98    /// correctly specifies the extensions required by its contained ops.
99    pub fn load_json(
100        reader: impl Read,
101        extension_registry: &ExtensionRegistry,
102    ) -> Result<Self, LoadHugrError> {
103        let mut hugr: Hugr = serde_json::from_reader(reader)?;
104
105        hugr.resolve_extension_defs(extension_registry)?;
106        hugr.validate_no_extensions()?;
107
108        if cfg!(feature = "extension_inference") {
109            hugr.infer_extensions(false)?;
110            hugr.validate_extensions()?;
111        }
112
113        Ok(hugr)
114    }
115
116    /// Infers an extension-delta for any non-function container node
117    /// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta
118    /// will be the smallest delta compatible with its children and that includes any
119    /// other [ExtensionId]s in the current delta.
120    ///
121    /// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED],
122    /// ExtensionIds are removed from the delta if they are *not* used by any child node.
123    ///
124    /// The non-function container nodes are:
125    /// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop]
126    ///
127    /// [Case]: crate::ops::Case
128    /// [CFG]: crate::ops::CFG
129    /// [Conditional]: crate::ops::Conditional
130    /// [DataflowBlock]: crate::ops::DataflowBlock
131    /// [DFG]: crate::ops::DFG
132    /// [TailLoop]: crate::ops::TailLoop
133    /// [extension_delta]: crate::ops::OpType::extension_delta
134    /// [ExtensionId]: crate::extension::ExtensionId
135    pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> {
136        fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> {
137            match optype {
138                OpType::DFG(dfg) => Some(&mut dfg.signature.runtime_reqs),
139                OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta),
140                OpType::TailLoop(tl) => Some(&mut tl.extension_delta),
141                OpType::CFG(cfg) => Some(&mut cfg.signature.runtime_reqs),
142                OpType::Conditional(c) => Some(&mut c.extension_delta),
143                OpType::Case(c) => Some(&mut c.signature.runtime_reqs),
144                //OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed
145                //OpType::FuncDefn(_) // Not at present due to the possibility of recursion
146                _ => None,
147            }
148        }
149        fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result<ExtensionSet, ExtensionError> {
150            let mut child_sets = h
151                .children(node)
152                .collect::<Vec<_>>() // Avoid borrowing h over recursive call
153                .into_iter()
154                .map(|ch| Ok((ch, infer(h, ch, remove)?)))
155                .collect::<Result<Vec<_>, _>>()?;
156
157            let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else {
158                return Ok(h.get_optype(node).extension_delta());
159            };
160            if es.contains(&TO_BE_INFERRED) {
161                // Do not remove anything from current delta - any other elements are a lower bound
162                child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst
163            } else if remove {
164                child_sets.iter().try_for_each(|(ch, ch_exts)| {
165                    if !es.is_superset(ch_exts) {
166                        return Err(ExtensionError {
167                            parent: node,
168                            parent_extensions: es.clone(),
169                            child: *ch,
170                            child_extensions: ch_exts.clone(),
171                        });
172                    }
173                    Ok(())
174                })?;
175            } else {
176                return Ok(es.clone()); // Can't neither add nor remove, so nothing to do
177            }
178            let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e));
179            *es = ExtensionSet::singleton(TO_BE_INFERRED).missing_from(&merged);
180
181            Ok(es.clone())
182        }
183        infer(self, self.root(), remove)?;
184        Ok(())
185    }
186
187    /// Given a Hugr that has been deserialized, collect all extensions used to
188    /// define the HUGR while resolving all [`OpType::OpaqueOp`] operations into
189    /// [`OpType::ExtensionOp`]s and updating the extension pointer in all
190    /// internal [`crate::types::CustomType`]s to point to the extensions in the
191    /// register.
192    ///
193    /// When listing "used extensions" we only care about _definitional_
194    /// extension requirements, i.e., the operations and types that are required
195    /// to define the HUGR nodes and wire types. This is computed from the union
196    /// of all extension required across the HUGR.
197    ///
198    /// This is distinct from _runtime_ extension requirements computed in
199    /// [`Hugr::infer_extensions`], which are computed more granularly in each
200    /// function signature by the `runtime_reqs` field and define the set
201    /// of capabilities required by the runtime to execute each function.
202    ///
203    /// Updates the internal extension registry with the extensions used in the
204    /// definition.
205    ///
206    /// # Parameters
207    ///
208    /// - `extensions`: The extension set considered when resolving opaque
209    ///     operations and types. The original Hugr's internal extension
210    ///     registry is ignored and replaced with the newly computed one.
211    ///
212    /// # Errors
213    ///
214    /// - If an opaque operation cannot be resolved to an extension operation.
215    /// - If an extension operation references an extension that is missing from
216    ///   the registry.
217    /// - If a custom type references an extension that is missing from the
218    ///   registry.
219    pub fn resolve_extension_defs(
220        &mut self,
221        extensions: &ExtensionRegistry,
222    ) -> Result<(), ExtensionResolutionError> {
223        let mut used_extensions = ExtensionRegistry::default();
224
225        // Here we need to iterate the optypes in the hugr mutably, to avoid
226        // having to clone and accumulate all replacements before finally
227        // applying them.
228        //
229        // This is not something we want to expose it the API, so we manually
230        // iterate instead of writing it as a method.
231        //
232        // Since we don't have a non-borrowing iterator over all the possible
233        // NodeIds, we have to simulate it by iterating over all possible
234        // indices and checking if the node exists.
235        let weak_extensions: WeakExtensionRegistry = extensions.into();
236        for n in 0..self.graph.node_capacity() {
237            let pg_node = portgraph::NodeIndex::new(n);
238            let node: Node = pg_node.into();
239            if !self.contains_node(node) {
240                continue;
241            }
242
243            let op = &mut self.op_types[pg_node];
244
245            if let Some(extension) = resolve_op_extensions(node, op, extensions)? {
246                used_extensions.register_updated_ref(extension);
247            }
248            used_extensions.extend(
249                resolve_op_types_extensions(Some(node), op, &weak_extensions)?.map(|weak| {
250                    weak.upgrade()
251                        .expect("Extension comes from a valid registry")
252                }),
253            );
254        }
255
256        self.extensions = used_extensions;
257        Ok(())
258    }
259}
260
261/// Internal API for HUGRs, not intended for use by users.
262impl Hugr {
263    /// Create a new Hugr, with a single root node and preallocated capacity.
264    pub(crate) fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self {
265        let mut graph = MultiPortGraph::with_capacity(nodes, ports);
266        let hierarchy = Hierarchy::new();
267        let mut op_types = UnmanagedDenseMap::with_capacity(nodes);
268        let root = graph.add_node(root_node.input_count(), root_node.output_count());
269        let extensions = root_node.used_extensions();
270        op_types[root] = root_node;
271
272        Self {
273            graph,
274            hierarchy,
275            root,
276            op_types,
277            metadata: UnmanagedDenseMap::with_capacity(nodes),
278            extensions: extensions.unwrap_or_default(),
279        }
280    }
281
282    /// Set the root node of the hugr.
283    pub(crate) fn set_root(&mut self, root: Node) {
284        self.hierarchy.detach(self.root);
285        self.root = root.pg_index();
286    }
287
288    /// Add a node to the graph.
289    pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node {
290        let node = self
291            .graph
292            .add_node(nodetype.input_count(), nodetype.output_count());
293        self.op_types[node] = nodetype;
294        node.into()
295    }
296
297    /// Produce a canonical ordering of the descendant nodes of a root,
298    /// following the graph hierarchy.
299    ///
300    /// This starts with the root, and then proceeds in BFS order through the
301    /// contained regions.
302    ///
303    /// Used by [`HugrMut::canonicalize_nodes`] and the serialization code.
304    fn canonical_order(&self, root: Node) -> impl Iterator<Item = Node> + '_ {
305        // Generate a BFS-ordered list of nodes based on the hierarchy
306        let mut queue = VecDeque::from([root]);
307        iter::from_fn(move || {
308            let node = queue.pop_front()?;
309            for child in self.children(node) {
310                queue.push_back(child);
311            }
312            Some(node)
313        })
314    }
315
316    /// Compact the nodes indices of the hugr to be contiguous, and order them as a breadth-first
317    /// traversal of the hierarchy.
318    ///
319    /// The rekey function is called for each moved node with the old and new indices.
320    ///
321    /// After this operation, a serialization and deserialization of the Hugr is guaranteed to
322    /// preserve the indices.
323    pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) {
324        // Generate the ordered list of nodes
325        let mut ordered = Vec::with_capacity(self.node_count());
326        let root = self.root();
327        ordered.extend(self.as_mut().canonical_order(root));
328
329        // Permute the nodes in the graph to match the order.
330        //
331        // Invariant: All the elements before `position` are in the correct place.
332        for position in 0..ordered.len() {
333            // Find the element's location. If it originally came from a previous position
334            // then it has been swapped somewhere else, so we follow the permutation chain.
335            let mut source: Node = ordered[position];
336            while position > source.index() {
337                source = ordered[source.index()];
338            }
339
340            let target: Node = portgraph::NodeIndex::new(position).into();
341            if target != source {
342                let pg_target = target.pg_index();
343                let pg_source = source.pg_index();
344                self.graph.swap_nodes(pg_target, pg_source);
345                self.op_types.swap(pg_target, pg_source);
346                self.hierarchy.swap_nodes(pg_target, pg_source);
347                rekey(source, target);
348            }
349        }
350        self.root = portgraph::NodeIndex::new(0);
351
352        // Finish by compacting the copy nodes.
353        // The operation nodes will be left in place.
354        // This step is not strictly necessary.
355        self.graph.compact_nodes(|_, _| {});
356    }
357}
358
359#[derive(Debug, Clone, PartialEq, Error)]
360#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
361/// An error in the extension deltas.
362pub struct ExtensionError {
363    parent: Node,
364    parent_extensions: ExtensionSet,
365    child: Node,
366    child_extensions: ExtensionSet,
367}
368
369/// Errors that can occur while manipulating a Hugr.
370///
371/// TODO: Better descriptions, not just re-exporting portgraph errors.
372#[derive(Debug, Clone, PartialEq, Eq, Error)]
373#[non_exhaustive]
374pub enum HugrError {
375    /// The node was not of the required [OpTag]
376    /// (e.g. to conform to the [RootTagged::RootHandle] of a [HugrView])
377    #[error("Invalid tag: required a tag in {required} but found {actual}")]
378    #[allow(missing_docs)]
379    InvalidTag { required: OpTag, actual: OpTag },
380    /// An invalid port was specified.
381    #[error("Invalid port direction {0:?}.")]
382    InvalidPortDirection(Direction),
383}
384
385/// Errors that can occur while loading and validating a Hugr json.
386#[derive(Debug, Error)]
387#[non_exhaustive]
388pub enum LoadHugrError {
389    /// Error while loading the Hugr from JSON.
390    #[error("Error while loading the Hugr from JSON: {0}")]
391    Load(#[from] serde_json::Error),
392    /// Validation of the loaded Hugr failed.
393    #[error(transparent)]
394    Validation(#[from] ValidationError),
395    /// Error when resolving extension operations and types.
396    #[error(transparent)]
397    Extension(#[from] ExtensionResolutionError),
398    /// Error when inferring runtime extensions.
399    #[error(transparent)]
400    RuntimeInference(#[from] ExtensionError),
401}
402
403#[cfg(test)]
404mod test {
405    use std::sync::Arc;
406    use std::{fs::File, io::BufReader};
407
408    use super::internal::HugrMutInternals;
409    #[cfg(feature = "extension_inference")]
410    use super::ValidationError;
411    use super::{ExtensionError, Hugr, HugrMut, HugrView, Node};
412    use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED};
413    use crate::ops::{ExtensionOp, OpName};
414    use crate::types::type_param::TypeParam;
415    use crate::types::{
416        FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV, TypeRow,
417    };
418
419    use crate::{const_extension_ids, ops, test_file, type_row, Extension};
420    use cool_asserts::assert_matches;
421    use lazy_static::lazy_static;
422    use rstest::rstest;
423
424    const_extension_ids! {
425        pub(crate) const LIFT_EXT_ID: ExtensionId = "LIFT_EXT_ID";
426    }
427    lazy_static! {
428        /// Tests only extension holding an Op that can add arbitrary extensions to a row.
429        pub(crate) static ref LIFT_EXT: Arc<Extension> = {
430            Extension::new_arc(
431                LIFT_EXT_ID,
432                hugr::extension::Version::new(0, 0, 0),
433                |ext, extension_ref| {
434                    ext.add_op(
435                        OpName::new_inline("Lift"),
436                        "".into(),
437                        PolyFuncTypeRV::new(
438                            vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)],
439                            FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any))
440                                .with_extension_delta(ExtensionSet::type_var(0)),
441                        ),
442                        extension_ref,
443                    )
444                    .unwrap();
445                },
446            )
447        };
448    }
449
450    pub(crate) fn lift_op(
451        type_row: impl Into<TypeRow>,
452        extensions: impl Into<ExtensionSet>,
453    ) -> ExtensionOp {
454        LIFT_EXT
455            .instantiate_extension_op(
456                "Lift",
457                [
458                    TypeArg::Extensions {
459                        es: extensions.into(),
460                    },
461                    TypeArg::Sequence {
462                        elems: type_row
463                            .into()
464                            .iter()
465                            .map(|t| TypeArg::Type { ty: t.clone() })
466                            .collect(),
467                    },
468                ],
469            )
470            .unwrap()
471    }
472
473    #[test]
474    fn impls_send_and_sync() {
475        // Send and Sync are automatically impl'd by the compiler, if possible.
476        // This test will fail to compile if that wasn't possible.
477        #[allow(dead_code)]
478        trait Test: Send + Sync {}
479        impl Test for Hugr {}
480    }
481
482    #[test]
483    fn io_node() {
484        use crate::builder::test::simple_dfg_hugr;
485        use cool_asserts::assert_matches;
486
487        let hugr = simple_dfg_hugr();
488        assert_matches!(hugr.get_io(hugr.root()), Some(_));
489    }
490
491    #[test]
492    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
493    fn hugr_validation_0() {
494        // https://github.com/CQCL/hugr/issues/1091 bad case
495        let hugr = Hugr::load_json(
496            BufReader::new(File::open(test_file!("hugr-0.json")).unwrap()),
497            &PRELUDE_REGISTRY,
498        );
499        assert_matches!(hugr, Err(_));
500    }
501
502    #[test]
503    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
504    fn hugr_validation_1() {
505        // https://github.com/CQCL/hugr/issues/1091 good case
506        let hugr = Hugr::load_json(
507            BufReader::new(File::open(test_file!("hugr-1.json")).unwrap()),
508            &PRELUDE_REGISTRY,
509        );
510        assert_matches!(&hugr, Ok(_));
511    }
512
513    #[test]
514    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
515    fn hugr_validation_2() {
516        // https://github.com/CQCL/hugr/issues/1185 bad case
517        let hugr = Hugr::load_json(
518            BufReader::new(File::open(test_file!("hugr-2.json")).unwrap()),
519            &PRELUDE_REGISTRY,
520        );
521        assert_matches!(hugr, Err(_));
522    }
523
524    #[test]
525    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
526    fn hugr_validation_3() {
527        // https://github.com/CQCL/hugr/issues/1185 good case
528        let hugr = Hugr::load_json(
529            BufReader::new(File::open(test_file!("hugr-3.json")).unwrap()),
530            &PRELUDE_REGISTRY,
531        );
532        assert_matches!(&hugr, Ok(_));
533    }
534
535    const_extension_ids! {
536        const XA: ExtensionId = "EXT_A";
537        const XB: ExtensionId = "EXT_B";
538    }
539
540    #[rstest]
541    #[case([], XA.into())]
542    #[case([XA], XA.into())]
543    #[case([XB], ExtensionSet::from_iter([XA, XB]))]
544
545    fn infer_single_delta(
546        #[case] parent: impl IntoIterator<Item = ExtensionId>,
547        #[values(true, false)] remove: bool, // makes no difference when inferring
548        #[case] result: ExtensionSet,
549    ) {
550        let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into());
551        let (mut h, _) = build_ext_dfg(parent);
552        h.infer_extensions(remove).unwrap();
553        assert_eq!(h, build_ext_dfg(result.union(LIFT_EXT_ID.into())).0);
554    }
555
556    #[test]
557    fn infer_removes_from_delta() {
558        let parent = ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]);
559        let mut h = build_ext_dfg(parent.clone()).0;
560        let backup = h.clone();
561        h.infer_extensions(false).unwrap();
562        assert_eq!(h, backup); // did nothing
563        h.infer_extensions(true).unwrap();
564        assert_eq!(
565            h,
566            build_ext_dfg(ExtensionSet::from_iter([XA, LIFT_EXT_ID])).0
567        );
568    }
569
570    #[test]
571    fn infer_bad_remove() {
572        let (mut h, mid) = build_ext_dfg(XB.into());
573        let backup = h.clone();
574        h.infer_extensions(false).unwrap();
575        assert_eq!(h, backup); // did nothing
576        let val_res = h.validate();
577        let expected_err = ExtensionError {
578            parent: h.root(),
579            parent_extensions: XB.into(),
580            child: mid,
581            child_extensions: ExtensionSet::from_iter([XA, LIFT_EXT_ID]),
582        };
583        #[cfg(feature = "extension_inference")]
584        assert_eq!(
585            val_res,
586            Err(ValidationError::ExtensionError(expected_err.clone()))
587        );
588        #[cfg(not(feature = "extension_inference"))]
589        assert!(val_res.is_ok());
590
591        let inf_res = h.infer_extensions(true);
592        assert_eq!(inf_res, Err(expected_err));
593    }
594
595    fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) {
596        let ty = Type::new_function(Signature::new_endo(type_row![]));
597        let mut h = Hugr::new(ops::DFG {
598            signature: Signature::new_endo(ty.clone()).with_extension_delta(parent.clone()),
599        });
600        let root = h.root();
601        let mid = add_inliftout(&mut h, root, ty);
602        (h, mid)
603    }
604
605    fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node {
606        let inp = h.add_node_with_parent(
607            p,
608            ops::Input {
609                types: ty.clone().into(),
610            },
611        );
612        let out = h.add_node_with_parent(
613            p,
614            ops::Output {
615                types: ty.clone().into(),
616            },
617        );
618        let mid = h.add_node_with_parent(p, lift_op(ty, XA));
619        h.connect(inp, 0, mid, 0);
620        h.connect(mid, 0, out, 0);
621        mid
622    }
623
624    #[rstest]
625    // Base case success: delta inferred for parent equals grandparent.
626    #[case([XA], [TO_BE_INFERRED], true, [XA])]
627    // Success: delta inferred for parent is subset of grandparent
628    #[case([XA, XB], [TO_BE_INFERRED], true, [XA])]
629    // Base case failure: infers [XA] for parent but grandparent has disjoint set
630    #[case([XB], [TO_BE_INFERRED], false, [XA])]
631    // Failure: as previous, but extra "lower bound" on parent that has no effect
632    #[case([XB], [XA, TO_BE_INFERRED], false, [XA])]
633    // Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB
634    #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])]
635    // Success: grandparent includes extra XB required for parent's "lower bound"
636    #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])]
637    // Success: grandparent is also inferred so can include 'extra' XB from parent
638    #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])]
639    // No inference: extraneous XB in parent is removed so all become [XA].
640    #[case([XA], [XA, XB], true, [XA])]
641    fn infer_three_generations(
642        #[case] grandparent: impl IntoIterator<Item = ExtensionId>,
643        #[case] parent: impl IntoIterator<Item = ExtensionId>,
644        #[case] success: bool,
645        #[case] result: impl IntoIterator<Item = ExtensionId>,
646    ) {
647        let ty = Type::new_function(Signature::new_endo(type_row![]));
648        let grandparent = ExtensionSet::from_iter(grandparent).union(LIFT_EXT_ID.into());
649        let parent = ExtensionSet::from_iter(parent).union(LIFT_EXT_ID.into());
650        let result = ExtensionSet::from_iter(result).union(LIFT_EXT_ID.into());
651        let root_ty = ops::Conditional {
652            sum_rows: vec![type_row![]],
653            other_inputs: ty.clone().into(),
654            outputs: ty.clone().into(),
655            extension_delta: grandparent.clone(),
656        };
657        let mut h = Hugr::new(root_ty.clone());
658        let p = h.add_node_with_parent(
659            h.root(),
660            ops::Case {
661                signature: Signature::new_endo(ty.clone()).with_extension_delta(parent),
662            },
663        );
664        add_inliftout(&mut h, p, ty.clone());
665        assert!(h.validate_extensions().is_err());
666        let backup = h.clone();
667        let inf_res = h.infer_extensions(true);
668        if success {
669            assert!(inf_res.is_ok());
670            let expected_p = ops::Case {
671                signature: Signature::new_endo(ty).with_extension_delta(result.clone()),
672            };
673            let mut expected = backup;
674            expected.replace_op(p, expected_p).unwrap();
675            let expected_gp = ops::Conditional {
676                extension_delta: result,
677                ..root_ty
678            };
679            expected.replace_op(h.root(), expected_gp).unwrap();
680
681            assert_eq!(h, expected);
682        } else {
683            assert_eq!(
684                inf_res,
685                Err(ExtensionError {
686                    parent: h.root(),
687                    parent_extensions: grandparent,
688                    child: p,
689                    child_extensions: result
690                })
691            );
692        }
693    }
694}