1use std::collections::{BTreeMap, HashMap, VecDeque};
4use std::sync::Arc;
5
6use portgraph::{LinkMut, PortMut, PortView, SecondaryMap};
7
8use crate::core::HugrNode;
9use crate::extension::ExtensionRegistry;
10use crate::hugr::views::SiblingSubgraph;
11use crate::hugr::{HugrView, Node, OpType};
12use crate::hugr::{NodeMetadata, Patch};
13use crate::ops::OpTrait;
14use crate::types::Substitution;
15use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex};
16
17use super::internal::HugrMutInternals;
18use super::views::{
19 Rerooted, panic_invalid_node, panic_invalid_non_entrypoint, panic_invalid_port,
20};
21
22pub trait HugrMut: HugrMutInternals {
24 fn set_entrypoint(&mut self, root: Self::Node);
40
41 fn with_entrypoint_mut(&mut self, entrypoint: Self::Node) -> Rerooted<&mut Self>
52 where
53 Self: Sized,
54 {
55 Rerooted::new(self, entrypoint)
56 }
57
58 fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef<str>) -> &mut NodeMetadata;
64
65 fn set_metadata(
71 &mut self,
72 node: Self::Node,
73 key: impl AsRef<str>,
74 metadata: impl Into<NodeMetadata>,
75 );
76
77 fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef<str>);
83
84 fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into<OpType>) -> Self::Node;
92
93 fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into<OpType>) -> Self::Node;
101
102 fn add_node_after(&mut self, sibling: Self::Node, op: impl Into<OpType>) -> Self::Node;
110
111 fn remove_node(&mut self, node: Self::Node) -> OpType;
119
120 fn remove_subtree(&mut self, node: Self::Node);
126
127 fn copy_descendants(
140 &mut self,
141 root: Self::Node,
142 new_parent: Self::Node,
143 subst: Option<Substitution>,
144 ) -> BTreeMap<Self::Node, Self::Node>;
145
146 fn connect(
152 &mut self,
153 src: Self::Node,
154 src_port: impl Into<OutgoingPort>,
155 dst: Self::Node,
156 dst_port: impl Into<IncomingPort>,
157 );
158
159 fn disconnect(&mut self, node: Self::Node, port: impl Into<Port>);
167
168 fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort);
180
181 fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult<Node, Self::Node>;
187
188 fn insert_from_view<H: HugrView>(
194 &mut self,
195 root: Self::Node,
196 other: &H,
197 ) -> InsertionResult<H::Node, Self::Node>;
198
199 fn insert_subgraph<H: HugrView>(
214 &mut self,
215 root: Self::Node,
216 other: &H,
217 subgraph: &SiblingSubgraph<H::Node>,
218 ) -> HashMap<H::Node, Self::Node>;
219
220 fn apply_patch<R, E>(&mut self, rw: impl Patch<Self, Outcome = R, Error = E>) -> Result<R, E>
222 where
223 Self: Sized,
224 {
225 rw.apply(self)
226 }
227
228 fn use_extension(&mut self, extension: impl Into<Arc<Extension>>);
235
236 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
246 where
247 ExtensionRegistry: Extend<Reg>;
248}
249
250pub struct InsertionResult<SourceN = Node, TargetN = Node> {
256 pub inserted_entrypoint: TargetN,
260 pub node_map: HashMap<SourceN, TargetN>,
263}
264
265fn translate_indices<N: HugrNode>(
271 mut source_node: impl FnMut(portgraph::NodeIndex) -> N,
272 mut target_node: impl FnMut(portgraph::NodeIndex) -> Node,
273 node_map: HashMap<portgraph::NodeIndex, portgraph::NodeIndex>,
274) -> impl Iterator<Item = (N, Node)> {
275 node_map
276 .into_iter()
277 .map(move |(k, v)| (source_node(k), target_node(v)))
278}
279
280impl HugrMut for Hugr {
282 #[inline]
283 fn set_entrypoint(&mut self, root: Node) {
284 panic_invalid_node(self, root);
285 self.entrypoint = root.into_portgraph();
286 }
287
288 fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef<str>) -> &mut NodeMetadata {
289 panic_invalid_node(self, node);
290 self.node_metadata_map_mut(node)
291 .entry(key.as_ref())
292 .or_insert(serde_json::Value::Null)
293 }
294
295 fn set_metadata(
296 &mut self,
297 node: Self::Node,
298 key: impl AsRef<str>,
299 metadata: impl Into<NodeMetadata>,
300 ) {
301 let entry = self.get_metadata_mut(node, key);
302 *entry = metadata.into();
303 }
304
305 fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef<str>) {
306 panic_invalid_node(self, node);
307 let node_meta = self.node_metadata_map_mut(node);
308 node_meta.remove(key.as_ref());
309 }
310
311 fn add_node_with_parent(&mut self, parent: Node, node: impl Into<OpType>) -> Node {
312 let node = self.as_mut().add_node(node.into());
313 self.hierarchy
314 .push_child(node.into_portgraph(), parent.into_portgraph())
315 .expect("Inserting a newly-created node into the hierarchy should never fail.");
316 node
317 }
318
319 fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
320 let node = self.as_mut().add_node(nodetype.into());
321 self.hierarchy
322 .insert_before(node.into_portgraph(), sibling.into_portgraph())
323 .expect("Inserting a newly-created node into the hierarchy should never fail.");
324 node
325 }
326
327 fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
328 let node = self.as_mut().add_node(op.into());
329 self.hierarchy
330 .insert_after(node.into_portgraph(), sibling.into_portgraph())
331 .expect("Inserting a newly-created node into the hierarchy should never fail.");
332 node
333 }
334
335 fn remove_node(&mut self, node: Node) -> OpType {
336 panic_invalid_non_entrypoint(self, node);
337 self.hierarchy.remove(node.into_portgraph());
338 self.graph.remove_node(node.into_portgraph());
339 self.op_types.take(node.into_portgraph())
340 }
341
342 fn remove_subtree(&mut self, node: Node) {
343 panic_invalid_non_entrypoint(self, node);
344 let mut queue = VecDeque::new();
345 queue.push_back(node);
346 while let Some(n) = queue.pop_front() {
347 queue.extend(self.children(n));
348 self.remove_node(n);
349 }
350 }
351
352 fn connect(
353 &mut self,
354 src: Node,
355 src_port: impl Into<OutgoingPort>,
356 dst: Node,
357 dst_port: impl Into<IncomingPort>,
358 ) {
359 let src_port = src_port.into();
360 let dst_port = dst_port.into();
361 panic_invalid_port(self, src, src_port);
362 panic_invalid_port(self, dst, dst_port);
363 self.graph
364 .link_nodes(
365 src.into_portgraph(),
366 src_port.index(),
367 dst.into_portgraph(),
368 dst_port.index(),
369 )
370 .expect("The ports should exist at this point.");
371 }
372
373 fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
374 let port = port.into();
375 let offset = port.pg_offset();
376 panic_invalid_port(self, node, port);
377 let port = self
378 .graph
379 .port_index(node.into_portgraph(), offset)
380 .expect("The port should exist at this point.");
381 self.graph.unlink_port(port);
382 }
383
384 fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
385 let src_port = self
386 .get_optype(src)
387 .other_output_port()
388 .expect("Source operation has no non-dataflow outgoing edges");
389 let dst_port = self
390 .get_optype(dst)
391 .other_input_port()
392 .expect("Destination operation has no non-dataflow incoming edges");
393 self.connect(src, src_port, dst, dst_port);
394 (src_port, dst_port)
395 }
396
397 fn insert_hugr(
398 &mut self,
399 root: Self::Node,
400 mut other: Hugr,
401 ) -> InsertionResult<Node, Self::Node> {
402 let node_map = insert_hugr_internal(self, &other, other.entry_descendants(), |&n| {
403 if n == other.entrypoint() {
404 Some(root)
405 } else {
406 None
407 }
408 });
409 self.extensions.extend(other.extensions());
411 for (&node, &new_node) in &node_map {
415 let node_pg = node.into_portgraph();
416 let new_node_pg = new_node.into_portgraph();
417 let optype = other.op_types.take(node_pg);
418 self.op_types.set(new_node_pg, optype);
419 let meta = other.metadata.take(node_pg);
420 self.metadata.set(new_node_pg, meta);
421 }
422 InsertionResult {
423 inserted_entrypoint: node_map[&other.entrypoint()],
424 node_map,
425 }
426 }
427
428 fn insert_from_view<H: HugrView>(
429 &mut self,
430 root: Self::Node,
431 other: &H,
432 ) -> InsertionResult<H::Node, Self::Node> {
433 let node_map = insert_hugr_internal(self, other, other.entry_descendants(), |&n| {
434 if n == other.entrypoint() {
435 Some(root)
436 } else {
437 None
438 }
439 });
440 self.extensions.extend(other.extensions());
442 for (&node, &new_node) in &node_map {
446 let nodetype = other.get_optype(node);
447 self.op_types
448 .set(new_node.into_portgraph(), nodetype.clone());
449 let meta = other.node_metadata_map(node);
450 if !meta.is_empty() {
451 self.metadata
452 .set(new_node.into_portgraph(), Some(meta.clone()));
453 }
454 }
455 InsertionResult {
456 inserted_entrypoint: node_map[&other.entrypoint()],
457 node_map,
458 }
459 }
460
461 fn insert_subgraph<H: HugrView>(
462 &mut self,
463 root: Self::Node,
464 other: &H,
465 subgraph: &SiblingSubgraph<H::Node>,
466 ) -> HashMap<H::Node, Self::Node> {
467 let node_map = insert_hugr_internal(self, other, subgraph.nodes().iter().copied(), |_| {
468 Some(root)
469 });
470 for (&node, &new_node) in &node_map {
472 let nodetype = other.get_optype(node);
473 self.op_types
474 .set(new_node.into_portgraph(), nodetype.clone());
475 let meta = other.node_metadata_map(node);
476 if !meta.is_empty() {
477 self.metadata
478 .set(new_node.into_portgraph(), Some(meta.clone()));
479 }
480 if let Ok(exts) = nodetype.used_extensions() {
482 self.use_extensions(exts);
483 }
484 }
485 node_map
486 }
487
488 fn copy_descendants(
489 &mut self,
490 root: Self::Node,
491 new_parent: Self::Node,
492 subst: Option<Substitution>,
493 ) -> BTreeMap<Self::Node, Self::Node> {
494 let mut descendants = self.hierarchy.descendants(root.into_portgraph());
495 let root2 = descendants.next();
496 debug_assert_eq!(root2, Some(root.into_portgraph()));
497 let nodes = Vec::from_iter(descendants);
498 let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes)
499 .copy_in_parent()
500 .expect("Is a MultiPortGraph");
501 let node_map =
502 translate_indices(Into::into, Into::into, node_map).collect::<BTreeMap<_, _>>();
503
504 for node in self.children(root).collect::<Vec<_>>() {
505 self.set_parent(*node_map.get(&node).unwrap(), new_parent);
506 }
507
508 for (&node, &new_node) in &node_map {
510 for ch in self.children(node).collect::<Vec<_>>() {
511 self.set_parent(*node_map.get(&ch).unwrap(), new_node);
512 }
513 let new_optype = match (&subst, self.get_optype(node)) {
514 (None, op) => op.clone(),
515 (Some(subst), op) => op.substitute(subst),
516 };
517 self.op_types.set(new_node.into_portgraph(), new_optype);
518 let meta = self.metadata.get(node.into_portgraph()).clone();
519 self.metadata.set(new_node.into_portgraph(), meta);
520 }
521 node_map
522 }
523
524 #[inline]
525 fn use_extension(&mut self, extension: impl Into<Arc<Extension>>) {
526 self.extensions_mut().register_updated(extension);
527 }
528
529 #[inline]
530 fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
531 where
532 ExtensionRegistry: Extend<Reg>,
533 {
534 self.extensions_mut().extend(registry);
535 }
536}
537
538fn insert_hugr_internal<H: HugrView>(
557 hugr: &mut Hugr,
558 other: &H,
559 other_nodes: impl Iterator<Item = H::Node>,
560 reroot: impl Fn(&H::Node) -> Option<Node>,
561) -> HashMap<H::Node, Node> {
562 let new_node_count_hint = other_nodes.size_hint().1.unwrap_or_default();
563
564 let mut node_map = HashMap::with_capacity(new_node_count_hint);
566 hugr.reserve(new_node_count_hint, 0);
567
568 for old in other_nodes {
569 let op = OpType::default();
572 let new = hugr.add_node(op);
573 node_map.insert(old, new);
574
575 hugr.set_num_ports(new, other.num_inputs(old), other.num_outputs(old));
576
577 let new_parent = if let Some(new_parent) = reroot(&old) {
578 new_parent
579 } else {
580 let old_parent = other.get_parent(old).unwrap();
581 *node_map
582 .get(&old_parent)
583 .expect("Child node came before parent in `other_nodes` iterator")
584 };
585 hugr.set_parent(new, new_parent);
586
587 for tgt in other.node_inputs(old) {
589 for (neigh, src) in other.linked_outputs(old, tgt) {
590 let Some(&neigh) = node_map.get(&neigh) else {
591 continue;
592 };
593 hugr.connect(neigh, src, new, tgt);
594 }
595 }
596 for src in other.node_outputs(old) {
597 for (neigh, tgt) in other.linked_inputs(old, src) {
598 if neigh == old {
599 continue;
600 }
601 let Some(&neigh) = node_map.get(&neigh) else {
602 continue;
603 };
604 hugr.connect(new, src, neigh, tgt);
605 }
606 }
607 }
608 node_map
609}
610
611#[cfg(test)]
612mod test {
613 use crate::extension::PRELUDE;
614 use crate::{
615 extension::prelude::{Noop, usize_t},
616 ops::{self, FuncDefn, Input, Output, dataflow::IOTrait},
617 types::Signature,
618 };
619
620 use super::*;
621
622 #[test]
623 fn simple_function() -> Result<(), Box<dyn std::error::Error>> {
624 let mut hugr = Hugr::default();
625 hugr.use_extension(PRELUDE.to_owned());
626
627 let module: Node = hugr.entrypoint();
629
630 let f: Node = hugr.add_node_with_parent(
632 module,
633 ops::FuncDefn::new(
634 "main",
635 Signature::new(usize_t(), vec![usize_t(), usize_t()]),
636 ),
637 );
638
639 {
640 let f_in = hugr.add_node_with_parent(f, ops::Input::new(vec![usize_t()]));
641 let f_out = hugr.add_node_with_parent(f, ops::Output::new(vec![usize_t(), usize_t()]));
642 let noop = hugr.add_node_with_parent(f, Noop(usize_t()));
643
644 hugr.connect(f_in, 0, noop, 0);
645 hugr.connect(noop, 0, f_out, 0);
646 hugr.connect(noop, 0, f_out, 1);
647 }
648
649 hugr.validate()?;
650
651 Ok(())
652 }
653
654 #[test]
655 fn metadata() {
656 let mut hugr = Hugr::default();
657
658 let root: Node = hugr.entrypoint();
660
661 assert_eq!(hugr.get_metadata(root, "meta"), None);
662
663 *hugr.get_metadata_mut(root, "meta") = "test".into();
664 assert_eq!(hugr.get_metadata(root, "meta"), Some(&"test".into()));
665
666 hugr.set_metadata(root, "meta", "new");
667 assert_eq!(hugr.get_metadata(root, "meta"), Some(&"new".into()));
668
669 hugr.remove_metadata(root, "meta");
670 assert_eq!(hugr.get_metadata(root, "meta"), None);
671 }
672
673 #[test]
674 fn remove_subtree() {
675 let mut hugr = Hugr::default();
676 hugr.use_extension(PRELUDE.to_owned());
677 let root = hugr.entrypoint();
678 let [foo, bar] = ["foo", "bar"].map(|name| {
679 let fd = hugr
680 .add_node_with_parent(root, FuncDefn::new(name, Signature::new_endo(usize_t())));
681 let inp = hugr.add_node_with_parent(fd, Input::new(usize_t()));
682 let out = hugr.add_node_with_parent(fd, Output::new(usize_t()));
683 hugr.connect(inp, 0, out, 0);
684 fd
685 });
686 hugr.validate().unwrap();
687 assert_eq!(hugr.num_nodes(), 7);
688
689 hugr.remove_subtree(foo);
690 hugr.validate().unwrap();
691 assert_eq!(hugr.num_nodes(), 4);
692
693 hugr.remove_subtree(bar);
694 hugr.validate().unwrap();
695 assert_eq!(hugr.num_nodes(), 1);
696 }
697}