1use core::panic;
4use std::collections::{BTreeMap, HashMap};
5use std::sync::Arc;
6
7use portgraph::view::{NodeFilter, NodeFiltered};
8use portgraph::{LinkMut, PortMut, PortView, SecondaryMap};
9
10use crate::extension::ExtensionRegistry;
11use crate::hugr::views::SiblingSubgraph;
12use crate::hugr::{HugrView, Node, OpType, RootTagged};
13use crate::hugr::{NodeMetadata, Rewrite};
14use crate::ops::OpTrait;
15use crate::types::Substitution;
16use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex};
17
18use super::internal::HugrMutInternals;
19use super::NodeMetadataMap;
20
21pub trait HugrMut: HugrMutInternals {
23 fn get_metadata_mut(&mut self, node: Node, key: impl AsRef<str>) -> &mut NodeMetadata {
29 panic_invalid_node(self, node);
30 let node_meta = self
31 .hugr_mut()
32 .metadata
33 .get_mut(node.pg_index())
34 .get_or_insert_with(Default::default);
35 node_meta
36 .entry(key.as_ref())
37 .or_insert(serde_json::Value::Null)
38 }
39
40 fn set_metadata(
46 &mut self,
47 node: Node,
48 key: impl AsRef<str>,
49 metadata: impl Into<NodeMetadata>,
50 ) {
51 let entry = self.get_metadata_mut(node, key);
52 *entry = metadata.into();
53 }
54
55 fn remove_metadata(&mut self, node: Node, key: impl AsRef<str>) {
61 panic_invalid_node(self, node);
62 let node_meta = self.hugr_mut().metadata.get_mut(node.pg_index());
63 if let Some(node_meta) = node_meta {
64 node_meta.remove(key.as_ref());
65 }
66 }
67
68 fn take_node_metadata(&mut self, node: Self::Node) -> Option<NodeMetadataMap> {
70 if !self.valid_node(node) {
71 return None;
72 }
73 self.hugr_mut().metadata.take(node.pg_index())
74 }
75
76 fn overwrite_node_metadata(&mut self, node: Node, metadata: Option<NodeMetadataMap>) {
82 panic_invalid_node(self, node);
83 self.hugr_mut().metadata.set(node.pg_index(), metadata);
84 }
85
86 #[inline]
94 fn add_node_with_parent(&mut self, parent: Node, op: impl Into<OpType>) -> Node {
95 panic_invalid_node(self, parent);
96 self.hugr_mut().add_node_with_parent(parent, op)
97 }
98
99 #[inline]
107 fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
108 panic_invalid_non_root(self, sibling);
109 self.hugr_mut().add_node_before(sibling, nodetype)
110 }
111
112 #[inline]
120 fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
121 panic_invalid_non_root(self, sibling);
122 self.hugr_mut().add_node_after(sibling, op)
123 }
124
125 #[inline]
133 fn remove_node(&mut self, node: Node) -> OpType {
134 panic_invalid_non_root(self, node);
135 self.hugr_mut().remove_node(node)
136 }
137
138 fn remove_subtree(&mut self, node: Node) {
144 panic_invalid_non_root(self, node);
145 while let Some(ch) = self.first_child(node) {
146 self.remove_subtree(ch)
147 }
148 self.hugr_mut().remove_node(node);
149 }
150
151 fn copy_descendants(
164 &mut self,
165 root: Node,
166 new_parent: Node,
167 subst: Option<Substitution>,
168 ) -> BTreeMap<Node, Node> {
169 panic_invalid_node(self, root);
170 panic_invalid_node(self, new_parent);
171 self.hugr_mut().copy_descendants(root, new_parent, subst)
172 }
173
174 #[inline]
180 fn connect(
181 &mut self,
182 src: Node,
183 src_port: impl Into<OutgoingPort>,
184 dst: Node,
185 dst_port: impl Into<IncomingPort>,
186 ) {
187 panic_invalid_node(self, src);
188 panic_invalid_node(self, dst);
189 self.hugr_mut().connect(src, src_port, dst, dst_port);
190 }
191
192 #[inline]
200 fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
201 panic_invalid_node(self, node);
202 self.hugr_mut().disconnect(node, port);
203 }
204
205 fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
217 panic_invalid_node(self, src);
218 panic_invalid_node(self, dst);
219 self.hugr_mut().add_other_edge(src, dst)
220 }
221
222 #[inline]
228 fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult {
229 panic_invalid_node(self, root);
230 self.hugr_mut().insert_hugr(root, other)
231 }
232
233 #[inline]
239 fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult {
240 panic_invalid_node(self, root);
241 self.hugr_mut().insert_from_view(root, other)
242 }
243
244 fn insert_subgraph(
259 &mut self,
260 root: Node,
261 other: &impl HugrView,
262 subgraph: &SiblingSubgraph,
263 ) -> HashMap<Node, Node> {
264 panic_invalid_node(self, root);
265 self.hugr_mut().insert_subgraph(root, other, subgraph)
266 }
267
268 fn apply_rewrite<R, E>(&mut self, rw: impl Rewrite<ApplyResult = R, Error = E>) -> Result<R, E>
270 where
271 Self: Sized,
272 {
273 rw.apply(self)
274 }
275
276 fn use_extension(&mut self, extension: impl Into<Arc<Extension>>) {
283 self.hugr_mut().extensions.register_updated(extension);
284 }
285
286 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
296 where
297 ExtensionRegistry: Extend<Reg>,
298 {
299 self.hugr_mut().extensions.extend(registry);
300 }
301
302 fn extensions_mut(&mut self) -> &mut ExtensionRegistry {
304 &mut self.hugr_mut().extensions
305 }
306}
307
308pub struct InsertionResult {
311 pub new_root: Node,
315 pub node_map: HashMap<Node, Node>,
318}
319
320fn translate_indices(
321 node_map: HashMap<portgraph::NodeIndex, portgraph::NodeIndex>,
322) -> impl Iterator<Item = (Node, Node)> {
323 node_map.into_iter().map(|(k, v)| (k.into(), v.into()))
324}
325
326impl<T: RootTagged<RootHandle = Node, Node = Node> + AsMut<Hugr>> HugrMut for T {
328 fn add_node_with_parent(&mut self, parent: Node, node: impl Into<OpType>) -> Node {
329 let node = self.as_mut().add_node(node.into());
330 self.as_mut()
331 .hierarchy
332 .push_child(node.pg_index(), parent.pg_index())
333 .expect("Inserting a newly-created node into the hierarchy should never fail.");
334 node
335 }
336
337 fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
338 let node = self.as_mut().add_node(nodetype.into());
339 self.as_mut()
340 .hierarchy
341 .insert_before(node.pg_index(), sibling.pg_index())
342 .expect("Inserting a newly-created node into the hierarchy should never fail.");
343 node
344 }
345
346 fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
347 let node = self.as_mut().add_node(op.into());
348 self.as_mut()
349 .hierarchy
350 .insert_after(node.pg_index(), sibling.pg_index())
351 .expect("Inserting a newly-created node into the hierarchy should never fail.");
352 node
353 }
354
355 fn remove_node(&mut self, node: Node) -> OpType {
356 panic_invalid_non_root(self, node);
357 self.as_mut().hierarchy.remove(node.pg_index());
358 self.as_mut().graph.remove_node(node.pg_index());
359 self.as_mut().op_types.take(node.pg_index())
360 }
361
362 fn connect(
363 &mut self,
364 src: Node,
365 src_port: impl Into<OutgoingPort>,
366 dst: Node,
367 dst_port: impl Into<IncomingPort>,
368 ) {
369 let src_port = src_port.into();
370 let dst_port = dst_port.into();
371 panic_invalid_port(self, src, src_port);
372 panic_invalid_port(self, dst, dst_port);
373 self.as_mut()
374 .graph
375 .link_nodes(
376 src.pg_index(),
377 src_port.index(),
378 dst.pg_index(),
379 dst_port.index(),
380 )
381 .expect("The ports should exist at this point.");
382 }
383
384 fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
385 let port = port.into();
386 let offset = port.pg_offset();
387 panic_invalid_port(self, node, port);
388 let port = self
389 .as_mut()
390 .graph
391 .port_index(node.pg_index(), offset)
392 .expect("The port should exist at this point.");
393 self.as_mut().graph.unlink_port(port);
394 }
395
396 fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
397 let src_port = self
398 .get_optype(src)
399 .other_output_port()
400 .expect("Source operation has no non-dataflow outgoing edges");
401 let dst_port = self
402 .get_optype(dst)
403 .other_input_port()
404 .expect("Destination operation has no non-dataflow incoming edges");
405 self.connect(src, src_port, dst, dst_port);
406 (src_port, dst_port)
407 }
408
409 fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult {
410 let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other);
411 for (&node, &new_node) in node_map.iter() {
415 let optype = other.op_types.take(node);
416 self.as_mut().op_types.set(new_node, optype);
417 let meta = other.metadata.take(node);
418 self.as_mut().metadata.set(new_node, meta);
419 }
420 debug_assert_eq!(
421 Some(&new_root.pg_index()),
422 node_map.get(&other.root().pg_index())
423 );
424 InsertionResult {
425 new_root,
426 node_map: translate_indices(node_map).collect(),
427 }
428 }
429
430 fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult {
431 let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other);
432 for (&node, &new_node) in node_map.iter() {
436 let nodetype = other.get_optype(other.get_node(node));
437 self.as_mut().op_types.set(new_node, nodetype.clone());
438 let meta = other.base_hugr().metadata.get(node);
439 self.as_mut().metadata.set(new_node, meta.clone());
440 }
441 debug_assert_eq!(
442 Some(&new_root.pg_index()),
443 node_map.get(&other.get_pg_index(other.root()))
444 );
445 InsertionResult {
446 new_root,
447 node_map: translate_indices(node_map).collect(),
448 }
449 }
450
451 fn insert_subgraph(
452 &mut self,
453 root: Node,
454 other: &impl HugrView,
455 subgraph: &SiblingSubgraph,
456 ) -> HashMap<Node, Node> {
457 let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> =
459 NodeFiltered::new_node_filtered(
460 other.portgraph(),
461 |node, ctx| ctx.contains(&node.into()),
462 subgraph.nodes(),
463 );
464 let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph);
465 for (&node, &new_node) in node_map.iter() {
467 let nodetype = other.get_optype(other.get_node(node));
468 self.as_mut().op_types.set(new_node, nodetype.clone());
469 let meta = other.base_hugr().metadata.get(node);
470 self.as_mut().metadata.set(new_node, meta.clone());
471 if let Ok(exts) = nodetype.used_extensions() {
473 self.use_extensions(exts);
474 }
475 }
476 translate_indices(node_map).collect()
477 }
478
479 fn copy_descendants(
480 &mut self,
481 root: Node,
482 new_parent: Node,
483 subst: Option<Substitution>,
484 ) -> BTreeMap<Node, Node> {
485 let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index());
486 let root2 = descendants.next();
487 debug_assert_eq!(root2, Some(root.pg_index()));
488 let nodes = Vec::from_iter(descendants);
489 let node_map = translate_indices(
490 portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes)
491 .copy_in_parent()
492 .expect("Is a MultiPortGraph"),
493 )
494 .collect::<BTreeMap<_, _>>();
495
496 for node in self.children(root).collect::<Vec<_>>() {
497 self.set_parent(*node_map.get(&node).unwrap(), new_parent);
498 }
499
500 for (&node, &new_node) in node_map.iter() {
502 for ch in self.children(node).collect::<Vec<_>>() {
503 self.set_parent(*node_map.get(&ch).unwrap(), new_node);
504 }
505 let new_optype = match (&subst, self.get_optype(node)) {
506 (None, op) => op.clone(),
507 (Some(subst), op) => op.substitute(subst),
508 };
509 self.as_mut().op_types.set(new_node.pg_index(), new_optype);
510 let meta = self.base_hugr().metadata.get(node.pg_index()).clone();
511 self.as_mut().metadata.set(new_node.pg_index(), meta);
512 }
513 node_map
514 }
515}
516
517fn insert_hugr_internal<H: HugrView>(
526 hugr: &mut Hugr,
527 root: Node,
528 other: &H,
529) -> (Node, HashMap<portgraph::NodeIndex, portgraph::NodeIndex>) {
530 let node_map = hugr
531 .graph
532 .insert_graph(&other.portgraph())
533 .unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}"));
534 let other_root = node_map[&other.get_pg_index(other.root())];
535
536 hugr.hierarchy
538 .push_child(other_root, root.pg_index())
539 .expect("Inserting a newly-created node into the hierarchy should never fail.");
540 for (&node, &new_node) in node_map.iter() {
541 other.children(other.get_node(node)).for_each(|child| {
542 hugr.hierarchy
543 .push_child(node_map[&other.get_pg_index(child)], new_node)
544 .expect("Inserting a newly-created node into the hierarchy should never fail.");
545 });
546 }
547
548 hugr.extensions.extend(other.extensions());
550
551 (other_root.into(), node_map)
552}
553
554fn insert_subgraph_internal(
567 hugr: &mut Hugr,
568 root: Node,
569 other: &impl HugrView,
570 portgraph: &impl portgraph::LinkView,
571) -> HashMap<portgraph::NodeIndex, portgraph::NodeIndex> {
572 let node_map = hugr
573 .graph
574 .insert_graph(&portgraph)
575 .expect("Internal error while inserting a subgraph into another");
576
577 for (&node, &new_node) in node_map.iter() {
580 let new_parent = other
581 .get_parent(other.get_node(node))
582 .and_then(|parent| node_map.get(&other.get_pg_index(parent)).copied())
583 .unwrap_or(root.pg_index());
584 hugr.hierarchy
585 .push_child(new_node, new_parent)
586 .expect("Inserting a newly-created node into the hierarchy should never fail.");
587 }
588
589 node_map
590}
591
592#[track_caller]
594pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
595 if !hugr.valid_node(node) {
596 panic!(
597 "Received an invalid node {node} while mutating a HUGR:\n\n {}",
598 hugr.mermaid_string()
599 );
600 }
601}
602
603#[track_caller]
605pub(super) fn panic_invalid_non_root<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
606 if !hugr.valid_non_root(node) {
607 panic!(
608 "Received an invalid non-root node {node} while mutating a HUGR:\n\n {}",
609 hugr.mermaid_string()
610 );
611 }
612}
613
614#[track_caller]
616pub(super) fn panic_invalid_port<H: HugrView + ?Sized>(
617 hugr: &H,
618 node: Node,
619 port: impl Into<Port>,
620) {
621 let port = port.into();
622 if hugr
623 .portgraph()
624 .port_index(node.pg_index(), port.pg_offset())
625 .is_none()
626 {
627 panic!(
628 "Received an invalid port {port} for node {node} while mutating a HUGR:\n\n {}",
629 hugr.mermaid_string()
630 );
631 }
632}
633
634#[cfg(test)]
635mod test {
636 use crate::extension::PRELUDE;
637 use crate::{
638 extension::prelude::{usize_t, Noop},
639 ops::{self, dataflow::IOTrait, FuncDefn, Input, Output},
640 types::Signature,
641 };
642
643 use super::*;
644
645 #[test]
646 fn simple_function() -> Result<(), Box<dyn std::error::Error>> {
647 let mut hugr = Hugr::default();
648 hugr.use_extension(PRELUDE.to_owned());
649
650 let module: Node = hugr.root();
652
653 let f: Node = hugr.add_node_with_parent(
655 module,
656 ops::FuncDefn {
657 name: "main".into(),
658 signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()])
659 .with_prelude()
660 .into(),
661 },
662 );
663
664 {
665 let f_in = hugr.add_node_with_parent(f, ops::Input::new(vec![usize_t()]));
666 let f_out = hugr.add_node_with_parent(f, ops::Output::new(vec![usize_t(), usize_t()]));
667 let noop = hugr.add_node_with_parent(f, Noop(usize_t()));
668
669 hugr.connect(f_in, 0, noop, 0);
670 hugr.connect(noop, 0, f_out, 0);
671 hugr.connect(noop, 0, f_out, 1);
672 }
673
674 hugr.validate()?;
675
676 Ok(())
677 }
678
679 #[test]
680 fn metadata() {
681 let mut hugr = Hugr::default();
682
683 let root: Node = hugr.root();
685
686 assert_eq!(hugr.get_metadata(root, "meta"), None);
687
688 *hugr.get_metadata_mut(root, "meta") = "test".into();
689 assert_eq!(hugr.get_metadata(root, "meta"), Some(&"test".into()));
690
691 hugr.set_metadata(root, "meta", "new");
692 assert_eq!(hugr.get_metadata(root, "meta"), Some(&"new".into()));
693
694 hugr.remove_metadata(root, "meta");
695 assert_eq!(hugr.get_metadata(root, "meta"), None);
696 }
697
698 #[test]
699 fn remove_subtree() {
700 let mut hugr = Hugr::default();
701 hugr.use_extension(PRELUDE.to_owned());
702 let root = hugr.root();
703 let [foo, bar] = ["foo", "bar"].map(|name| {
704 let fd = hugr.add_node_with_parent(
705 root,
706 FuncDefn {
707 name: name.to_string(),
708 signature: Signature::new_endo(usize_t()).into(),
709 },
710 );
711 let inp = hugr.add_node_with_parent(fd, Input::new(usize_t()));
712 let out = hugr.add_node_with_parent(fd, Output::new(usize_t()));
713 hugr.connect(inp, 0, out, 0);
714 fd
715 });
716 hugr.validate().unwrap();
717 assert_eq!(hugr.node_count(), 7);
718
719 hugr.remove_subtree(foo);
720 hugr.validate().unwrap();
721 assert_eq!(hugr.node_count(), 4);
722
723 hugr.remove_subtree(bar);
724 hugr.validate().unwrap();
725 assert_eq!(hugr.node_count(), 1);
726 }
727}