hugr_core/hugr/patch/
replace.rs

1//! Implementation of the `Replace` operation.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use itertools::Itertools;
6use thiserror::Error;
7
8use crate::core::HugrNode;
9use crate::hugr::HugrMut;
10use crate::hugr::hugrmut::InsertionResult;
11use crate::hugr::views::check_valid_non_entrypoint;
12use crate::ops::{OpTag, OpTrait};
13use crate::types::EdgeKind;
14use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort};
15
16use super::{PatchHugrMut, PatchVerification};
17
18/// Specifies how to create a new edge.
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct NewEdgeSpec<SrcNode, TgtNode> {
21    /// The source of the new edge. For [`Replacement::mu_inp`] and
22    /// [`Replacement::mu_new`], this is in the existing Hugr; for edges in
23    /// [`Replacement::mu_out`] this is in the [`Replacement::replacement`]
24    pub src: SrcNode,
25    /// The target of the new edge. For [`Replacement::mu_inp`], this is in the
26    /// [`Replacement::replacement`]; for edges in [`Replacement::mu_out`] and
27    /// [`Replacement::mu_new`], this is in the existing Hugr.
28    pub tgt: TgtNode,
29    /// The kind of edge to create, and any port specifiers required
30    pub kind: NewEdgeKind,
31}
32
33/// Describes an edge that should be created between two nodes already given
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum NewEdgeKind {
36    /// An [`EdgeKind::StateOrder`] edge (between DFG nodes only)
37    Order,
38    /// An [`EdgeKind::Value`] edge (between DFG nodes only)
39    Value {
40        /// The source port
41        src_pos: OutgoingPort,
42        /// The target port
43        tgt_pos: IncomingPort,
44    },
45    /// An [`EdgeKind::Const`] or [`EdgeKind::Function`] edge
46    Static {
47        /// The source port
48        src_pos: OutgoingPort,
49        /// The target port
50        tgt_pos: IncomingPort,
51    },
52    /// A [`EdgeKind::ControlFlow`] edge (between CFG nodes only)
53    ControlFlow {
54        /// Identifies a control-flow output (successor) of the source node.
55        src_pos: OutgoingPort,
56    },
57}
58
59/// Specification of a `Replace` operation
60#[derive(Debug, Clone, PartialEq)]
61pub struct Replacement<HostNode = Node> {
62    /// The nodes to remove from the existing Hugr (known as Gamma).
63    /// These must all have a common parent (i.e. be siblings).  Called "S" in
64    /// the spec. Must be non-empty - otherwise there is no parent under
65    /// which to place [`Self::replacement`], and there would be no possible
66    /// [`Self::mu_inp`], [`Self::mu_out`] or [`Self::adoptions`].
67    pub removal: Vec<HostNode>,
68    /// A hugr (not necessarily valid, as it may be missing edges and/or nodes),
69    /// whose root is the same type as the root of [`Self::replacement`].  "G"
70    /// in the spec.
71    pub replacement: Hugr,
72    /// Describes how parts of the Hugr that would otherwise be removed should
73    /// instead be preserved but with new parents amongst the newly-inserted
74    /// nodes.  This is a Map from container nodes in [`Self::replacement`]
75    /// that have no children, to container nodes that are descended from
76    /// [`Self::removal`]. The keys are the new parents for the children of
77    /// the values.  Note no value may be ancestor or descendant of another.
78    /// This is "B" in the spec; "R" is the set of descendants of
79    /// [`Self::removal`]  that are not descendants of values here.
80    pub adoptions: HashMap<Node, HostNode>,
81    /// Edges from nodes in the existing Hugr that are not removed
82    /// ([`NewEdgeSpec::src`] in Gamma\R) to inserted nodes
83    /// ([`NewEdgeSpec::tgt`] in [`Self::replacement`]).
84    pub mu_inp: Vec<NewEdgeSpec<HostNode, Node>>,
85    /// Edges from inserted nodes ([`NewEdgeSpec::src`] in [`Self::replacement`]) to
86    /// existing nodes not removed ([`NewEdgeSpec::tgt`] in Gamma \ R).
87    pub mu_out: Vec<NewEdgeSpec<Node, HostNode>>,
88    /// Edges to add between existing nodes (both [`NewEdgeSpec::src`] and
89    /// [`NewEdgeSpec::tgt`] in Gamma \ R). For example, in cases where the
90    /// source had an edge to a removed node, and the target had an
91    /// edge from a removed node, this would allow source to be directly
92    /// connected to target.
93    pub mu_new: Vec<NewEdgeSpec<HostNode, HostNode>>,
94}
95
96impl<SrcNode: Copy, TgtNode: Copy> NewEdgeSpec<SrcNode, TgtNode> {
97    fn check_src<HostNode>(
98        &self,
99        h: &impl HugrView<Node = SrcNode>,
100        err_spec: impl Fn(Self) -> WhichEdgeSpec<HostNode>,
101    ) -> Result<(), ReplaceError<HostNode>> {
102        let optype = h.get_optype(self.src);
103        let ok = match self.kind {
104            NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder),
105            NewEdgeKind::Value { src_pos, .. } => {
106                matches!(optype.port_kind(src_pos), Some(EdgeKind::Value(_)))
107            }
108            NewEdgeKind::Static { src_pos, .. } => optype
109                .port_kind(src_pos)
110                .as_ref()
111                .is_some_and(EdgeKind::is_static),
112            NewEdgeKind::ControlFlow { src_pos } => {
113                matches!(optype.port_kind(src_pos), Some(EdgeKind::ControlFlow))
114            }
115        };
116        ok.then_some(())
117            .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec(self.clone())))
118    }
119
120    fn check_tgt<HostNode>(
121        &self,
122        h: &impl HugrView<Node = TgtNode>,
123        err_spec: impl Fn(Self) -> WhichEdgeSpec<HostNode>,
124    ) -> Result<(), ReplaceError<HostNode>> {
125        let optype = h.get_optype(self.tgt);
126        let ok = match self.kind {
127            NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder),
128            NewEdgeKind::Value { tgt_pos, .. } => {
129                matches!(optype.port_kind(tgt_pos), Some(EdgeKind::Value(_)))
130            }
131            NewEdgeKind::Static { tgt_pos, .. } => optype
132                .port_kind(tgt_pos)
133                .as_ref()
134                .is_some_and(EdgeKind::is_static),
135            NewEdgeKind::ControlFlow { .. } => matches!(
136                optype.port_kind(IncomingPort::from(0)),
137                Some(EdgeKind::ControlFlow)
138            ),
139        };
140        ok.then_some(())
141            .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec(self.clone())))
142    }
143}
144
145impl<HostNode: HugrNode, N: Clone> NewEdgeSpec<N, HostNode> {
146    fn check_existing_edge(
147        &self,
148        h: &impl HugrView<Node = HostNode>,
149        legal_src_ancestors: &HashSet<HostNode>,
150        err_edge: impl Fn(Self) -> WhichEdgeSpec<HostNode>,
151    ) -> Result<(), ReplaceError<HostNode>> {
152        if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind
153        {
154            let descends_from_legal = |mut descendant: HostNode| -> bool {
155                while !legal_src_ancestors.contains(&descendant) {
156                    let Some(p) = h.get_parent(descendant) else {
157                        return false;
158                    };
159                    descendant = p;
160                }
161                true
162            };
163            let found_incoming = h
164                .single_linked_output(self.tgt, tgt_pos)
165                .is_some_and(|(src_n, _)| descends_from_legal(src_n));
166            if !found_incoming {
167                return Err(ReplaceError::NoRemovedEdge(err_edge(self.clone())));
168            }
169        }
170        Ok(())
171    }
172}
173
174impl<HostNode: HugrNode> Replacement<HostNode> {
175    fn check_parent(
176        &self,
177        h: &impl HugrView<Node = HostNode>,
178    ) -> Result<HostNode, ReplaceError<HostNode>> {
179        let parent = self
180            .removal
181            .iter()
182            .map(|n| h.get_parent(*n))
183            .unique()
184            .exactly_one()
185            .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))?
186            .ok_or(ReplaceError::CantReplaceRoot)?; // If no parent
187
188        // Check replacement parent is of same tag. Note we do not require exact
189        // equality of OpType/Signature, e.g. to ease changing of Input/Output
190        // node signatures too.
191        let removed = h.get_optype(parent).tag();
192        let replacement = self.replacement.entrypoint_optype().tag();
193        if removed != replacement {
194            return Err(ReplaceError::WrongRootNodeTag {
195                removed,
196                replacement,
197            });
198        }
199        Ok(parent)
200    }
201
202    fn get_removed_nodes(
203        &self,
204        h: &impl HugrView<Node = HostNode>,
205    ) -> Result<HashSet<HostNode>, ReplaceError<HostNode>> {
206        // Check the keys of the transfer map too, the values we'll use imminently
207        self.adoptions.keys().try_for_each(|&n| {
208            (self.replacement.contains_node(n)
209                && self.replacement.get_optype(n).is_container()
210                && self.replacement.children(n).next().is_none())
211            .then_some(())
212            .ok_or(ReplaceError::InvalidAdoptingParent(n))
213        })?;
214        let mut transferred: HashSet<HostNode> = self.adoptions.values().copied().collect();
215        if transferred.len() != self.adoptions.values().len() {
216            return Err(ReplaceError::AdopteesNotSeparateDescendants(
217                self.adoptions
218                    .values()
219                    .filter(|v| !transferred.remove(v))
220                    .copied()
221                    .collect(),
222            ));
223        }
224
225        let mut removed = HashSet::new();
226        let mut queue = VecDeque::from_iter(self.removal.iter().copied());
227        while let Some(n) = queue.pop_front() {
228            let new = removed.insert(n);
229            debug_assert!(new); // Fails only if h's hierarchy has merges (is not a tree)
230            if !transferred.remove(&n) {
231                h.children(n).for_each(|ch| queue.push_back(ch));
232            }
233        }
234        if !transferred.is_empty() {
235            return Err(ReplaceError::AdopteesNotSeparateDescendants(
236                transferred.into_iter().collect(),
237            ));
238        }
239        Ok(removed)
240    }
241}
242
243impl<HostNode: HugrNode> PatchVerification for Replacement<HostNode> {
244    type Error = ReplaceError<HostNode>;
245    type Node = HostNode;
246
247    fn verify(&self, h: &impl HugrView<Node = HostNode>) -> Result<(), Self::Error> {
248        self.check_parent(h)?;
249        let removed = self.get_removed_nodes(h)?;
250        // Edge sources...
251        for e in &self.mu_inp {
252            if !h.contains_node(e.src) || removed.contains(&e.src) {
253                return Err(ReplaceError::BadEdgeSpec(
254                    Direction::Outgoing,
255                    WhichEdgeSpec::HostToRepl(e.clone()),
256                ));
257            }
258            e.check_src(h, WhichEdgeSpec::HostToRepl)?;
259        }
260        for e in &self.mu_new {
261            if !h.contains_node(e.src) || removed.contains(&e.src) {
262                return Err(ReplaceError::BadEdgeSpec(
263                    Direction::Outgoing,
264                    WhichEdgeSpec::HostToHost(e.clone()),
265                ));
266            }
267            e.check_src(h, WhichEdgeSpec::HostToHost)?;
268        }
269        self.mu_out.iter().try_for_each(|e| {
270            if check_valid_non_entrypoint(&self.replacement, e.src) {
271                e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost)
272            } else {
273                Err(ReplaceError::BadEdgeSpec(
274                    Direction::Outgoing,
275                    WhichEdgeSpec::ReplToHost(e.clone()),
276                ))
277            }
278        })?;
279        // Edge targets...
280        self.mu_inp.iter().try_for_each(|e| {
281            if check_valid_non_entrypoint(&self.replacement, e.tgt) {
282                e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl)
283            } else {
284                Err(ReplaceError::BadEdgeSpec(
285                    Direction::Incoming,
286                    WhichEdgeSpec::HostToRepl(e.clone()),
287                ))
288            }
289        })?;
290        for e in &self.mu_out {
291            if !h.contains_node(e.tgt) || removed.contains(&e.tgt) {
292                return Err(ReplaceError::BadEdgeSpec(
293                    Direction::Incoming,
294                    WhichEdgeSpec::ReplToHost(e.clone()),
295                ));
296            }
297            e.check_tgt(h, WhichEdgeSpec::ReplToHost)?;
298            // The descendant check is to allow the case where the old edge is nonlocal
299            // from a part of the Hugr being moved (which may require changing source,
300            // depending on where the transplanted portion ends up). While this subsumes
301            // the first "removed.contains" check, we'll keep that as a common-case
302            // fast-path.
303            e.check_existing_edge(h, &removed, WhichEdgeSpec::ReplToHost)?;
304        }
305        for e in &self.mu_new {
306            if !h.contains_node(e.tgt) || removed.contains(&e.tgt) {
307                return Err(ReplaceError::BadEdgeSpec(
308                    Direction::Incoming,
309                    WhichEdgeSpec::HostToHost(e.clone()),
310                ));
311            }
312            e.check_tgt(h, WhichEdgeSpec::HostToHost)?;
313            // The descendant check is to allow the case where the old edge is nonlocal
314            // from a part of the Hugr being moved (which may require changing source,
315            // depending on where the transplanted portion ends up). While this subsumes
316            // the first "removed.contains" check, we'll keep that as a common-case
317            // fast-path.
318            e.check_existing_edge(h, &removed, WhichEdgeSpec::HostToHost)?;
319        }
320        Ok(())
321    }
322
323    fn invalidated_nodes(
324        &self,
325        _: &impl HugrView<Node = Self::Node>,
326    ) -> impl Iterator<Item = Self::Node> {
327        self.removal.iter().copied()
328    }
329}
330
331impl<HostNode: HugrNode> PatchHugrMut for Replacement<HostNode> {
332    /// Map from Node in replacement to corresponding Node in the result Hugr
333    type Outcome = HashMap<Node, HostNode>;
334
335    const UNCHANGED_ON_FAILURE: bool = false;
336
337    fn apply_hugr_mut(
338        self,
339        h: &mut impl HugrMut<Node = HostNode>,
340    ) -> Result<Self::Outcome, Self::Error> {
341        let parent = self.check_parent(h)?;
342        // Calculate removed nodes here. (Does not include transfers, so enumerates only
343        // nodes we are going to remove, individually, anyway; so no *asymptotic* speed
344        // penalty)
345        let to_remove = self.get_removed_nodes(h)?;
346
347        // 1. Add all the new nodes. Note this includes replacement.root(), which we
348        //    don't want.
349        // TODO what would an error here mean? e.g. malformed self.replacement??
350        let InsertionResult {
351            inserted_entrypoint,
352            node_map,
353        } = h.insert_hugr(parent, self.replacement);
354
355        // 2. Add new edges from existing to copied nodes according to mu_in
356        let translate_idx = |n| node_map.get(&n).copied();
357        let kept = |n| (!to_remove.contains(&n)).then_some(n);
358        transfer_edges(
359            h,
360            self.mu_inp.iter(),
361            kept,
362            translate_idx,
363            WhichEdgeSpec::HostToRepl,
364            None,
365        )?;
366
367        // 3. Add new edges from copied to existing nodes according to mu_out,
368        // replacing existing value/static edges incoming to targets
369        transfer_edges(
370            h,
371            self.mu_out.iter(),
372            translate_idx,
373            kept,
374            WhichEdgeSpec::ReplToHost,
375            Some(&to_remove),
376        )?;
377
378        // 4. Add new edges between existing nodes according to mu_new,
379        // replacing existing value/static edges incoming to targets.
380        transfer_edges(
381            h,
382            self.mu_new.iter(),
383            kept,
384            kept,
385            WhichEdgeSpec::HostToHost,
386            Some(&to_remove),
387        )?;
388
389        // 5. Put newly-added copies into correct places in hierarchy
390        // (these will be correct places after removing nodes)
391        let mut remove_top_sibs = self.removal.iter();
392        for new_node in h.children(inserted_entrypoint).collect::<Vec<HostNode>>() {
393            if let Some(top_sib) = remove_top_sibs.next() {
394                h.move_before_sibling(new_node, *top_sib);
395            } else {
396                h.set_parent(new_node, parent);
397            }
398        }
399        debug_assert!(h.children(inserted_entrypoint).next().is_none());
400        h.remove_node(inserted_entrypoint);
401
402        // 6. Transfer to keys of `transfers` children of the corresponding values.
403        for (new_parent, &old_parent) in &self.adoptions {
404            let new_parent = node_map.get(new_parent).unwrap();
405            debug_assert!(h.children(old_parent).next().is_some());
406            while let Some(ch) = h.first_child(old_parent) {
407                h.set_parent(ch, *new_parent);
408            }
409        }
410
411        // 7. Remove remaining nodes
412        for n in to_remove {
413            h.remove_node(n);
414        }
415        Ok(node_map)
416    }
417}
418
419fn transfer_edges<'a, SrcNode, TgtNode, HostNode>(
420    h: &mut impl HugrMut<Node = HostNode>,
421    edges: impl Iterator<Item = &'a NewEdgeSpec<SrcNode, TgtNode>>,
422    trans_src: impl Fn(SrcNode) -> Option<HostNode>,
423    trans_tgt: impl Fn(TgtNode) -> Option<HostNode>,
424    err_spec: impl Fn(NewEdgeSpec<SrcNode, TgtNode>) -> WhichEdgeSpec<HostNode>,
425    legal_src_ancestors: Option<&HashSet<HostNode>>,
426) -> Result<(), ReplaceError<HostNode>>
427where
428    SrcNode: 'a + HugrNode,
429    TgtNode: 'a + HugrNode,
430    HostNode: 'a + HugrNode,
431{
432    for oe in edges {
433        let err_spec = err_spec(oe.clone());
434        let e = NewEdgeSpec {
435            // Translation can only fail for Nodes that are supposed to be in the replacement
436            src: trans_src(oe.src)
437                .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Outgoing, err_spec.clone()))?,
438            tgt: trans_tgt(oe.tgt)
439                .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?,
440            kind: oe.kind,
441        };
442        if !h.contains_node(e.src) {
443            return Err(ReplaceError::BadEdgeSpec(
444                Direction::Outgoing,
445                err_spec.clone(),
446            ));
447        }
448        if !h.contains_node(e.tgt) {
449            return Err(ReplaceError::BadEdgeSpec(
450                Direction::Incoming,
451                err_spec.clone(),
452            ));
453        }
454        let err_spec = |_| err_spec.clone();
455        e.check_src(h, err_spec)?;
456        e.check_tgt(h, err_spec)?;
457        match e.kind {
458            NewEdgeKind::Order => {
459                h.add_other_edge(e.src, e.tgt);
460            }
461            NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => {
462                if let Some(legal_src_ancestors) = legal_src_ancestors {
463                    e.check_existing_edge(h, legal_src_ancestors, err_spec)?;
464                    h.disconnect(e.tgt, tgt_pos);
465                }
466                h.connect(e.src, src_pos, e.tgt, tgt_pos);
467            }
468            NewEdgeKind::ControlFlow { src_pos } => h.connect(e.src, src_pos, e.tgt, 0),
469        }
470    }
471    Ok(())
472}
473
474/// Error in a [`Replacement`]
475#[derive(Clone, Debug, PartialEq, Eq, Error)]
476#[non_exhaustive]
477pub enum ReplaceError<HostNode = Node> {
478    /// The node(s) to replace had no parent i.e. were root(s).
479    // (Perhaps if there is only one node to replace we should be able to?)
480    #[error("Cannot replace the root node of the Hugr")]
481    CantReplaceRoot,
482    /// The nodes to replace did not have a unique common parent
483    #[error("Removed nodes had different parents {0:?}")]
484    MultipleParents(Vec<HostNode>),
485    /// Replacement root node had different tag from parent of removed nodes
486    #[error("Expected replacement root with tag {removed} but found {replacement}")]
487    WrongRootNodeTag {
488        /// The tag of the parent of the removed nodes
489        removed: OpTag,
490        /// The tag of the root in the replacement Hugr
491        replacement: OpTag,
492    },
493    /// Keys in [`Replacement::adoptions`] were not valid container nodes in
494    /// [`Replacement::replacement`]
495    #[error("Node {0} was not an empty container node in the replacement")]
496    InvalidAdoptingParent(Node),
497    /// Some values in [`Replacement::adoptions`] were either descendants of other
498    /// values, or not descendants of the [`Replacement::removal`]. The nodes
499    /// are indicated on a best-effort basis.
500    #[error("Nodes not free to be moved into new locations: {0:?}")]
501    AdopteesNotSeparateDescendants(Vec<HostNode>),
502    /// A node at one end of a [`NewEdgeSpec`] was not found
503    #[error("{0:?} end of edge {1:?} not found in {which_hugr}", which_hugr = .1.which_hugr(*.0))]
504    BadEdgeSpec(Direction, WhichEdgeSpec<HostNode>),
505    /// The target of the edge was found, but there was no existing edge to
506    /// replace
507    #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")]
508    NoRemovedEdge(WhichEdgeSpec<HostNode>),
509    /// The [`NewEdgeKind`] was not applicable for the source/target node(s)
510    #[error("The edge kind was not applicable to the {0:?} node: {1:?}")]
511    BadEdgeKind(Direction, WhichEdgeSpec<HostNode>),
512}
513
514/// The three kinds of [`NewEdgeSpec`] that may appear in a [`ReplaceError`]
515#[derive(Clone, Debug, PartialEq, Eq)]
516pub enum WhichEdgeSpec<HostNode> {
517    /// An edge from the host Hugr into the replacement, i.e.
518    /// [`Replacement::mu_inp`]
519    HostToRepl(NewEdgeSpec<HostNode, Node>),
520    /// An edge from the replacement to the host, i.e. [`Replacement::mu_out`]
521    ReplToHost(NewEdgeSpec<Node, HostNode>),
522    /// An edge between two nodes in the host (bypassing the replacement),
523    /// i.e. [`Replacement::mu_new`]
524    HostToHost(NewEdgeSpec<HostNode, HostNode>),
525}
526
527impl<HostNode> WhichEdgeSpec<HostNode> {
528    fn which_hugr(&self, d: Direction) -> &str {
529        match (self, d) {
530            (Self::HostToRepl(_), Direction::Incoming)
531            | (Self::ReplToHost(_), Direction::Outgoing) => "replacement Hugr",
532            _ => "retained portion of Hugr",
533        }
534    }
535}
536
537#[cfg(test)]
538mod test {
539    use std::collections::HashMap;
540
541    use cool_asserts::assert_matches;
542    use itertools::Itertools;
543
544    use crate::builder::{
545        BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr,
546        DataflowSubContainer, HugrBuilder, SubContainer, endo_sig,
547    };
548    use crate::extension::prelude::{bool_t, usize_t};
549    use crate::extension::{ExtensionRegistry, PRELUDE};
550    use crate::hugr::internal::HugrMutInternals;
551    use crate::hugr::patch::PatchVerification;
552    use crate::hugr::{HugrMut, Patch};
553    use crate::ops::custom::ExtensionOp;
554    use crate::ops::dataflow::DataflowOpTrait;
555    use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
556    use crate::ops::{self, Case, DFG, DataflowBlock, OpTag, OpType};
557    use crate::std_extensions::collections::list;
558    use crate::types::{Signature, Type, TypeRow};
559    use crate::utils::{depth, test_quantum_extension};
560    use crate::{Direction, Extension, Hugr, HugrView, OutgoingPort, type_row};
561
562    use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement, WhichEdgeSpec};
563
564    #[test]
565    #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-'
566    fn cfg() -> Result<(), Box<dyn std::error::Error>> {
567        let reg = ExtensionRegistry::new([PRELUDE.to_owned(), list::EXTENSION.to_owned()]);
568        reg.validate()?;
569        let listy = list::list_type(usize_t());
570        let pop: ExtensionOp = list::ListOp::pop
571            .with_type(usize_t())
572            .to_extension_op()
573            .unwrap();
574        let push: ExtensionOp = list::ListOp::push
575            .with_type(usize_t())
576            .to_extension_op()
577            .unwrap();
578        let just_list = TypeRow::from(vec![listy.clone()]);
579        let intermed = TypeRow::from(vec![listy.clone(), usize_t()]);
580
581        let mut cfg = CFGBuilder::new(endo_sig(just_list.clone()))?;
582
583        let pred_const = cfg.add_constant(ops::Value::unary_unit_sum());
584
585        let entry = single_node_block(&mut cfg, pop, &pred_const, true)?;
586        let bb2 = single_node_block(&mut cfg, push, &pred_const, false)?;
587
588        let exit = cfg.exit_block();
589        cfg.branch(&entry, 0, &bb2)?;
590        cfg.branch(&bb2, 0, &exit)?;
591
592        let mut h = cfg.finish_hugr().unwrap();
593        {
594            let pop = find_node(&h, "pop");
595            let push = find_node(&h, "push");
596            assert_eq!(depth(&h, pop), 2); // BB, CFG
597            assert_eq!(depth(&h, push), 2);
598
599            let popp = h.get_parent(pop).unwrap();
600            let pushp = h.get_parent(push).unwrap();
601            assert_ne!(popp, pushp); // Two different BBs
602            assert!(h.get_optype(popp).is_dataflow_block());
603            assert!(h.get_optype(pushp).is_dataflow_block());
604
605            assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap());
606        }
607
608        // Replacement: one BB with two DFGs inside.
609        // Use Hugr rather than Builder because it must be empty (not even
610        // Input/Output).
611        let mut replacement = Hugr::new_with_entrypoint(ops::CFG {
612            signature: Signature::new_endo(just_list.clone()),
613        })
614        .expect("CFG is a valid entrypoint");
615        let r_bb = replacement.add_node_with_parent(
616            replacement.entrypoint(),
617            DataflowBlock {
618                inputs: vec![listy.clone()].into(),
619                sum_rows: vec![type_row![]],
620                other_outputs: vec![listy.clone()].into(),
621            },
622        );
623        let r_df1 = replacement.add_node_with_parent(
624            r_bb,
625            DFG {
626                signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())),
627            },
628        );
629        let r_df2 = replacement.add_node_with_parent(
630            r_bb,
631            DFG {
632                signature: Signature::new(intermed, simple_unary_plus(just_list.clone())),
633            },
634        );
635        [0, 1]
636            .iter()
637            .for_each(|p| replacement.connect(r_df1, *p + 1, r_df2, *p));
638
639        {
640            let inp = replacement.add_node_before(
641                r_df1,
642                ops::Input {
643                    types: just_list.clone(),
644                },
645            );
646            let out = replacement.add_node_before(
647                r_df1,
648                ops::Output {
649                    types: simple_unary_plus(just_list),
650                },
651            );
652            replacement.connect(inp, 0, r_df1, 0);
653            replacement.connect(r_df2, 0, out, 0);
654            replacement.connect(r_df2, 1, out, 1);
655        }
656
657        h.apply_patch(Replacement {
658            removal: vec![entry.node(), bb2.node()],
659            replacement,
660            adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]),
661            mu_inp: vec![],
662            mu_out: vec![NewEdgeSpec {
663                src: r_bb,
664                tgt: exit.node(),
665                kind: NewEdgeKind::ControlFlow {
666                    src_pos: OutgoingPort::from(0),
667                },
668            }],
669            mu_new: vec![],
670        })?;
671        h.validate()?;
672        {
673            let pop = find_node(&h, "pop");
674            let push = find_node(&h, "push");
675            assert_eq!(depth(&h, pop), 3); // DFG, BB, CFG
676            assert_eq!(depth(&h, push), 3);
677
678            let popp = h.get_parent(pop).unwrap();
679            let pushp = h.get_parent(push).unwrap();
680            assert_ne!(popp, pushp); // Two different DFGs
681            assert!(h.get_optype(popp).is_dfg());
682            assert!(h.get_optype(pushp).is_dfg());
683
684            let grandp = h.get_parent(popp).unwrap();
685            assert_eq!(grandp, h.get_parent(pushp).unwrap());
686            assert!(h.get_optype(grandp).is_dataflow_block());
687        }
688
689        Ok(())
690    }
691
692    fn find_node(h: &Hugr, s: &str) -> crate::Node {
693        h.entry_descendants()
694            .filter(|n| format!("{}", h.get_optype(*n)).contains(s))
695            .exactly_one()
696            .ok()
697            .unwrap()
698    }
699
700    fn single_node_block<T: AsRef<Hugr> + AsMut<Hugr>, O: DataflowOpTrait + Into<OpType>>(
701        h: &mut CFGBuilder<T>,
702        op: O,
703        pred_const: &ConstID,
704        entry: bool,
705    ) -> Result<BasicBlockID, BuildError> {
706        let op_sig = op.signature();
707        let mut bb = if entry {
708            assert_eq!(
709                if let OpType::CFG(c) = h.hugr().get_optype(h.container_node()) {
710                    &c.signature.input
711                } else {
712                    panic!()
713                },
714                op_sig.input()
715            );
716            h.simple_entry_builder(op_sig.output.clone(), 1)?
717        } else {
718            h.simple_block_builder(op_sig.into_owned(), 1)?
719        };
720        let op: OpType = op.into();
721        let op = bb.add_dataflow_op(op, bb.input_wires())?;
722        let load_pred = bb.load_const(pred_const);
723        bb.finish_with_outputs(load_pred, op.outputs())
724    }
725
726    fn simple_unary_plus(t: TypeRow) -> TypeRow {
727        let mut v = t.into_owned();
728        v.insert(0, Type::new_unit_sum(1));
729        v.into()
730    }
731
732    #[test]
733    fn test_invalid() {
734        let utou = Signature::new_endo(vec![usize_t()]);
735        let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| {
736            ext.add_op("foo".into(), String::new(), utou.clone(), extension_ref)
737                .unwrap();
738            ext.add_op("bar".into(), String::new(), utou.clone(), extension_ref)
739                .unwrap();
740            ext.add_op("baz".into(), String::new(), utou.clone(), extension_ref)
741                .unwrap();
742        });
743        let foo = ext.instantiate_extension_op("foo", []).unwrap();
744        let bar = ext.instantiate_extension_op("bar", []).unwrap();
745        let baz = ext.instantiate_extension_op("baz", []).unwrap();
746        let mut registry = test_quantum_extension::REG.clone();
747        registry.register(ext).unwrap();
748
749        let mut h =
750            DFGBuilder::new(Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])).unwrap();
751        let [i, b] = h.input_wires_arr();
752        let mut cond = h
753            .conditional_builder(
754                (vec![type_row![]; 2], b),
755                [(usize_t(), i)],
756                vec![usize_t()].into(),
757            )
758            .unwrap();
759        let mut case1 = cond.case_builder(0).unwrap();
760        let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap();
761        let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node();
762        let mut case2 = cond.case_builder(1).unwrap();
763        let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap();
764        let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs()).unwrap();
765        let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap();
766        let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap();
767        let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node();
768        let cond = cond.finish_sub_container().unwrap();
769        let h = h.finish_hugr_with_outputs(cond.outputs()).unwrap();
770
771        let mut r_hugr = Hugr::new_with_entrypoint(h.get_optype(cond.node()).clone()).unwrap();
772        let r1 = r_hugr.add_node_with_parent(
773            r_hugr.entrypoint(),
774            Case {
775                signature: utou.clone(),
776            },
777        );
778        let r2 = r_hugr.add_node_with_parent(
779            r_hugr.entrypoint(),
780            Case {
781                signature: utou.clone(),
782            },
783        );
784        let rep: Replacement = Replacement {
785            removal: vec![case1, case2],
786            replacement: r_hugr,
787            adoptions: HashMap::from_iter([(r1, case1), (r2, baz_dfg.node())]),
788            mu_inp: vec![],
789            mu_out: vec![],
790            mu_new: vec![],
791        };
792        assert_eq!(h.get_parent(baz.node()), Some(baz_dfg.node()));
793        rep.verify(&h).unwrap();
794        {
795            let mut target = h.clone();
796            let node_map = rep.clone().apply(&mut target).unwrap();
797            let new_case2 = *node_map.get(&r2).unwrap();
798            assert_eq!(target.get_parent(baz.node()), Some(new_case2));
799        }
800
801        // Test some bad Replacements (using variations of the `replacement` Hugr).
802        let check_same_errors = |r: Replacement| {
803            let verify_res = r.verify(&h).unwrap_err();
804            let apply_res = r.apply(&mut h.clone()).unwrap_err();
805            assert_eq!(verify_res, apply_res);
806            apply_res
807        };
808        // Root node type needs to be that of common parent of the removed nodes:
809        let mut rep2 = rep.clone();
810        rep2.replacement
811            .replace_op(rep2.replacement.entrypoint(), h.entrypoint_optype().clone());
812        assert_eq!(
813            check_same_errors(rep2),
814            ReplaceError::WrongRootNodeTag {
815                removed: OpTag::Conditional,
816                replacement: OpTag::Dfg
817            }
818        );
819        // Removed nodes...
820        assert_eq!(
821            check_same_errors(Replacement {
822                removal: vec![h.module_root()],
823                ..rep.clone()
824            }),
825            ReplaceError::CantReplaceRoot
826        );
827        assert_eq!(
828            check_same_errors(Replacement {
829                removal: vec![case1, baz_dfg.node()],
830                ..rep.clone()
831            }),
832            ReplaceError::MultipleParents(vec![cond.node(), case2])
833        );
834        // Adoptions...
835        assert_eq!(
836            check_same_errors(Replacement {
837                adoptions: HashMap::from([(r1, case1), (rep.replacement.entrypoint(), case2)]),
838                ..rep.clone()
839            }),
840            ReplaceError::InvalidAdoptingParent(rep.replacement.entrypoint())
841        );
842        assert_eq!(
843            check_same_errors(Replacement {
844                adoptions: HashMap::from_iter([(r1, case1), (r2, case1)]),
845                ..rep.clone()
846            }),
847            ReplaceError::AdopteesNotSeparateDescendants(vec![case1])
848        );
849        assert_eq!(
850            check_same_errors(Replacement {
851                adoptions: HashMap::from_iter([(r1, case2), (r2, baz_dfg.node())]),
852                ..rep.clone()
853            }),
854            ReplaceError::AdopteesNotSeparateDescendants(vec![baz_dfg.node()])
855        );
856        // Edges....
857        let edge_from_removed = NewEdgeSpec {
858            src: case1,
859            tgt: r2,
860            kind: NewEdgeKind::Order,
861        };
862        assert_eq!(
863            check_same_errors(Replacement {
864                mu_inp: vec![edge_from_removed.clone()],
865                ..rep.clone()
866            }),
867            ReplaceError::BadEdgeSpec(
868                Direction::Outgoing,
869                WhichEdgeSpec::HostToRepl(edge_from_removed)
870            )
871        );
872        let bad_out_edge = NewEdgeSpec {
873            src: h.nodes().max().unwrap(), // not valid in replacement
874            tgt: cond.node(),
875            kind: NewEdgeKind::Order,
876        };
877        assert_eq!(
878            check_same_errors(Replacement {
879                mu_out: vec![bad_out_edge.clone()],
880                ..rep.clone()
881            }),
882            ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichEdgeSpec::ReplToHost(bad_out_edge),)
883        );
884        let bad_order_edge = NewEdgeSpec {
885            src: cond.node(),
886            tgt: h.get_io(h.entrypoint()).unwrap()[1],
887            kind: NewEdgeKind::ControlFlow { src_pos: 0.into() },
888        };
889        assert_matches!(
890            check_same_errors(Replacement {
891                mu_new: vec![bad_order_edge.clone()],
892                ..rep.clone()
893            }),
894            ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, WhichEdgeSpec::HostToHost(bad_order_edge))
895        );
896        let op = OutgoingPort::from(0);
897        let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap();
898        let new_out_edge = NewEdgeSpec {
899            src: r1.node(),
900            tgt,
901            kind: NewEdgeKind::Value {
902                src_pos: op,
903                tgt_pos: ip,
904            },
905        };
906        assert_eq!(
907            check_same_errors(Replacement {
908                mu_out: vec![new_out_edge.clone()],
909                ..rep.clone()
910            }),
911            ReplaceError::BadEdgeKind(Direction::Outgoing, WhichEdgeSpec::ReplToHost(new_out_edge))
912        );
913    }
914}