1use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::sync::Arc;
5
6use portgraph::{LinkMut, PortMut, PortView, SecondaryMap};
7
8use crate::core::HugrNode;
9use crate::extension::ExtensionRegistry;
10use crate::hugr::views::SiblingSubgraph;
11use crate::hugr::{HugrView, Node, OpType};
12use crate::hugr::{NodeMetadata, Patch};
13use crate::ops::OpTrait;
14use crate::types::Substitution;
15use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex};
16
17use super::internal::HugrMutInternals;
18use super::views::{
19 Rerooted, panic_invalid_node, panic_invalid_non_entrypoint, panic_invalid_port,
20};
21
22pub trait HugrMut: HugrMutInternals {
24 fn set_entrypoint(&mut self, root: Self::Node);
40
41 fn with_entrypoint_mut(&mut self, entrypoint: Self::Node) -> Rerooted<&mut Self>
52 where
53 Self: Sized,
54 {
55 Rerooted::new(self, entrypoint)
56 }
57
58 fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef<str>) -> &mut NodeMetadata;
64
65 fn set_metadata(
71 &mut self,
72 node: Self::Node,
73 key: impl AsRef<str>,
74 metadata: impl Into<NodeMetadata>,
75 );
76
77 fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef<str>);
83
84 fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into<OpType>) -> Self::Node;
92
93 fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into<OpType>) -> Self::Node;
101
102 fn add_node_after(&mut self, sibling: Self::Node, op: impl Into<OpType>) -> Self::Node;
110
111 fn remove_node(&mut self, node: Self::Node) -> OpType;
119
120 fn remove_subtree(&mut self, node: Self::Node);
126
127 fn copy_descendants(
140 &mut self,
141 root: Self::Node,
142 new_parent: Self::Node,
143 subst: Option<Substitution>,
144 ) -> BTreeMap<Self::Node, Self::Node>;
145
146 fn connect(
152 &mut self,
153 src: Self::Node,
154 src_port: impl Into<OutgoingPort>,
155 dst: Self::Node,
156 dst_port: impl Into<IncomingPort>,
157 );
158
159 fn disconnect(&mut self, node: Self::Node, port: impl Into<Port>);
167
168 fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort);
180
181 fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult<Node, Self::Node> {
187 let region = other.entrypoint();
188 Self::insert_region(self, root, other, region)
189 }
190
191 fn insert_region(
198 &mut self,
199 root: Self::Node,
200 other: Hugr,
201 region: Node,
202 ) -> InsertionResult<Node, Self::Node>;
203
204 fn insert_from_view<H: HugrView>(
210 &mut self,
211 root: Self::Node,
212 other: &H,
213 ) -> InsertionResult<H::Node, Self::Node>;
214
215 fn insert_subgraph<H: HugrView>(
230 &mut self,
231 root: Self::Node,
232 other: &H,
233 subgraph: &SiblingSubgraph<H::Node>,
234 ) -> HashMap<H::Node, Self::Node>;
235
236 fn apply_patch<R, E>(&mut self, rw: impl Patch<Self, Outcome = R, Error = E>) -> Result<R, E>
238 where
239 Self: Sized,
240 {
241 rw.apply(self)
242 }
243
244 fn use_extension(&mut self, extension: impl Into<Arc<Extension>>);
251
252 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
262 where
263 ExtensionRegistry: Extend<Reg>;
264}
265
266pub struct InsertionResult<SourceN = Node, TargetN = Node> {
272 pub inserted_entrypoint: TargetN,
278 pub node_map: HashMap<SourceN, TargetN>,
281}
282
283fn translate_indices<N: HugrNode>(
289 mut source_node: impl FnMut(portgraph::NodeIndex) -> N,
290 mut target_node: impl FnMut(portgraph::NodeIndex) -> Node,
291 node_map: HashMap<portgraph::NodeIndex, portgraph::NodeIndex>,
292) -> impl Iterator<Item = (N, Node)> {
293 node_map
294 .into_iter()
295 .map(move |(k, v)| (source_node(k), target_node(v)))
296}
297
298impl HugrMut for Hugr {
300 #[inline]
301 fn set_entrypoint(&mut self, root: Node) {
302 panic_invalid_node(self, root);
303 self.entrypoint = root.into_portgraph();
304 }
305
306 fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef<str>) -> &mut NodeMetadata {
307 panic_invalid_node(self, node);
308 self.node_metadata_map_mut(node)
309 .entry(key.as_ref())
310 .or_insert(serde_json::Value::Null)
311 }
312
313 fn set_metadata(
314 &mut self,
315 node: Self::Node,
316 key: impl AsRef<str>,
317 metadata: impl Into<NodeMetadata>,
318 ) {
319 let entry = self.get_metadata_mut(node, key);
320 *entry = metadata.into();
321 }
322
323 fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef<str>) {
324 panic_invalid_node(self, node);
325 let node_meta = self.node_metadata_map_mut(node);
326 node_meta.remove(key.as_ref());
327 }
328
329 fn add_node_with_parent(&mut self, parent: Node, node: impl Into<OpType>) -> Node {
330 let node = self.as_mut().add_node(node.into());
331 self.hierarchy
332 .push_child(node.into_portgraph(), parent.into_portgraph())
333 .expect("Inserting a newly-created node into the hierarchy should never fail.");
334 node
335 }
336
337 fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
338 let node = self.as_mut().add_node(nodetype.into());
339 self.hierarchy
340 .insert_before(node.into_portgraph(), sibling.into_portgraph())
341 .expect("Inserting a newly-created node into the hierarchy should never fail.");
342 node
343 }
344
345 fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
346 let node = self.as_mut().add_node(op.into());
347 self.hierarchy
348 .insert_after(node.into_portgraph(), sibling.into_portgraph())
349 .expect("Inserting a newly-created node into the hierarchy should never fail.");
350 node
351 }
352
353 fn remove_node(&mut self, node: Node) -> OpType {
354 panic_invalid_non_entrypoint(self, node);
355 self.hierarchy.remove(node.into_portgraph());
356 self.graph.remove_node(node.into_portgraph());
357 self.op_types.take(node.into_portgraph())
358 }
359
360 fn remove_subtree(&mut self, node: Node) {
361 panic_invalid_non_entrypoint(self, node);
362 let mut queue = VecDeque::new();
363 queue.push_back(node);
364 while let Some(n) = queue.pop_front() {
365 queue.extend(self.children(n));
366 self.remove_node(n);
367 }
368 }
369
370 fn connect(
371 &mut self,
372 src: Node,
373 src_port: impl Into<OutgoingPort>,
374 dst: Node,
375 dst_port: impl Into<IncomingPort>,
376 ) {
377 let src_port = src_port.into();
378 let dst_port = dst_port.into();
379 panic_invalid_port(self, src, src_port);
380 panic_invalid_port(self, dst, dst_port);
381 self.graph
382 .link_nodes(
383 src.into_portgraph(),
384 src_port.index(),
385 dst.into_portgraph(),
386 dst_port.index(),
387 )
388 .expect("The ports should exist at this point.");
389 }
390
391 fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
392 let port = port.into();
393 let offset = port.pg_offset();
394 panic_invalid_port(self, node, port);
395 let port = self
396 .graph
397 .port_index(node.into_portgraph(), offset)
398 .expect("The port should exist at this point.");
399 self.graph.unlink_port(port);
400 }
401
402 fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
403 let src_port = self
404 .get_optype(src)
405 .other_output_port()
406 .expect("Source operation has no non-dataflow outgoing edges");
407 let dst_port = self
408 .get_optype(dst)
409 .other_input_port()
410 .expect("Destination operation has no non-dataflow incoming edges");
411 self.connect(src, src_port, dst, dst_port);
412 (src_port, dst_port)
413 }
414
415 fn insert_region(
416 &mut self,
417 root: Self::Node,
418 mut other: Hugr,
419 region: Node,
420 ) -> InsertionResult<Node, Self::Node> {
421 let node_map = insert_hugr_internal(self, &other, other.descendants(region), |&n| {
422 if n == region { Some(root) } else { None }
423 });
424 self.extensions.extend(other.extensions());
426 for (&node, &new_node) in &node_map {
430 let node_pg = node.into_portgraph();
431 let new_node_pg = new_node.into_portgraph();
432 let optype = other.op_types.take(node_pg);
433 self.op_types.set(new_node_pg, optype);
434 let meta = other.metadata.take(node_pg);
435 self.metadata.set(new_node_pg, meta);
436 }
437 InsertionResult {
438 inserted_entrypoint: node_map[®ion],
439 node_map,
440 }
441 }
442
443 fn insert_from_view<H: HugrView>(
444 &mut self,
445 root: Self::Node,
446 other: &H,
447 ) -> InsertionResult<H::Node, Self::Node> {
448 let node_map = insert_hugr_internal(self, other, other.entry_descendants(), |&n| {
449 if n == other.entrypoint() {
450 Some(root)
451 } else {
452 None
453 }
454 });
455 self.extensions.extend(other.extensions());
457 for (&node, &new_node) in &node_map {
461 let nodetype = other.get_optype(node);
462 self.op_types
463 .set(new_node.into_portgraph(), nodetype.clone());
464 let meta = other.node_metadata_map(node);
465 if !meta.is_empty() {
466 self.metadata
467 .set(new_node.into_portgraph(), Some(meta.clone()));
468 }
469 }
470 InsertionResult {
471 inserted_entrypoint: node_map[&other.entrypoint()],
472 node_map,
473 }
474 }
475
476 fn insert_subgraph<H: HugrView>(
477 &mut self,
478 root: Self::Node,
479 other: &H,
480 subgraph: &SiblingSubgraph<H::Node>,
481 ) -> HashMap<H::Node, Self::Node> {
482 let node_map = insert_hugr_internal(self, other, subgraph.nodes().iter().copied(), |_| {
483 Some(root)
484 });
485 for (&node, &new_node) in &node_map {
487 let nodetype = other.get_optype(node);
488 self.op_types
489 .set(new_node.into_portgraph(), nodetype.clone());
490 let meta = other.node_metadata_map(node);
491 if !meta.is_empty() {
492 self.metadata
493 .set(new_node.into_portgraph(), Some(meta.clone()));
494 }
495 if let Ok(exts) = nodetype.used_extensions() {
497 self.use_extensions(exts);
498 }
499 }
500 node_map
501 }
502
503 fn copy_descendants(
504 &mut self,
505 root: Self::Node,
506 new_parent: Self::Node,
507 subst: Option<Substitution>,
508 ) -> BTreeMap<Self::Node, Self::Node> {
509 let mut descendants = self.hierarchy.descendants(root.into_portgraph());
510 let root2 = descendants.next();
511 debug_assert_eq!(root2, Some(root.into_portgraph()));
512 let nodes = Vec::from_iter(descendants);
513 let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes)
514 .copy_in_parent()
515 .expect("Is a MultiPortGraph");
516 let node_map =
517 translate_indices(Into::into, Into::into, node_map).collect::<BTreeMap<_, _>>();
518
519 for node in self.children(root).collect::<Vec<_>>() {
520 self.set_parent(*node_map.get(&node).unwrap(), new_parent);
521 }
522
523 for (&node, &new_node) in &node_map {
525 for ch in self.children(node).collect::<Vec<_>>() {
526 self.set_parent(*node_map.get(&ch).unwrap(), new_node);
527 }
528 let new_optype = match (&subst, self.get_optype(node)) {
529 (None, op) => op.clone(),
530 (Some(subst), op) => op.substitute(subst),
531 };
532 self.op_types.set(new_node.into_portgraph(), new_optype);
533 let meta = self.metadata.get(node.into_portgraph()).clone();
534 self.metadata.set(new_node.into_portgraph(), meta);
535 }
536 node_map
537 }
538
539 #[inline]
540 fn use_extension(&mut self, extension: impl Into<Arc<Extension>>) {
541 self.extensions_mut().register_updated(extension);
542 }
543
544 #[inline]
545 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
546 where
547 ExtensionRegistry: Extend<Reg>,
548 {
549 self.extensions_mut().extend(registry);
550 }
551}
552
553fn insert_hugr_internal<H: HugrView>(
572 hugr: &mut Hugr,
573 other: &H,
574 other_nodes: impl Iterator<Item = H::Node>,
575 reroot: impl Fn(&H::Node) -> Option<Node>,
576) -> HashMap<H::Node, Node> {
577 let new_node_count_hint = other_nodes.size_hint().1.unwrap_or_default();
578
579 let mut node_map = HashMap::with_capacity(new_node_count_hint);
581 hugr.reserve(new_node_count_hint, 0);
582
583 for old in other_nodes {
584 let op = OpType::default();
587 let new = hugr.add_node(op);
588 node_map.insert(old, new);
589
590 hugr.set_num_ports(new, other.num_inputs(old), other.num_outputs(old));
591
592 let new_parent = if let Some(new_parent) = reroot(&old) {
593 new_parent
594 } else {
595 let old_parent = other.get_parent(old).unwrap();
596 *node_map
597 .get(&old_parent)
598 .expect("Child node came before parent in `other_nodes` iterator")
599 };
600 hugr.set_parent(new, new_parent);
601
602 for tgt in other.node_inputs(old) {
604 for (neigh, src) in other.linked_outputs(old, tgt) {
605 let Some(&neigh) = node_map.get(&neigh) else {
606 continue;
607 };
608 hugr.connect(neigh, src, new, tgt);
609 }
610 }
611 for src in other.node_outputs(old) {
612 for (neigh, tgt) in other.linked_inputs(old, src) {
613 if neigh == old {
614 continue;
615 }
616 let Some(&neigh) = node_map.get(&neigh) else {
617 continue;
618 };
619 hugr.connect(new, src, neigh, tgt);
620 }
621 }
622 }
623 node_map
624}
625
626#[cfg(test)]
627mod test {
628 use crate::extension::PRELUDE;
629 use crate::{
630 extension::prelude::{Noop, usize_t},
631 ops::{self, FuncDefn, Input, Output, dataflow::IOTrait},
632 types::Signature,
633 };
634
635 use super::*;
636
637 #[test]
638 fn simple_function() -> Result<(), Box<dyn std::error::Error>> {
639 let mut hugr = Hugr::default();
640 hugr.use_extension(PRELUDE.to_owned());
641
642 let module: Node = hugr.entrypoint();
644
645 let f: Node = hugr.add_node_with_parent(
647 module,
648 ops::FuncDefn::new(
649 "main",
650 Signature::new(usize_t(), vec![usize_t(), usize_t()]),
651 ),
652 );
653
654 {
655 let f_in = hugr.add_node_with_parent(f, ops::Input::new(vec![usize_t()]));
656 let f_out = hugr.add_node_with_parent(f, ops::Output::new(vec![usize_t(), usize_t()]));
657 let noop = hugr.add_node_with_parent(f, Noop(usize_t()));
658
659 hugr.connect(f_in, 0, noop, 0);
660 hugr.connect(noop, 0, f_out, 0);
661 hugr.connect(noop, 0, f_out, 1);
662 }
663
664 hugr.validate()?;
665
666 Ok(())
667 }
668
669 #[test]
670 fn metadata() {
671 let mut hugr = Hugr::default();
672
673 let root: Node = hugr.entrypoint();
675
676 assert_eq!(hugr.get_metadata(root, "meta"), None);
677
678 *hugr.get_metadata_mut(root, "meta") = "test".into();
679 assert_eq!(hugr.get_metadata(root, "meta"), Some(&"test".into()));
680
681 hugr.set_metadata(root, "meta", "new");
682 assert_eq!(hugr.get_metadata(root, "meta"), Some(&"new".into()));
683
684 hugr.remove_metadata(root, "meta");
685 assert_eq!(hugr.get_metadata(root, "meta"), None);
686 }
687
688 #[test]
689 fn remove_subtree() {
690 let mut hugr = Hugr::default();
691 hugr.use_extension(PRELUDE.to_owned());
692 let root = hugr.entrypoint();
693 let [foo, bar] = ["foo", "bar"].map(|name| {
694 let fd = hugr
695 .add_node_with_parent(root, FuncDefn::new(name, Signature::new_endo(usize_t())));
696 let inp = hugr.add_node_with_parent(fd, Input::new(usize_t()));
697 let out = hugr.add_node_with_parent(fd, Output::new(usize_t()));
698 hugr.connect(inp, 0, out, 0);
699 fd
700 });
701 hugr.validate().unwrap();
702 assert_eq!(hugr.num_nodes(), 7);
703
704 hugr.remove_subtree(foo);
705 hugr.validate().unwrap();
706 assert_eq!(hugr.num_nodes(), 4);
707
708 hugr.remove_subtree(bar);
709 hugr.validate().unwrap();
710 assert_eq!(hugr.num_nodes(), 1);
711 }
712}