hugr_core/hugr/
linking.rs

1//! Directives and errors relating to linking Hugrs.
2
3use std::{collections::HashMap, fmt::Display};
4
5use itertools::Either;
6
7use crate::{
8    Hugr, HugrView, Node,
9    core::HugrNode,
10    hugr::{HugrMut, hugrmut::InsertedForest, internal::HugrMutInternals},
11};
12
13/// Methods that merge Hugrs, adding static edges between old and inserted nodes.
14///
15/// This is done by module-children from the inserted (source) Hugr replacing, or being replaced by,
16/// module-children already in the target Hugr; static edges from the replaced node,
17/// are transferred to come from the replacing node, and the replaced node(/subtree) then deleted.
18pub trait HugrLinking: HugrMut {
19    /// Copy and link nodes from another Hugr into this one, with linking specified by Node.
20    ///
21    /// If `parent` is non-None, then `other`'s entrypoint-subtree is copied under it.
22    /// `children` of the Module root of `other` may also be inserted with their
23    /// subtrees or linked according to their [NodeLinkingDirective].
24    ///
25    /// # Errors
26    ///
27    /// * If `children` are not `children` of the root of `other`
28    /// * If `parent` is Some, and `other.entrypoint()` is either
29    ///   * among `children`, or
30    ///   * descends from an element of `children` with [NodeLinkingDirective::Add]
31    ///
32    /// # Panics
33    ///
34    /// If `parent` is `Some` but not in the graph.
35    #[allow(clippy::type_complexity)]
36    fn insert_link_view_by_node<H: HugrView>(
37        &mut self,
38        parent: Option<Self::Node>,
39        other: &H,
40        children: NodeLinkingDirectives<H::Node, Self::Node>,
41    ) -> Result<InsertedForest<H::Node, Self::Node>, NodeLinkingError<H::Node, Self::Node>> {
42        let transfers = check_directives(other, parent, &children)?;
43        let nodes =
44            parent
45                .iter()
46                .flat_map(|_| other.entry_descendants())
47                .chain(children.iter().flat_map(|(&ch, dirv)| match dirv {
48                    NodeLinkingDirective::Add { .. } => Either::Left(other.descendants(ch)),
49                    NodeLinkingDirective::UseExisting(_) => Either::Right(std::iter::once(ch)),
50                }));
51        let mut roots = HashMap::new();
52        if let Some(parent) = parent {
53            roots.insert(other.entrypoint(), parent);
54        }
55        for ch in children.keys() {
56            roots.insert(*ch, self.module_root());
57        }
58        let mut inserted = self
59            .insert_view_forest(other, nodes, roots)
60            .expect("NodeLinkingDirectives were checked for disjointness");
61        link_by_node(self, transfers, &mut inserted.node_map);
62        Ok(inserted)
63    }
64
65    /// Insert and link another Hugr into this one, with linking specified by Node.
66    ///
67    /// If `parent` is non-None, then `other`'s entrypoint-subtree is placed under it.
68    /// `children` of the Module root of `other` may also be inserted with their
69    /// subtrees or linked according to their [NodeLinkingDirective].
70    ///
71    /// # Errors
72    ///
73    /// * If `children` are not `children` of the root of `other`
74    /// * If `other`s entrypoint is among `children`, or descends from an element
75    ///   of `children` with [NodeLinkingDirective::Add]
76    ///
77    /// # Panics
78    ///
79    /// If `parent` is not in this graph.
80    fn insert_link_hugr_by_node(
81        &mut self,
82        parent: Option<Self::Node>,
83        mut other: Hugr,
84        children: NodeLinkingDirectives<Node, Self::Node>,
85    ) -> Result<InsertedForest<Node, Self::Node>, NodeLinkingError<Node, Self::Node>> {
86        let transfers = check_directives(&other, parent, &children)?;
87        let mut roots = HashMap::new();
88        if let Some(parent) = parent {
89            roots.insert(other.entrypoint(), parent);
90            other.set_parent(other.entrypoint(), other.module_root());
91        };
92        for (ch, dirv) in children.iter() {
93            roots.insert(*ch, self.module_root());
94            if matches!(dirv, NodeLinkingDirective::UseExisting(_)) {
95                // We do not need to copy the children of ch
96                while let Some(gch) = other.first_child(*ch) {
97                    // No point in deleting subtree, we won't copy disconnected nodes
98                    other.remove_node(gch);
99                }
100            }
101        }
102        let mut inserted = self
103            .insert_forest(other, roots)
104            .expect("NodeLinkingDirectives were checked for disjointness");
105        link_by_node(self, transfers, &mut inserted.node_map);
106        Ok(inserted)
107    }
108}
109
110impl<T: HugrMut> HugrLinking for T {}
111
112/// An error resulting from an [NodeLinkingDirective] passed to [HugrLinking::insert_link_hugr_by_node]
113/// or [HugrLinking::insert_link_view_by_node].
114///
115/// `SN` is the type of nodes in the source (inserted) Hugr; `TN` similarly for the target Hugr.
116#[derive(Clone, Debug, PartialEq, thiserror::Error)]
117#[non_exhaustive]
118pub enum NodeLinkingError<SN: Display = Node, TN: Display = Node> {
119    /// Inserting the whole Hugr, yet also asked to insert some of its children
120    /// (so the inserted Hugr's entrypoint was its module-root).
121    #[error(
122        "Cannot insert children (e.g. {_0}) when already inserting whole Hugr (entrypoint == module_root)"
123    )]
124    ChildOfEntrypoint(SN),
125    /// A module-child requested contained (or was) the entrypoint
126    #[error("Requested to insert module-child {_0} but this contains the entrypoint")]
127    ChildContainsEntrypoint(SN),
128    /// A module-child requested was not a child of the module root
129    #[error("{_0} was not a child of the module root")]
130    NotChildOfRoot(SN),
131    /// A node in the target Hugr was in a [NodeLinkingDirective::Add::replace] for multiple
132    /// inserted nodes (it is not clear to which we should transfer edges).
133    #[error("Target node {_0} is to be replaced by two source nodes {_1} and {_2}")]
134    NodeMultiplyReplaced(TN, SN, SN),
135}
136
137/// Directive for how to treat a particular module-child in the source Hugr.
138/// (TN is a node in the target Hugr.)
139#[derive(Clone, Debug, Hash, PartialEq, Eq)]
140#[non_exhaustive]
141pub enum NodeLinkingDirective<TN = Node> {
142    /// Insert the module-child (with subtree if any) into the target Hugr.
143    Add {
144        // TODO If non-None, change the name of the inserted function
145        //rename: Option<String>,
146        /// Existing/old nodes in the target which will be removed (with their subtrees),
147        /// and any static ([EdgeKind::Function]/[EdgeKind::Const]) edges from them changed
148        /// to leave the newly-inserted node instead. (Typically, this `Vec` would contain
149        /// at most one [FuncDefn], or perhaps-multiple, aliased, [FuncDecl]s.)
150        ///
151        /// [FuncDecl]: crate::ops::FuncDecl
152        /// [FuncDefn]: crate::ops::FuncDefn
153        /// [EdgeKind::Const]: crate::types::EdgeKind::Const
154        /// [EdgeKind::Function]: crate::types::EdgeKind::Function
155        replace: Vec<TN>,
156    },
157    /// Do not insert the node/subtree from the source, but for any static edge from it
158    /// to an inserted node, instead add an edge from the specified node already existing
159    /// in the target Hugr. (Static edges are [EdgeKind::Function] and [EdgeKind::Const].)
160    ///
161    /// [EdgeKind::Const]: crate::types::EdgeKind::Const
162    /// [EdgeKind::Function]: crate::types::EdgeKind::Function
163    UseExisting(TN),
164}
165
166impl<TN> NodeLinkingDirective<TN> {
167    /// Just add the node (and any subtree) into the target.
168    /// (Could lead to an invalid Hugr if the target Hugr
169    /// already has another with the same name and both are [Public])
170    ///
171    /// [Public]: crate::Visibility::Public
172    pub const fn add() -> Self {
173        Self::Add { replace: vec![] }
174    }
175
176    /// The new node should replace the specified node(s) already existing
177    /// in the target.
178    ///
179    /// (Could lead to an invalid Hugr if they have different signatures,
180    /// or if the target already has another function with the same name and both are public.)
181    pub fn replace(nodes: impl IntoIterator<Item = TN>) -> Self {
182        Self::Add {
183            replace: nodes.into_iter().collect(),
184        }
185    }
186}
187
188/// Details, node-by-node, how module-children of a source Hugr should be inserted into a
189/// target Hugr.
190///
191/// For use with [HugrLinking::insert_link_hugr_by_node] and [HugrLinking::insert_link_view_by_node].
192pub type NodeLinkingDirectives<SN, TN> = HashMap<SN, NodeLinkingDirective<TN>>;
193
194/// Invariant: no SourceNode can be in both maps (by type of [NodeLinkingDirective])
195/// TargetNodes can be (in RHS of multiple directives)
196struct Transfers<SourceNode, TargetNode> {
197    use_existing: HashMap<SourceNode, TargetNode>,
198    replace: HashMap<TargetNode, SourceNode>,
199}
200
201fn check_directives<SRC: HugrView, TN: HugrNode>(
202    other: &SRC,
203    parent: Option<TN>,
204    children: &HashMap<SRC::Node, NodeLinkingDirective<TN>>,
205) -> Result<Transfers<SRC::Node, TN>, NodeLinkingError<SRC::Node, TN>> {
206    if parent.is_some() {
207        if other.entrypoint() == other.module_root() {
208            if let Some(c) = children.keys().next() {
209                return Err(NodeLinkingError::ChildOfEntrypoint(*c));
210            }
211        } else {
212            let mut n = other.entrypoint();
213            if children.contains_key(&n) {
214                // If parent == hugr.module_root() and the directive is to Add, we could
215                // allow that - it amounts to two instructions to do the same thing.
216                // (If the directive is to UseExisting, then we'd have nothing to add
217                //  beneath parent! And if parent != hugr.module_root(), then not only
218                //  would we have to double-copy the entrypoint-subtree, but also
219                //  (unless n is a Const!) we would be creating an illegal Hugr.)
220                return Err(NodeLinkingError::ChildContainsEntrypoint(n));
221            }
222            while let Some(p) = other.get_parent(n) {
223                if matches!(children.get(&p), Some(NodeLinkingDirective::Add { .. })) {
224                    return Err(NodeLinkingError::ChildContainsEntrypoint(p));
225                }
226                n = p
227            }
228        }
229    }
230    let mut trns = Transfers {
231        replace: HashMap::default(),
232        use_existing: HashMap::default(),
233    };
234    for (&sn, dirv) in children {
235        if other.get_parent(sn) != Some(other.module_root()) {
236            return Err(NodeLinkingError::NotChildOfRoot(sn));
237        }
238        match dirv {
239            NodeLinkingDirective::Add { replace } => {
240                for &r in replace {
241                    if let Some(old_sn) = trns.replace.insert(r, sn) {
242                        return Err(NodeLinkingError::NodeMultiplyReplaced(r, old_sn, sn));
243                    }
244                }
245            }
246            NodeLinkingDirective::UseExisting(tn) => {
247                trns.use_existing.insert(sn, *tn);
248            }
249        }
250    }
251    Ok(trns)
252}
253
254fn link_by_node<SN: HugrNode, TGT: HugrLinking + ?Sized>(
255    hugr: &mut TGT,
256    transfers: Transfers<SN, TGT::Node>,
257    node_map: &mut HashMap<SN, TGT::Node>,
258) {
259    // Resolve `use_existing` first in case the existing node is also replaced by
260    // a new node (which we know will not be in RHS of any entry in `replace`).
261    for (sn, tn) in transfers.use_existing {
262        let copy = node_map.remove(&sn).unwrap();
263        // Because of `UseExisting` we avoided adding `sn`s descendants
264        debug_assert_eq!(hugr.children(copy).next(), None);
265        replace_static_src(hugr, copy, tn);
266    }
267    for (tn, sn) in transfers.replace {
268        let new_node = *node_map.get(&sn).unwrap();
269        replace_static_src(hugr, tn, new_node);
270    }
271}
272
273fn replace_static_src<H: HugrMut + ?Sized>(hugr: &mut H, old_src: H::Node, new_src: H::Node) {
274    let targets = hugr.all_linked_inputs(old_src).collect::<Vec<_>>();
275    for (target, inport) in targets {
276        let (src_node, outport) = hugr.single_linked_output(target, inport).unwrap();
277        debug_assert_eq!(src_node, old_src);
278        hugr.disconnect(target, inport);
279        hugr.connect(new_src, outport, target, inport);
280    }
281    hugr.remove_subtree(old_src);
282}
283
284#[cfg(test)]
285mod test {
286    use std::collections::HashMap;
287
288    use cool_asserts::assert_matches;
289    use itertools::Itertools;
290
291    use super::{HugrLinking, NodeLinkingDirective, NodeLinkingError};
292    use crate::builder::test::{dfg_calling_defn_decl, simple_dfg_hugr};
293    use crate::hugr::hugrmut::test::check_calls_defn_decl;
294    use crate::ops::{FuncDecl, OpTag, OpTrait, handle::NodeHandle};
295    use crate::{HugrView, hugr::HugrMut, types::Signature};
296
297    #[test]
298    fn test_insert_link_nodes_add() {
299        // Default (non-linking) methods...just for comparison
300        let (insert, _, _) = dfg_calling_defn_decl();
301
302        let mut h = simple_dfg_hugr();
303        h.insert_from_view(h.entrypoint(), &insert);
304        check_calls_defn_decl(&h, false, false);
305
306        let mut h = simple_dfg_hugr();
307        h.insert_hugr(h.entrypoint(), insert);
308        check_calls_defn_decl(&h, false, false);
309
310        // Specify which decls to transfer. No real "linking" here though.
311        for (call1, call2) in [(false, false), (false, true), (true, false), (true, true)] {
312            let (insert, defn, decl) = dfg_calling_defn_decl();
313            let mod_children = HashMap::from_iter(
314                call1
315                    .then_some((defn.node(), NodeLinkingDirective::add()))
316                    .into_iter()
317                    .chain(call2.then_some((decl.node(), NodeLinkingDirective::add()))),
318            );
319
320            let mut h = simple_dfg_hugr();
321            h.insert_link_view_by_node(Some(h.entrypoint()), &insert, mod_children.clone())
322                .unwrap();
323            check_calls_defn_decl(&h, call1, call2);
324
325            let mut h = simple_dfg_hugr();
326            h.insert_link_hugr_by_node(Some(h.entrypoint()), insert, mod_children)
327                .unwrap();
328            check_calls_defn_decl(&h, call1, call2);
329        }
330    }
331
332    #[test]
333    fn insert_link_nodes_replace() {
334        let (mut host, defn, decl) = dfg_calling_defn_decl();
335        assert_eq!(
336            host.children(host.module_root())
337                .map(|n| host.get_optype(n).tag())
338                .collect_vec(),
339            vec![OpTag::FuncDefn, OpTag::FuncDefn, OpTag::Function]
340        );
341        let insert = simple_dfg_hugr();
342        let dirvs = HashMap::from([(
343            insert
344                .children(insert.module_root())
345                .exactly_one()
346                .ok()
347                .unwrap(),
348            NodeLinkingDirective::Add {
349                replace: vec![defn.node(), decl.node()],
350            },
351        )]);
352        host.insert_link_hugr_by_node(None, insert, dirvs).unwrap();
353        host.validate().unwrap();
354        assert_eq!(
355            host.children(host.module_root())
356                .map(|n| host.get_optype(n).tag())
357                .collect_vec(),
358            vec![OpTag::FuncDefn; 2]
359        );
360    }
361
362    #[test]
363    fn insert_link_nodes_use_existing() {
364        let (insert, defn, decl) = dfg_calling_defn_decl();
365        let mut chmap =
366            HashMap::from([defn.node(), decl.node()].map(|n| (n, NodeLinkingDirective::add())));
367        let (h, node_map) = {
368            let mut h = simple_dfg_hugr();
369            let res = h
370                .insert_link_view_by_node(Some(h.entrypoint()), &insert, chmap.clone())
371                .unwrap();
372            (h, res.node_map)
373        };
374        h.validate().unwrap();
375        let num_nodes = h.num_nodes();
376        let num_ep_nodes = h.descendants(node_map[&insert.entrypoint()]).count();
377        let [inserted_defn, inserted_decl] = [defn.node(), decl.node()].map(|n| node_map[&n]);
378
379        // No reason we can't add the decl again, or replace the defn with the decl,
380        // but here we'll limit to the "interesting" (likely) cases
381        for decl_replacement in [inserted_defn, inserted_decl] {
382            let decl_mode = NodeLinkingDirective::UseExisting(decl_replacement);
383            chmap.insert(decl.node(), decl_mode);
384            for defn_mode in [
385                NodeLinkingDirective::add(),
386                NodeLinkingDirective::UseExisting(inserted_defn),
387            ] {
388                chmap.insert(defn.node(), defn_mode.clone());
389                let mut h = h.clone();
390                h.insert_link_hugr_by_node(Some(h.entrypoint()), insert.clone(), chmap.clone())
391                    .unwrap();
392                h.validate().unwrap();
393                if defn_mode != NodeLinkingDirective::add() {
394                    assert_eq!(h.num_nodes(), num_nodes + num_ep_nodes);
395                }
396                assert_eq!(
397                    h.children(h.module_root()).count(),
398                    3 + (defn_mode == NodeLinkingDirective::add()) as usize
399                );
400                let expected_defn_uses = 1
401                    + (defn_mode == NodeLinkingDirective::UseExisting(inserted_defn)) as usize
402                    + (decl_replacement == inserted_defn) as usize;
403                assert_eq!(
404                    h.static_targets(inserted_defn).unwrap().count(),
405                    expected_defn_uses
406                );
407                assert_eq!(
408                    h.static_targets(inserted_decl).unwrap().count(),
409                    1 + (decl_replacement == inserted_decl) as usize
410                );
411            }
412        }
413    }
414
415    #[test]
416    fn bad_insert_link_nodes() {
417        let backup = simple_dfg_hugr();
418        let mut h = backup.clone();
419
420        let (insert, defn, decl) = dfg_calling_defn_decl();
421        let (defn, decl) = (defn.node(), decl.node());
422
423        let epp = insert.get_parent(insert.entrypoint()).unwrap();
424        let r = h.insert_link_view_by_node(
425            Some(h.entrypoint()),
426            &insert,
427            HashMap::from([(epp, NodeLinkingDirective::add())]),
428        );
429        assert_eq!(
430            r.err().unwrap(),
431            NodeLinkingError::ChildContainsEntrypoint(epp)
432        );
433        assert_eq!(h, backup);
434
435        let [inp, _] = insert.get_io(defn).unwrap();
436        let r = h.insert_link_view_by_node(
437            Some(h.entrypoint()),
438            &insert,
439            HashMap::from([(inp, NodeLinkingDirective::add())]),
440        );
441        assert_eq!(r.err().unwrap(), NodeLinkingError::NotChildOfRoot(inp));
442        assert_eq!(h, backup);
443
444        let mut insert = insert;
445        insert.set_entrypoint(defn);
446        let r = h.insert_link_view_by_node(
447            Some(h.module_root()),
448            &insert,
449            HashMap::from([(
450                defn,
451                NodeLinkingDirective::UseExisting(h.get_parent(h.entrypoint()).unwrap()),
452            )]),
453        );
454        assert_eq!(
455            r.err().unwrap(),
456            NodeLinkingError::ChildContainsEntrypoint(defn)
457        );
458        assert_eq!(h, backup);
459
460        insert.set_entrypoint(insert.module_root());
461        let r = h.insert_link_hugr_by_node(
462            Some(h.module_root()),
463            insert,
464            HashMap::from([(decl, NodeLinkingDirective::add())]),
465        );
466        assert_eq!(r.err().unwrap(), NodeLinkingError::ChildOfEntrypoint(decl));
467        assert_eq!(h, backup);
468
469        let (insert, defn, decl) = dfg_calling_defn_decl();
470        let sig = insert
471            .get_optype(defn.node())
472            .as_func_defn()
473            .unwrap()
474            .signature()
475            .clone();
476        let tmp = h.add_node_with_parent(h.module_root(), FuncDecl::new("replaced", sig));
477        let r = h.insert_link_hugr_by_node(
478            Some(h.entrypoint()),
479            insert,
480            HashMap::from([
481                (decl.node(), NodeLinkingDirective::replace([tmp])),
482                (defn.node(), NodeLinkingDirective::replace([tmp])),
483            ]),
484        );
485        assert_matches!(
486            r.err().unwrap(),
487            NodeLinkingError::NodeMultiplyReplaced(tn, sn1, sn2) => {
488                assert_eq!(tmp, tn);
489                assert_eq!([sn1,sn2].into_iter().sorted().collect_vec(), [defn.node(), decl.node()]);
490        });
491    }
492
493    #[test]
494    fn test_replace_used() {
495        let mut h = simple_dfg_hugr();
496        let temp = h.add_node_with_parent(
497            h.module_root(),
498            FuncDecl::new("temp", Signature::new_endo(vec![])),
499        );
500
501        let (insert, defn, decl) = dfg_calling_defn_decl();
502        let node_map = h
503            .insert_link_hugr_by_node(
504                Some(h.entrypoint()),
505                insert,
506                HashMap::from([
507                    (defn.node(), NodeLinkingDirective::replace([temp])),
508                    (decl.node(), NodeLinkingDirective::UseExisting(temp)),
509                ]),
510            )
511            .unwrap()
512            .node_map;
513        let defn = node_map[&defn.node()];
514        assert_eq!(node_map.get(&decl.node()), None);
515        assert!(!h.contains_node(temp));
516
517        assert!(
518            h.children(h.module_root())
519                .all(|n| h.get_optype(n).is_func_defn())
520        );
521        for call in h.nodes().filter(|n| h.get_optype(*n).is_call()) {
522            assert_eq!(h.static_source(call), Some(defn));
523        }
524    }
525}