hugr_core/hugr/rewrite/
replace.rs

1//! Implementation of the `Replace` operation.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::hugr::hugrmut::InsertionResult;
9use crate::hugr::HugrMut;
10use crate::ops::{OpTag, OpTrait};
11use crate::types::EdgeKind;
12use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort};
13
14use super::Rewrite;
15
16/// Specifies how to create a new edge.
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct NewEdgeSpec {
19    /// The source of the new edge. For [Replacement::mu_inp] and [Replacement::mu_new], this is in the
20    /// existing Hugr; for edges in [Replacement::mu_out] this is in the [Replacement::replacement]
21    pub src: Node,
22    /// The target of the new edge. For [Replacement::mu_inp], this is in the [Replacement::replacement];
23    /// for edges in [Replacement::mu_out] and [Replacement::mu_new], this is in the existing Hugr.
24    pub tgt: Node,
25    /// The kind of edge to create, and any port specifiers required
26    pub kind: NewEdgeKind,
27}
28
29/// Describes an edge that should be created between two nodes already given
30#[derive(Clone, Debug, PartialEq, Eq)]
31pub enum NewEdgeKind {
32    /// An [EdgeKind::StateOrder] edge (between DFG nodes only)
33    Order,
34    /// An [EdgeKind::Value] edge (between DFG nodes only)
35    Value {
36        /// The source port
37        src_pos: OutgoingPort,
38        /// The target port
39        tgt_pos: IncomingPort,
40    },
41    /// An [EdgeKind::Const] or [EdgeKind::Function] edge
42    Static {
43        /// The source port
44        src_pos: OutgoingPort,
45        /// The target port
46        tgt_pos: IncomingPort,
47    },
48    /// A [EdgeKind::ControlFlow] edge (between CFG nodes only)
49    ControlFlow {
50        /// Identifies a control-flow output (successor) of the source node.
51        src_pos: OutgoingPort,
52    },
53}
54
55/// Specification of a `Replace` operation
56#[derive(Debug, Clone, PartialEq)]
57pub struct Replacement {
58    /// The nodes to remove from the existing Hugr (known as Gamma).
59    /// These must all have a common parent (i.e. be siblings).  Called "S" in the spec.
60    /// Must be non-empty - otherwise there is no parent under which to place [Self::replacement],
61    /// and there would be no possible [Self::mu_inp], [Self::mu_out] or [Self::adoptions].
62    pub removal: Vec<Node>,
63    /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), whose root
64    /// is the same type as the root of [Self::replacement].  "G" in the spec.
65    pub replacement: Hugr,
66    /// Describes how parts of the Hugr that would otherwise be removed should instead be preserved but
67    /// with new parents amongst the newly-inserted nodes.  This is a Map from container nodes in
68    /// [Self::replacement] that have no children, to container nodes that are descended from [Self::removal].
69    /// The keys are the new parents for the children of the values.  Note no value may be ancestor or
70    /// descendant of another.  This is "B" in the spec; "R" is the set of descendants of [Self::removal]
71    ///  that are not descendants of values here.
72    pub adoptions: HashMap<Node, Node>,
73    /// Edges from nodes in the existing Hugr that are not removed ([NewEdgeSpec::src] in Gamma\R)
74    /// to inserted nodes ([NewEdgeSpec::tgt] in [Self::replacement]).
75    pub mu_inp: Vec<NewEdgeSpec>,
76    /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to existing nodes not removed
77    /// ([NewEdgeSpec::tgt] in Gamma \ R).
78    pub mu_out: Vec<NewEdgeSpec>,
79    /// Edges to add between existing nodes (both [NewEdgeSpec::src] and [NewEdgeSpec::tgt] in Gamma \ R).
80    /// For example, in cases where the source had an edge to a removed node, and the target had an
81    /// edge from a removed node, this would allow source to be directly connected to target.
82    pub mu_new: Vec<NewEdgeSpec>,
83}
84
85impl NewEdgeSpec {
86    fn check_src(
87        &self,
88        h: &impl HugrView<Node = Node>,
89        err_spec: &NewEdgeSpec,
90    ) -> Result<(), ReplaceError> {
91        let optype = h.get_optype(self.src);
92        let ok = match self.kind {
93            NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder),
94            NewEdgeKind::Value { src_pos, .. } => {
95                matches!(optype.port_kind(src_pos), Some(EdgeKind::Value(_)))
96            }
97            NewEdgeKind::Static { src_pos, .. } => optype
98                .port_kind(src_pos)
99                .as_ref()
100                .is_some_and(EdgeKind::is_static),
101            NewEdgeKind::ControlFlow { src_pos } => {
102                matches!(optype.port_kind(src_pos), Some(EdgeKind::ControlFlow))
103            }
104        };
105        ok.then_some(())
106            .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone()))
107    }
108    fn check_tgt(
109        &self,
110        h: &impl HugrView<Node = Node>,
111        err_spec: &NewEdgeSpec,
112    ) -> Result<(), ReplaceError> {
113        let optype = h.get_optype(self.tgt);
114        let ok = match self.kind {
115            NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder),
116            NewEdgeKind::Value { tgt_pos, .. } => {
117                matches!(optype.port_kind(tgt_pos), Some(EdgeKind::Value(_)))
118            }
119            NewEdgeKind::Static { tgt_pos, .. } => optype
120                .port_kind(tgt_pos)
121                .as_ref()
122                .is_some_and(EdgeKind::is_static),
123            NewEdgeKind::ControlFlow { .. } => matches!(
124                optype.port_kind(IncomingPort::from(0)),
125                Some(EdgeKind::ControlFlow)
126            ),
127        };
128        ok.then_some(())
129            .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone()))
130    }
131
132    fn check_existing_edge(
133        &self,
134        h: &impl HugrView<Node = Node>,
135        legal_src_ancestors: &HashSet<Node>,
136        err_edge: impl Fn() -> NewEdgeSpec,
137    ) -> Result<(), ReplaceError> {
138        if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind
139        {
140            let descends_from_legal = |mut descendant: Node| -> bool {
141                while !legal_src_ancestors.contains(&descendant) {
142                    let Some(p) = h.get_parent(descendant) else {
143                        return false;
144                    };
145                    descendant = p;
146                }
147                true
148            };
149            let found_incoming = h
150                .single_linked_output(self.tgt, tgt_pos)
151                .is_some_and(|(src_n, _)| descends_from_legal(src_n));
152            if !found_incoming {
153                return Err(ReplaceError::NoRemovedEdge(err_edge()));
154            };
155        };
156        Ok(())
157    }
158}
159
160impl Replacement {
161    fn check_parent(&self, h: &impl HugrView<Node = Node>) -> Result<Node, ReplaceError> {
162        let parent = self
163            .removal
164            .iter()
165            .map(|n| h.get_parent(*n))
166            .unique()
167            .exactly_one()
168            .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))?
169            .ok_or(ReplaceError::CantReplaceRoot)?; // If no parent
170
171        // Check replacement parent is of same tag. Note we do not require exact equality
172        // of OpType/Signature, e.g. to ease changing of Input/Output node signatures too.
173        let removed = h.get_optype(parent).tag();
174        let replacement = self.replacement.root_type().tag();
175        if removed != replacement {
176            return Err(ReplaceError::WrongRootNodeTag {
177                removed,
178                replacement,
179            });
180        };
181        Ok(parent)
182    }
183
184    fn get_removed_nodes(
185        &self,
186        h: &impl HugrView<Node = Node>,
187    ) -> Result<HashSet<Node>, ReplaceError> {
188        // Check the keys of the transfer map too, the values we'll use imminently
189        self.adoptions.keys().try_for_each(|&n| {
190            (self.replacement.contains_node(n)
191                && self.replacement.get_optype(n).is_container()
192                && self.replacement.children(n).next().is_none())
193            .then_some(())
194            .ok_or(ReplaceError::InvalidAdoptingParent(n))
195        })?;
196        let mut transferred: HashSet<Node> = self.adoptions.values().copied().collect();
197        if transferred.len() != self.adoptions.values().len() {
198            return Err(ReplaceError::AdopteesNotSeparateDescendants(
199                self.adoptions
200                    .values()
201                    .filter(|v| !transferred.remove(v))
202                    .copied()
203                    .collect(),
204            ));
205        }
206
207        let mut removed = HashSet::new();
208        let mut queue = VecDeque::from_iter(self.removal.iter().copied());
209        while let Some(n) = queue.pop_front() {
210            let new = removed.insert(n);
211            debug_assert!(new); // Fails only if h's hierarchy has merges (is not a tree)
212            if !transferred.remove(&n) {
213                h.children(n).for_each(|ch| queue.push_back(ch))
214            }
215        }
216        if !transferred.is_empty() {
217            return Err(ReplaceError::AdopteesNotSeparateDescendants(
218                transferred.into_iter().collect(),
219            ));
220        }
221        Ok(removed)
222    }
223}
224impl Rewrite for Replacement {
225    type Error = ReplaceError;
226
227    /// Map from Node in replacement to corresponding Node in the result Hugr
228    type ApplyResult = HashMap<Node, Node>;
229
230    const UNCHANGED_ON_FAILURE: bool = false;
231
232    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
233        self.check_parent(h)?;
234        let removed = self.get_removed_nodes(h)?;
235        // Edge sources...
236        for e in self.mu_inp.iter().chain(self.mu_new.iter()) {
237            if !h.contains_node(e.src) || removed.contains(&e.src) {
238                return Err(ReplaceError::BadEdgeSpec(
239                    Direction::Outgoing,
240                    WhichHugr::Retained,
241                    e.clone(),
242                ));
243            }
244            e.check_src(h, e)?;
245        }
246        self.mu_out
247            .iter()
248            .try_for_each(|e| match self.replacement.valid_non_root(e.src) {
249                true => e.check_src(&self.replacement, e),
250                false => Err(ReplaceError::BadEdgeSpec(
251                    Direction::Outgoing,
252                    WhichHugr::Replacement,
253                    e.clone(),
254                )),
255            })?;
256        // Edge targets...
257        self.mu_inp
258            .iter()
259            .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) {
260                true => e.check_tgt(&self.replacement, e),
261                false => Err(ReplaceError::BadEdgeSpec(
262                    Direction::Incoming,
263                    WhichHugr::Replacement,
264                    e.clone(),
265                )),
266            })?;
267        for e in self.mu_out.iter().chain(self.mu_new.iter()) {
268            if !h.contains_node(e.tgt) || removed.contains(&e.tgt) {
269                return Err(ReplaceError::BadEdgeSpec(
270                    Direction::Incoming,
271                    WhichHugr::Retained,
272                    e.clone(),
273                ));
274            }
275            e.check_tgt(h, e)?;
276            // The descendant check is to allow the case where the old edge is nonlocal
277            // from a part of the Hugr being moved (which may require changing source,
278            // depending on where the transplanted portion ends up). While this subsumes
279            // the first "removed.contains" check, we'll keep that as a common-case fast-path.
280            e.check_existing_edge(h, &removed, || e.clone())?;
281        }
282        Ok(())
283    }
284
285    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
286        let parent = self.check_parent(h)?;
287        // Calculate removed nodes here. (Does not include transfers, so enumerates only
288        // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty)
289        let to_remove = self.get_removed_nodes(h)?;
290
291        // 1. Add all the new nodes. Note this includes replacement.root(), which we don't want.
292        // TODO what would an error here mean? e.g. malformed self.replacement??
293        let InsertionResult { new_root, node_map } = h.insert_hugr(parent, self.replacement);
294
295        // 2. Add new edges from existing to copied nodes according to mu_in
296        let translate_idx = |n| node_map.get(&n).copied().ok_or(WhichHugr::Replacement);
297        let kept = |n| {
298            let keep = !to_remove.contains(&n);
299            keep.then_some(n).ok_or(WhichHugr::Retained)
300        };
301        transfer_edges(h, self.mu_inp.iter(), kept, translate_idx, None)?;
302
303        // 3. Add new edges from copied to existing nodes according to mu_out,
304        // replacing existing value/static edges incoming to targets
305        transfer_edges(h, self.mu_out.iter(), translate_idx, kept, Some(&to_remove))?;
306
307        // 4. Add new edges between existing nodes according to mu_new,
308        // replacing existing value/static edges incoming to targets.
309        transfer_edges(h, self.mu_new.iter(), kept, kept, Some(&to_remove))?;
310
311        // 5. Put newly-added copies into correct places in hierarchy
312        // (these will be correct places after removing nodes)
313        let mut remove_top_sibs = self.removal.iter();
314        for new_node in h.children(new_root).collect::<Vec<Node>>().into_iter() {
315            if let Some(top_sib) = remove_top_sibs.next() {
316                h.move_before_sibling(new_node, *top_sib);
317            } else {
318                h.set_parent(new_node, parent);
319            }
320        }
321        debug_assert!(h.children(new_root).next().is_none());
322        h.remove_node(new_root);
323
324        // 6. Transfer to keys of `transfers` children of the corresponding values.
325        for (new_parent, &old_parent) in self.adoptions.iter() {
326            let new_parent = node_map.get(new_parent).unwrap();
327            debug_assert!(h.children(old_parent).next().is_some());
328            while let Some(ch) = h.first_child(old_parent) {
329                h.set_parent(ch, *new_parent);
330            }
331        }
332
333        // 7. Remove remaining nodes
334        to_remove.into_iter().for_each(|n| {
335            h.remove_node(n);
336        });
337        Ok(node_map)
338    }
339
340    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
341        self.removal.iter().copied()
342    }
343}
344
345fn transfer_edges<'a>(
346    h: &mut impl HugrMut,
347    edges: impl Iterator<Item = &'a NewEdgeSpec>,
348    trans_src: impl Fn(Node) -> Result<Node, WhichHugr>,
349    trans_tgt: impl Fn(Node) -> Result<Node, WhichHugr>,
350    legal_src_ancestors: Option<&HashSet<Node>>,
351) -> Result<(), ReplaceError> {
352    for oe in edges {
353        let e = NewEdgeSpec {
354            // Translation can only fail for Nodes that are supposed to be in the replacement
355            src: trans_src(oe.src)
356                .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Outgoing, h, oe.clone()))?,
357            tgt: trans_tgt(oe.tgt)
358                .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Incoming, h, oe.clone()))?,
359            ..oe.clone()
360        };
361        if !h.valid_node(e.src) {
362            return Err(ReplaceError::BadEdgeSpec(
363                Direction::Outgoing,
364                WhichHugr::Retained,
365                oe.clone(),
366            ));
367        }
368        if !h.valid_node(e.tgt) {
369            return Err(ReplaceError::BadEdgeSpec(
370                Direction::Incoming,
371                WhichHugr::Retained,
372                oe.clone(),
373            ));
374        };
375        e.check_src(h, oe)?;
376        e.check_tgt(h, oe)?;
377        match e.kind {
378            NewEdgeKind::Order => {
379                h.add_other_edge(e.src, e.tgt);
380            }
381            NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => {
382                if let Some(legal_src_ancestors) = legal_src_ancestors {
383                    e.check_existing_edge(h, legal_src_ancestors, || oe.clone())?;
384                    h.disconnect(e.tgt, tgt_pos);
385                }
386                h.connect(e.src, src_pos, e.tgt, tgt_pos);
387            }
388            NewEdgeKind::ControlFlow { src_pos } => h.connect(e.src, src_pos, e.tgt, 0),
389        }
390    }
391    Ok(())
392}
393
394/// Error in a [`Replacement`]
395#[derive(Clone, Debug, PartialEq, Eq, Error)]
396#[non_exhaustive]
397pub enum ReplaceError {
398    /// The node(s) to replace had no parent i.e. were root(s).
399    // (Perhaps if there is only one node to replace we should be able to?)
400    #[error("Cannot replace the root node of the Hugr")]
401    CantReplaceRoot,
402    /// The nodes to replace did not have a unique common parent
403    #[error("Removed nodes had different parents {0:?}")]
404    MultipleParents(Vec<Node>),
405    /// Replacement root node had different tag from parent of removed nodes
406    #[error("Expected replacement root with tag {removed} but found {replacement}")]
407    WrongRootNodeTag {
408        /// The tag of the parent of the removed nodes
409        removed: OpTag,
410        /// The tag of the root in the replacement Hugr
411        replacement: OpTag,
412    },
413    /// Keys in [Replacement::adoptions] were not valid container nodes in [Replacement::replacement]
414    #[error("Node {0} was not an empty container node in the replacement")]
415    InvalidAdoptingParent(Node),
416    /// Some values in [Replacement::adoptions] were either descendants of other values, or not
417    /// descendants of the [Replacement::removal]. The nodes are indicated on a best-effort basis.
418    #[error("Nodes not free to be moved into new locations: {0:?}")]
419    AdopteesNotSeparateDescendants(Vec<Node>),
420    /// A node at one end of a [NewEdgeSpec] was not found
421    #[error("{0:?} end of edge {2:?} not found in {1}")]
422    BadEdgeSpec(Direction, WhichHugr, NewEdgeSpec),
423    /// The target of the edge was found, but there was no existing edge to replace
424    #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")]
425    NoRemovedEdge(NewEdgeSpec),
426    /// The [NewEdgeKind] was not applicable for the source/target node(s)
427    #[error("The edge kind was not applicable to the {0:?} node: {1:?}")]
428    BadEdgeKind(Direction, NewEdgeSpec),
429}
430
431/// A Hugr or portion thereof that is part of the [Replacement]
432#[derive(Clone, Debug, PartialEq, Eq)]
433pub enum WhichHugr {
434    /// The newly-inserted nodes, i.e. the [Replacement::replacement]
435    Replacement,
436    /// Nodes in the existing Hugr that are not [Replacement::removal]
437    /// (or are on the RHS of an entry in [Replacement::adoptions])
438    Retained,
439}
440
441impl std::fmt::Display for WhichHugr {
442    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443        f.write_str(match self {
444            Self::Replacement => "replacement Hugr",
445            Self::Retained => "retained portion of Hugr",
446        })
447    }
448}
449
450#[cfg(test)]
451mod test {
452    use std::collections::HashMap;
453
454    use cool_asserts::assert_matches;
455    use itertools::Itertools;
456
457    use crate::builder::{
458        endo_sig, BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr,
459        DataflowSubContainer, HugrBuilder, SubContainer,
460    };
461    use crate::extension::prelude::{bool_t, usize_t};
462    use crate::extension::{ExtensionRegistry, PRELUDE};
463    use crate::hugr::internal::HugrMutInternals;
464    use crate::hugr::rewrite::replace::WhichHugr;
465    use crate::hugr::{HugrMut, Rewrite};
466    use crate::ops::custom::ExtensionOp;
467    use crate::ops::dataflow::DataflowOpTrait;
468    use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
469    use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
470    use crate::std_extensions::collections::list;
471    use crate::types::{Signature, Type, TypeRow};
472    use crate::utils::{depth, test_quantum_extension};
473    use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort};
474
475    use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement};
476
477    #[test]
478    #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-'
479    fn cfg() -> Result<(), Box<dyn std::error::Error>> {
480        let reg = ExtensionRegistry::new([PRELUDE.to_owned(), list::EXTENSION.to_owned()]);
481        reg.validate()?;
482        let listy = list::list_type(usize_t());
483        let pop: ExtensionOp = list::ListOp::pop
484            .with_type(usize_t())
485            .to_extension_op()
486            .unwrap();
487        let push: ExtensionOp = list::ListOp::push
488            .with_type(usize_t())
489            .to_extension_op()
490            .unwrap();
491        let just_list = TypeRow::from(vec![listy.clone()]);
492        let intermed = TypeRow::from(vec![listy.clone(), usize_t()]);
493
494        let mut cfg = CFGBuilder::new(endo_sig(just_list.clone()))?;
495
496        let pred_const = cfg.add_constant(ops::Value::unary_unit_sum());
497
498        let entry = single_node_block(&mut cfg, pop, &pred_const, true)?;
499        let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?;
500
501        let exit = cfg.exit_block();
502        cfg.branch(&entry, 0, &bb2)?;
503        cfg.branch(&bb2, 0, &exit)?;
504
505        let mut h = cfg.finish_hugr().unwrap();
506        {
507            let pop = find_node(&h, "pop");
508            let push = find_node(&h, "push");
509            assert_eq!(depth(&h, pop), 2); // BB, CFG
510            assert_eq!(depth(&h, push), 2);
511
512            let popp = h.get_parent(pop).unwrap();
513            let pushp = h.get_parent(push).unwrap();
514            assert_ne!(popp, pushp); // Two different BBs
515            assert!(h.get_optype(popp).is_dataflow_block());
516            assert!(h.get_optype(pushp).is_dataflow_block());
517
518            assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap());
519        }
520
521        // Replacement: one BB with two DFGs inside.
522        // Use Hugr rather than Builder because DFGs must be empty (not even Input/Output).
523        let mut replacement = Hugr::new(ops::CFG {
524            signature: Signature::new_endo(just_list.clone()),
525        });
526        let r_bb = replacement.add_node_with_parent(
527            replacement.root(),
528            DataflowBlock {
529                inputs: vec![listy.clone()].into(),
530                sum_rows: vec![type_row![]],
531                other_outputs: vec![listy.clone()].into(),
532                extension_delta: list::EXTENSION_ID.into(),
533            },
534        );
535        let r_df1 = replacement.add_node_with_parent(
536            r_bb,
537            DFG {
538                signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone()))
539                    .with_extension_delta(list::EXTENSION_ID),
540            },
541        );
542        let r_df2 = replacement.add_node_with_parent(
543            r_bb,
544            DFG {
545                signature: Signature::new(intermed, simple_unary_plus(just_list.clone()))
546                    .with_extension_delta(list::EXTENSION_ID),
547            },
548        );
549        [0, 1]
550            .iter()
551            .for_each(|p| replacement.connect(r_df1, *p + 1, r_df2, *p));
552
553        {
554            let inp = replacement.add_node_before(
555                r_df1,
556                ops::Input {
557                    types: just_list.clone(),
558                },
559            );
560            let out = replacement.add_node_before(
561                r_df1,
562                ops::Output {
563                    types: simple_unary_plus(just_list),
564                },
565            );
566            replacement.connect(inp, 0, r_df1, 0);
567            replacement.connect(r_df2, 0, out, 0);
568            replacement.connect(r_df2, 1, out, 1);
569        }
570
571        h.apply_rewrite(Replacement {
572            removal: vec![entry.node(), bb2.node()],
573            replacement,
574            adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]),
575            mu_inp: vec![],
576            mu_out: vec![NewEdgeSpec {
577                src: r_bb,
578                tgt: exit.node(),
579                kind: NewEdgeKind::ControlFlow {
580                    src_pos: OutgoingPort::from(0),
581                },
582            }],
583            mu_new: vec![],
584        })?;
585        h.validate()?;
586        {
587            let pop = find_node(&h, "pop");
588            let push = find_node(&h, "push");
589            assert_eq!(depth(&h, pop), 3); // DFG, BB, CFG
590            assert_eq!(depth(&h, push), 3);
591
592            let popp = h.get_parent(pop).unwrap();
593            let pushp = h.get_parent(push).unwrap();
594            assert_ne!(popp, pushp); // Two different DFGs
595            assert!(h.get_optype(popp).is_dfg());
596            assert!(h.get_optype(pushp).is_dfg());
597
598            let grandp = h.get_parent(popp).unwrap();
599            assert_eq!(grandp, h.get_parent(pushp).unwrap());
600            assert!(h.get_optype(grandp).is_dataflow_block());
601        }
602
603        Ok(())
604    }
605
606    fn find_node(h: &Hugr, s: &str) -> crate::Node {
607        h.nodes()
608            .filter(|n| format!("{}", h.get_optype(*n)).contains(s))
609            .exactly_one()
610            .ok()
611            .unwrap()
612    }
613
614    fn single_node_block<T: AsRef<Hugr> + AsMut<Hugr>, O: DataflowOpTrait + Into<OpType>>(
615        h: &mut CFGBuilder<T>,
616        op: O,
617        pred_const: &ConstID,
618        entry: bool,
619    ) -> Result<BasicBlockID, BuildError> {
620        let op_sig = op.signature();
621        let mut bb = if entry {
622            assert_eq!(
623                match h.hugr().get_optype(h.container_node()) {
624                    OpType::CFG(c) => &c.signature.input,
625                    _ => panic!(),
626                },
627                op_sig.input()
628            );
629            h.simple_entry_builder_exts(op_sig.output.clone(), 1, op_sig.runtime_reqs.clone())?
630        } else {
631            h.simple_block_builder(op_sig.into_owned(), 1)?
632        };
633        let op: OpType = op.into();
634        let op = bb.add_dataflow_op(op, bb.input_wires())?;
635        let load_pred = bb.load_const(pred_const);
636        bb.finish_with_outputs(load_pred, op.outputs())
637    }
638
639    fn simple_unary_plus(t: TypeRow) -> TypeRow {
640        let mut v = t.into_owned();
641        v.insert(0, Type::new_unit_sum(1));
642        v.into()
643    }
644
645    #[test]
646    fn test_invalid() {
647        let utou = Signature::new_endo(vec![usize_t()]);
648        let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| {
649            ext.add_op("foo".into(), "".to_string(), utou.clone(), extension_ref)
650                .unwrap();
651            ext.add_op("bar".into(), "".to_string(), utou.clone(), extension_ref)
652                .unwrap();
653            ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref)
654                .unwrap();
655        });
656        let ext_name = ext.name().clone();
657        let foo = ext.instantiate_extension_op("foo", []).unwrap();
658        let bar = ext.instantiate_extension_op("bar", []).unwrap();
659        let baz = ext.instantiate_extension_op("baz", []).unwrap();
660        let mut registry = test_quantum_extension::REG.clone();
661        registry.register(ext).unwrap();
662
663        let mut h = DFGBuilder::new(
664            Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])
665                .with_extension_delta(ext_name.clone()),
666        )
667        .unwrap();
668        let [i, b] = h.input_wires_arr();
669        let mut cond = h
670            .conditional_builder_exts(
671                (vec![type_row![]; 2], b),
672                [(usize_t(), i)],
673                vec![usize_t()].into(),
674                ext_name.clone(),
675            )
676            .unwrap();
677        let mut case1 = cond.case_builder(0).unwrap();
678        let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap();
679        let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node();
680        let mut case2 = cond.case_builder(1).unwrap();
681        let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap();
682        let mut baz_dfg = case2
683            .dfg_builder(
684                utou.clone().with_extension_delta(ext_name.clone()),
685                bar.outputs(),
686            )
687            .unwrap();
688        let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap();
689        let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap();
690        let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node();
691        let cond = cond.finish_sub_container().unwrap();
692        let h = h.finish_hugr_with_outputs(cond.outputs()).unwrap();
693
694        let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone());
695        let r1 = r_hugr.add_node_with_parent(
696            r_hugr.root(),
697            Case {
698                signature: utou.clone(),
699            },
700        );
701        let r2 = r_hugr.add_node_with_parent(
702            r_hugr.root(),
703            Case {
704                signature: utou.clone(),
705            },
706        );
707        let rep: Replacement = Replacement {
708            removal: vec![case1, case2],
709            replacement: r_hugr,
710            adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]),
711            mu_inp: vec![],
712            mu_out: vec![],
713            mu_new: vec![],
714        };
715        assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node()));
716        rep.verify(&h).unwrap();
717        {
718            let mut target = h.clone();
719            let node_map = rep.clone().apply(&mut target).unwrap();
720            let new_case2 = *node_map.get(&r2).unwrap();
721            assert_eq!(target.get_parent(baz.node()), Some(new_case2));
722        }
723
724        // Test some bad Replacements (using variations of the `replacement` Hugr).
725        let check_same_errors = |r: Replacement| {
726            let verify_res = r.verify(&h).unwrap_err();
727            let apply_res = r.apply(&mut h.clone()).unwrap_err();
728            assert_eq!(verify_res, apply_res);
729            apply_res
730        };
731        // Root node type needs to be that of common parent of the removed nodes:
732        let mut rep2 = rep.clone();
733        rep2.replacement
734            .replace_op(rep2.replacement.root(), h.root_type().clone())
735            .unwrap();
736        assert_eq!(
737            check_same_errors(rep2),
738            ReplaceError::WrongRootNodeTag {
739                removed: OpTag::Conditional,
740                replacement: OpTag::Dfg
741            }
742        );
743        // Removed nodes...
744        assert_eq!(
745            check_same_errors(Replacement {
746                removal: vec![h.root()],
747                ..rep.clone()
748            }),
749            ReplaceError::CantReplaceRoot
750        );
751        assert_eq!(
752            check_same_errors(Replacement {
753                removal: vec![case1, baz_dfg.node()],
754                ..rep.clone()
755            }),
756            ReplaceError::MultipleParents(vec![cond.node(), case2])
757        );
758        // Adoptions...
759        assert_eq!(
760            check_same_errors(Replacement {
761                adoptions: HashMap::from([(r1, case1), (rep.replacement.root(), case2)]),
762                ..rep.clone()
763            }),
764            ReplaceError::InvalidAdoptingParent(rep.replacement.root())
765        );
766        assert_eq!(
767            check_same_errors(Replacement {
768                adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]),
769                ..rep.clone()
770            }),
771            ReplaceError::AdopteesNotSeparateDescendants(vec![case1])
772        );
773        assert_eq!(
774            check_same_errors(Replacement {
775                adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]),
776                ..rep.clone()
777            }),
778            ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()])
779        );
780        // Edges....
781        let edge_from_removed = NewEdgeSpec {
782            src: case1,
783            tgt: r2,
784            kind: NewEdgeKind::Order,
785        };
786        assert_eq!(
787            check_same_errors(Replacement {
788                mu_inp: vec![edge_from_removed.clone()],
789                ..rep.clone()
790            }),
791            ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed)
792        );
793        let bad_out_edge = NewEdgeSpec {
794            src: h.nodes().max().unwrap(), // not valid in replacement
795            tgt: cond.node(),
796            kind: NewEdgeKind::Order,
797        };
798        assert_eq!(
799            check_same_errors(Replacement {
800                mu_out: vec![bad_out_edge.clone()],
801                ..rep.clone()
802            }),
803            ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge)
804        );
805        let bad_order_edge = NewEdgeSpec {
806            src: cond.node(),
807            tgt: h.get_io(h.root()).unwrap()[1],
808            kind: NewEdgeKind::ControlFlow { src_pos: 0.into() },
809        };
810        assert_matches!(
811            check_same_errors(Replacement {
812                mu_new: vec![bad_order_edge.clone()],
813                ..rep.clone()
814            }),
815            ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, bad_order_edge)
816        );
817        let op = OutgoingPort::from(0);
818        let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap();
819        let new_out_edge = NewEdgeSpec {
820            src: r1.node(),
821            tgt,
822            kind: NewEdgeKind::Value {
823                src_pos: op,
824                tgt_pos: ip,
825            },
826        };
827        assert_eq!(
828            check_same_errors(Replacement {
829                mu_out: vec![new_out_edge.clone()],
830                ..rep.clone()
831            }),
832            ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)
833        );
834    }
835}