1use std::{collections::HashMap, fmt::Display};
4
5use itertools::Either;
6
7use crate::{
8 Hugr, HugrView, Node,
9 core::HugrNode,
10 hugr::{HugrMut, hugrmut::InsertedForest, internal::HugrMutInternals},
11};
12
13pub trait HugrLinking: HugrMut {
19 #[allow(clippy::type_complexity)]
36 fn insert_link_view_by_node<H: HugrView>(
37 &mut self,
38 parent: Option<Self::Node>,
39 other: &H,
40 children: NodeLinkingDirectives<H::Node, Self::Node>,
41 ) -> Result<InsertedForest<H::Node, Self::Node>, NodeLinkingError<H::Node, Self::Node>> {
42 let transfers = check_directives(other, parent, &children)?;
43 let nodes =
44 parent
45 .iter()
46 .flat_map(|_| other.entry_descendants())
47 .chain(children.iter().flat_map(|(&ch, dirv)| match dirv {
48 NodeLinkingDirective::Add { .. } => Either::Left(other.descendants(ch)),
49 NodeLinkingDirective::UseExisting(_) => Either::Right(std::iter::once(ch)),
50 }));
51 let mut roots = HashMap::new();
52 if let Some(parent) = parent {
53 roots.insert(other.entrypoint(), parent);
54 }
55 for ch in children.keys() {
56 roots.insert(*ch, self.module_root());
57 }
58 let mut inserted = self
59 .insert_view_forest(other, nodes, roots)
60 .expect("NodeLinkingDirectives were checked for disjointness");
61 link_by_node(self, transfers, &mut inserted.node_map);
62 Ok(inserted)
63 }
64
65 fn insert_link_hugr_by_node(
81 &mut self,
82 parent: Option<Self::Node>,
83 mut other: Hugr,
84 children: NodeLinkingDirectives<Node, Self::Node>,
85 ) -> Result<InsertedForest<Node, Self::Node>, NodeLinkingError<Node, Self::Node>> {
86 let transfers = check_directives(&other, parent, &children)?;
87 let mut roots = HashMap::new();
88 if let Some(parent) = parent {
89 roots.insert(other.entrypoint(), parent);
90 other.set_parent(other.entrypoint(), other.module_root());
91 };
92 for (ch, dirv) in children.iter() {
93 roots.insert(*ch, self.module_root());
94 if matches!(dirv, NodeLinkingDirective::UseExisting(_)) {
95 while let Some(gch) = other.first_child(*ch) {
97 other.remove_node(gch);
99 }
100 }
101 }
102 let mut inserted = self
103 .insert_forest(other, roots)
104 .expect("NodeLinkingDirectives were checked for disjointness");
105 link_by_node(self, transfers, &mut inserted.node_map);
106 Ok(inserted)
107 }
108}
109
110impl<T: HugrMut> HugrLinking for T {}
111
112#[derive(Clone, Debug, PartialEq, thiserror::Error)]
117#[non_exhaustive]
118pub enum NodeLinkingError<SN: Display = Node, TN: Display = Node> {
119 #[error(
122 "Cannot insert children (e.g. {_0}) when already inserting whole Hugr (entrypoint == module_root)"
123 )]
124 ChildOfEntrypoint(SN),
125 #[error("Requested to insert module-child {_0} but this contains the entrypoint")]
127 ChildContainsEntrypoint(SN),
128 #[error("{_0} was not a child of the module root")]
130 NotChildOfRoot(SN),
131 #[error("Target node {_0} is to be replaced by two source nodes {_1} and {_2}")]
134 NodeMultiplyReplaced(TN, SN, SN),
135}
136
137#[derive(Clone, Debug, Hash, PartialEq, Eq)]
140#[non_exhaustive]
141pub enum NodeLinkingDirective<TN = Node> {
142 Add {
144 replace: Vec<TN>,
156 },
157 UseExisting(TN),
164}
165
166impl<TN> NodeLinkingDirective<TN> {
167 pub const fn add() -> Self {
173 Self::Add { replace: vec![] }
174 }
175
176 pub fn replace(nodes: impl IntoIterator<Item = TN>) -> Self {
182 Self::Add {
183 replace: nodes.into_iter().collect(),
184 }
185 }
186}
187
188pub type NodeLinkingDirectives<SN, TN> = HashMap<SN, NodeLinkingDirective<TN>>;
193
194struct Transfers<SourceNode, TargetNode> {
197 use_existing: HashMap<SourceNode, TargetNode>,
198 replace: HashMap<TargetNode, SourceNode>,
199}
200
201fn check_directives<SRC: HugrView, TN: HugrNode>(
202 other: &SRC,
203 parent: Option<TN>,
204 children: &HashMap<SRC::Node, NodeLinkingDirective<TN>>,
205) -> Result<Transfers<SRC::Node, TN>, NodeLinkingError<SRC::Node, TN>> {
206 if parent.is_some() {
207 if other.entrypoint() == other.module_root() {
208 if let Some(c) = children.keys().next() {
209 return Err(NodeLinkingError::ChildOfEntrypoint(*c));
210 }
211 } else {
212 let mut n = other.entrypoint();
213 if children.contains_key(&n) {
214 return Err(NodeLinkingError::ChildContainsEntrypoint(n));
221 }
222 while let Some(p) = other.get_parent(n) {
223 if matches!(children.get(&p), Some(NodeLinkingDirective::Add { .. })) {
224 return Err(NodeLinkingError::ChildContainsEntrypoint(p));
225 }
226 n = p
227 }
228 }
229 }
230 let mut trns = Transfers {
231 replace: HashMap::default(),
232 use_existing: HashMap::default(),
233 };
234 for (&sn, dirv) in children {
235 if other.get_parent(sn) != Some(other.module_root()) {
236 return Err(NodeLinkingError::NotChildOfRoot(sn));
237 }
238 match dirv {
239 NodeLinkingDirective::Add { replace } => {
240 for &r in replace {
241 if let Some(old_sn) = trns.replace.insert(r, sn) {
242 return Err(NodeLinkingError::NodeMultiplyReplaced(r, old_sn, sn));
243 }
244 }
245 }
246 NodeLinkingDirective::UseExisting(tn) => {
247 trns.use_existing.insert(sn, *tn);
248 }
249 }
250 }
251 Ok(trns)
252}
253
254fn link_by_node<SN: HugrNode, TGT: HugrLinking + ?Sized>(
255 hugr: &mut TGT,
256 transfers: Transfers<SN, TGT::Node>,
257 node_map: &mut HashMap<SN, TGT::Node>,
258) {
259 for (sn, tn) in transfers.use_existing {
262 let copy = node_map.remove(&sn).unwrap();
263 debug_assert_eq!(hugr.children(copy).next(), None);
265 replace_static_src(hugr, copy, tn);
266 }
267 for (tn, sn) in transfers.replace {
268 let new_node = *node_map.get(&sn).unwrap();
269 replace_static_src(hugr, tn, new_node);
270 }
271}
272
273fn replace_static_src<H: HugrMut + ?Sized>(hugr: &mut H, old_src: H::Node, new_src: H::Node) {
274 let targets = hugr.all_linked_inputs(old_src).collect::<Vec<_>>();
275 for (target, inport) in targets {
276 let (src_node, outport) = hugr.single_linked_output(target, inport).unwrap();
277 debug_assert_eq!(src_node, old_src);
278 hugr.disconnect(target, inport);
279 hugr.connect(new_src, outport, target, inport);
280 }
281 hugr.remove_subtree(old_src);
282}
283
284#[cfg(test)]
285mod test {
286 use std::collections::HashMap;
287
288 use cool_asserts::assert_matches;
289 use itertools::Itertools;
290
291 use super::{HugrLinking, NodeLinkingDirective, NodeLinkingError};
292 use crate::builder::test::{dfg_calling_defn_decl, simple_dfg_hugr};
293 use crate::hugr::hugrmut::test::check_calls_defn_decl;
294 use crate::ops::{FuncDecl, OpTag, OpTrait, handle::NodeHandle};
295 use crate::{HugrView, hugr::HugrMut, types::Signature};
296
297 #[test]
298 fn test_insert_link_nodes_add() {
299 let (insert, _, _) = dfg_calling_defn_decl();
301
302 let mut h = simple_dfg_hugr();
303 h.insert_from_view(h.entrypoint(), &insert);
304 check_calls_defn_decl(&h, false, false);
305
306 let mut h = simple_dfg_hugr();
307 h.insert_hugr(h.entrypoint(), insert);
308 check_calls_defn_decl(&h, false, false);
309
310 for (call1, call2) in [(false, false), (false, true), (true, false), (true, true)] {
312 let (insert, defn, decl) = dfg_calling_defn_decl();
313 let mod_children = HashMap::from_iter(
314 call1
315 .then_some((defn.node(), NodeLinkingDirective::add()))
316 .into_iter()
317 .chain(call2.then_some((decl.node(), NodeLinkingDirective::add()))),
318 );
319
320 let mut h = simple_dfg_hugr();
321 h.insert_link_view_by_node(Some(h.entrypoint()), &insert, mod_children.clone())
322 .unwrap();
323 check_calls_defn_decl(&h, call1, call2);
324
325 let mut h = simple_dfg_hugr();
326 h.insert_link_hugr_by_node(Some(h.entrypoint()), insert, mod_children)
327 .unwrap();
328 check_calls_defn_decl(&h, call1, call2);
329 }
330 }
331
332 #[test]
333 fn insert_link_nodes_replace() {
334 let (mut host, defn, decl) = dfg_calling_defn_decl();
335 assert_eq!(
336 host.children(host.module_root())
337 .map(|n| host.get_optype(n).tag())
338 .collect_vec(),
339 vec![OpTag::FuncDefn, OpTag::FuncDefn, OpTag::Function]
340 );
341 let insert = simple_dfg_hugr();
342 let dirvs = HashMap::from([(
343 insert
344 .children(insert.module_root())
345 .exactly_one()
346 .ok()
347 .unwrap(),
348 NodeLinkingDirective::Add {
349 replace: vec![defn.node(), decl.node()],
350 },
351 )]);
352 host.insert_link_hugr_by_node(None, insert, dirvs).unwrap();
353 host.validate().unwrap();
354 assert_eq!(
355 host.children(host.module_root())
356 .map(|n| host.get_optype(n).tag())
357 .collect_vec(),
358 vec![OpTag::FuncDefn; 2]
359 );
360 }
361
362 #[test]
363 fn insert_link_nodes_use_existing() {
364 let (insert, defn, decl) = dfg_calling_defn_decl();
365 let mut chmap =
366 HashMap::from([defn.node(), decl.node()].map(|n| (n, NodeLinkingDirective::add())));
367 let (h, node_map) = {
368 let mut h = simple_dfg_hugr();
369 let res = h
370 .insert_link_view_by_node(Some(h.entrypoint()), &insert, chmap.clone())
371 .unwrap();
372 (h, res.node_map)
373 };
374 h.validate().unwrap();
375 let num_nodes = h.num_nodes();
376 let num_ep_nodes = h.descendants(node_map[&insert.entrypoint()]).count();
377 let [inserted_defn, inserted_decl] = [defn.node(), decl.node()].map(|n| node_map[&n]);
378
379 for decl_replacement in [inserted_defn, inserted_decl] {
382 let decl_mode = NodeLinkingDirective::UseExisting(decl_replacement);
383 chmap.insert(decl.node(), decl_mode);
384 for defn_mode in [
385 NodeLinkingDirective::add(),
386 NodeLinkingDirective::UseExisting(inserted_defn),
387 ] {
388 chmap.insert(defn.node(), defn_mode.clone());
389 let mut h = h.clone();
390 h.insert_link_hugr_by_node(Some(h.entrypoint()), insert.clone(), chmap.clone())
391 .unwrap();
392 h.validate().unwrap();
393 if defn_mode != NodeLinkingDirective::add() {
394 assert_eq!(h.num_nodes(), num_nodes + num_ep_nodes);
395 }
396 assert_eq!(
397 h.children(h.module_root()).count(),
398 3 + (defn_mode == NodeLinkingDirective::add()) as usize
399 );
400 let expected_defn_uses = 1
401 + (defn_mode == NodeLinkingDirective::UseExisting(inserted_defn)) as usize
402 + (decl_replacement == inserted_defn) as usize;
403 assert_eq!(
404 h.static_targets(inserted_defn).unwrap().count(),
405 expected_defn_uses
406 );
407 assert_eq!(
408 h.static_targets(inserted_decl).unwrap().count(),
409 1 + (decl_replacement == inserted_decl) as usize
410 );
411 }
412 }
413 }
414
415 #[test]
416 fn bad_insert_link_nodes() {
417 let backup = simple_dfg_hugr();
418 let mut h = backup.clone();
419
420 let (insert, defn, decl) = dfg_calling_defn_decl();
421 let (defn, decl) = (defn.node(), decl.node());
422
423 let epp = insert.get_parent(insert.entrypoint()).unwrap();
424 let r = h.insert_link_view_by_node(
425 Some(h.entrypoint()),
426 &insert,
427 HashMap::from([(epp, NodeLinkingDirective::add())]),
428 );
429 assert_eq!(
430 r.err().unwrap(),
431 NodeLinkingError::ChildContainsEntrypoint(epp)
432 );
433 assert_eq!(h, backup);
434
435 let [inp, _] = insert.get_io(defn).unwrap();
436 let r = h.insert_link_view_by_node(
437 Some(h.entrypoint()),
438 &insert,
439 HashMap::from([(inp, NodeLinkingDirective::add())]),
440 );
441 assert_eq!(r.err().unwrap(), NodeLinkingError::NotChildOfRoot(inp));
442 assert_eq!(h, backup);
443
444 let mut insert = insert;
445 insert.set_entrypoint(defn);
446 let r = h.insert_link_view_by_node(
447 Some(h.module_root()),
448 &insert,
449 HashMap::from([(
450 defn,
451 NodeLinkingDirective::UseExisting(h.get_parent(h.entrypoint()).unwrap()),
452 )]),
453 );
454 assert_eq!(
455 r.err().unwrap(),
456 NodeLinkingError::ChildContainsEntrypoint(defn)
457 );
458 assert_eq!(h, backup);
459
460 insert.set_entrypoint(insert.module_root());
461 let r = h.insert_link_hugr_by_node(
462 Some(h.module_root()),
463 insert,
464 HashMap::from([(decl, NodeLinkingDirective::add())]),
465 );
466 assert_eq!(r.err().unwrap(), NodeLinkingError::ChildOfEntrypoint(decl));
467 assert_eq!(h, backup);
468
469 let (insert, defn, decl) = dfg_calling_defn_decl();
470 let sig = insert
471 .get_optype(defn.node())
472 .as_func_defn()
473 .unwrap()
474 .signature()
475 .clone();
476 let tmp = h.add_node_with_parent(h.module_root(), FuncDecl::new("replaced", sig));
477 let r = h.insert_link_hugr_by_node(
478 Some(h.entrypoint()),
479 insert,
480 HashMap::from([
481 (decl.node(), NodeLinkingDirective::replace([tmp])),
482 (defn.node(), NodeLinkingDirective::replace([tmp])),
483 ]),
484 );
485 assert_matches!(
486 r.err().unwrap(),
487 NodeLinkingError::NodeMultiplyReplaced(tn, sn1, sn2) => {
488 assert_eq!(tmp, tn);
489 assert_eq!([sn1,sn2].into_iter().sorted().collect_vec(), [defn.node(), decl.node()]);
490 });
491 }
492
493 #[test]
494 fn test_replace_used() {
495 let mut h = simple_dfg_hugr();
496 let temp = h.add_node_with_parent(
497 h.module_root(),
498 FuncDecl::new("temp", Signature::new_endo(vec![])),
499 );
500
501 let (insert, defn, decl) = dfg_calling_defn_decl();
502 let node_map = h
503 .insert_link_hugr_by_node(
504 Some(h.entrypoint()),
505 insert,
506 HashMap::from([
507 (defn.node(), NodeLinkingDirective::replace([temp])),
508 (decl.node(), NodeLinkingDirective::UseExisting(temp)),
509 ]),
510 )
511 .unwrap()
512 .node_map;
513 let defn = node_map[&defn.node()];
514 assert_eq!(node_map.get(&decl.node()), None);
515 assert!(!h.contains_node(temp));
516
517 assert!(
518 h.children(h.module_root())
519 .all(|n| h.get_optype(n).is_func_defn())
520 );
521 for call in h.nodes().filter(|n| h.get_optype(*n).is_call()) {
522 assert_eq!(h.static_source(call), Some(defn));
523 }
524 }
525}