1use core::panic;
4use std::collections::{BTreeMap, HashMap};
5use std::sync::Arc;
6
7use portgraph::view::{NodeFilter, NodeFiltered};
8use portgraph::{LinkMut, PortMut, PortView, SecondaryMap};
9
10use crate::extension::ExtensionRegistry;
11use crate::hugr::views::SiblingSubgraph;
12use crate::hugr::{HugrView, Node, OpType, RootTagged};
13use crate::hugr::{NodeMetadata, Rewrite};
14use crate::ops::OpTrait;
15use crate::types::Substitution;
16use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex};
17
18use super::internal::HugrMutInternals;
19use super::NodeMetadataMap;
20
21pub trait HugrMut: HugrMutInternals {
23    fn get_metadata_mut(&mut self, node: Node, key: impl AsRef<str>) -> &mut NodeMetadata {
29        panic_invalid_node(self, node);
30        let node_meta = self
31            .hugr_mut()
32            .metadata
33            .get_mut(node.pg_index())
34            .get_or_insert_with(Default::default);
35        node_meta
36            .entry(key.as_ref())
37            .or_insert(serde_json::Value::Null)
38    }
39
40    fn set_metadata(
46        &mut self,
47        node: Node,
48        key: impl AsRef<str>,
49        metadata: impl Into<NodeMetadata>,
50    ) {
51        let entry = self.get_metadata_mut(node, key);
52        *entry = metadata.into();
53    }
54
55    fn remove_metadata(&mut self, node: Node, key: impl AsRef<str>) {
61        panic_invalid_node(self, node);
62        let node_meta = self.hugr_mut().metadata.get_mut(node.pg_index());
63        if let Some(node_meta) = node_meta {
64            node_meta.remove(key.as_ref());
65        }
66    }
67
68    fn take_node_metadata(&mut self, node: Self::Node) -> Option<NodeMetadataMap> {
70        if !self.valid_node(node) {
71            return None;
72        }
73        self.hugr_mut().metadata.take(node.pg_index())
74    }
75
76    fn overwrite_node_metadata(&mut self, node: Node, metadata: Option<NodeMetadataMap>) {
82        panic_invalid_node(self, node);
83        self.hugr_mut().metadata.set(node.pg_index(), metadata);
84    }
85
86    #[inline]
94    fn add_node_with_parent(&mut self, parent: Node, op: impl Into<OpType>) -> Node {
95        panic_invalid_node(self, parent);
96        self.hugr_mut().add_node_with_parent(parent, op)
97    }
98
99    #[inline]
107    fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
108        panic_invalid_non_root(self, sibling);
109        self.hugr_mut().add_node_before(sibling, nodetype)
110    }
111
112    #[inline]
120    fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
121        panic_invalid_non_root(self, sibling);
122        self.hugr_mut().add_node_after(sibling, op)
123    }
124
125    #[inline]
133    fn remove_node(&mut self, node: Node) -> OpType {
134        panic_invalid_non_root(self, node);
135        self.hugr_mut().remove_node(node)
136    }
137
138    fn remove_subtree(&mut self, node: Node) {
144        panic_invalid_non_root(self, node);
145        while let Some(ch) = self.first_child(node) {
146            self.remove_subtree(ch)
147        }
148        self.hugr_mut().remove_node(node);
149    }
150
151    fn copy_descendants(
164        &mut self,
165        root: Node,
166        new_parent: Node,
167        subst: Option<Substitution>,
168    ) -> BTreeMap<Node, Node> {
169        panic_invalid_node(self, root);
170        panic_invalid_node(self, new_parent);
171        self.hugr_mut().copy_descendants(root, new_parent, subst)
172    }
173
174    #[inline]
180    fn connect(
181        &mut self,
182        src: Node,
183        src_port: impl Into<OutgoingPort>,
184        dst: Node,
185        dst_port: impl Into<IncomingPort>,
186    ) {
187        panic_invalid_node(self, src);
188        panic_invalid_node(self, dst);
189        self.hugr_mut().connect(src, src_port, dst, dst_port);
190    }
191
192    #[inline]
200    fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
201        panic_invalid_node(self, node);
202        self.hugr_mut().disconnect(node, port);
203    }
204
205    fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
217        panic_invalid_node(self, src);
218        panic_invalid_node(self, dst);
219        self.hugr_mut().add_other_edge(src, dst)
220    }
221
222    #[inline]
228    fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult {
229        panic_invalid_node(self, root);
230        self.hugr_mut().insert_hugr(root, other)
231    }
232
233    #[inline]
239    fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult {
240        panic_invalid_node(self, root);
241        self.hugr_mut().insert_from_view(root, other)
242    }
243
244    fn insert_subgraph(
259        &mut self,
260        root: Node,
261        other: &impl HugrView,
262        subgraph: &SiblingSubgraph,
263    ) -> HashMap<Node, Node> {
264        panic_invalid_node(self, root);
265        self.hugr_mut().insert_subgraph(root, other, subgraph)
266    }
267
268    fn apply_rewrite<R, E>(&mut self, rw: impl Rewrite<ApplyResult = R, Error = E>) -> Result<R, E>
270    where
271        Self: Sized,
272    {
273        rw.apply(self)
274    }
275
276    fn use_extension(&mut self, extension: impl Into<Arc<Extension>>) {
283        self.hugr_mut().extensions.register_updated(extension);
284    }
285
286    fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>)
296    where
297        ExtensionRegistry: Extend<Reg>,
298    {
299        self.hugr_mut().extensions.extend(registry);
300    }
301
302    fn extensions_mut(&mut self) -> &mut ExtensionRegistry {
304        &mut self.hugr_mut().extensions
305    }
306}
307
308pub struct InsertionResult {
311    pub new_root: Node,
315    pub node_map: HashMap<Node, Node>,
318}
319
320fn translate_indices(
321    node_map: HashMap<portgraph::NodeIndex, portgraph::NodeIndex>,
322) -> impl Iterator<Item = (Node, Node)> {
323    node_map.into_iter().map(|(k, v)| (k.into(), v.into()))
324}
325
326impl<T: RootTagged<RootHandle = Node, Node = Node> + AsMut<Hugr>> HugrMut for T {
328    fn add_node_with_parent(&mut self, parent: Node, node: impl Into<OpType>) -> Node {
329        let node = self.as_mut().add_node(node.into());
330        self.as_mut()
331            .hierarchy
332            .push_child(node.pg_index(), parent.pg_index())
333            .expect("Inserting a newly-created node into the hierarchy should never fail.");
334        node
335    }
336
337    fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
338        let node = self.as_mut().add_node(nodetype.into());
339        self.as_mut()
340            .hierarchy
341            .insert_before(node.pg_index(), sibling.pg_index())
342            .expect("Inserting a newly-created node into the hierarchy should never fail.");
343        node
344    }
345
346    fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
347        let node = self.as_mut().add_node(op.into());
348        self.as_mut()
349            .hierarchy
350            .insert_after(node.pg_index(), sibling.pg_index())
351            .expect("Inserting a newly-created node into the hierarchy should never fail.");
352        node
353    }
354
355    fn remove_node(&mut self, node: Node) -> OpType {
356        panic_invalid_non_root(self, node);
357        self.as_mut().hierarchy.remove(node.pg_index());
358        self.as_mut().graph.remove_node(node.pg_index());
359        self.as_mut().op_types.take(node.pg_index())
360    }
361
362    fn connect(
363        &mut self,
364        src: Node,
365        src_port: impl Into<OutgoingPort>,
366        dst: Node,
367        dst_port: impl Into<IncomingPort>,
368    ) {
369        let src_port = src_port.into();
370        let dst_port = dst_port.into();
371        panic_invalid_port(self, src, src_port);
372        panic_invalid_port(self, dst, dst_port);
373        self.as_mut()
374            .graph
375            .link_nodes(
376                src.pg_index(),
377                src_port.index(),
378                dst.pg_index(),
379                dst_port.index(),
380            )
381            .expect("The ports should exist at this point.");
382    }
383
384    fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
385        let port = port.into();
386        let offset = port.pg_offset();
387        panic_invalid_port(self, node, port);
388        let port = self
389            .as_mut()
390            .graph
391            .port_index(node.pg_index(), offset)
392            .expect("The port should exist at this point.");
393        self.as_mut().graph.unlink_port(port);
394    }
395
396    fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
397        let src_port = self
398            .get_optype(src)
399            .other_output_port()
400            .expect("Source operation has no non-dataflow outgoing edges");
401        let dst_port = self
402            .get_optype(dst)
403            .other_input_port()
404            .expect("Destination operation has no non-dataflow incoming edges");
405        self.connect(src, src_port, dst, dst_port);
406        (src_port, dst_port)
407    }
408
409    fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult {
410        let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other);
411        for (&node, &new_node) in node_map.iter() {
415            let optype = other.op_types.take(node);
416            self.as_mut().op_types.set(new_node, optype);
417            let meta = other.metadata.take(node);
418            self.as_mut().metadata.set(new_node, meta);
419        }
420        debug_assert_eq!(
421            Some(&new_root.pg_index()),
422            node_map.get(&other.root().pg_index())
423        );
424        InsertionResult {
425            new_root,
426            node_map: translate_indices(node_map).collect(),
427        }
428    }
429
430    fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult {
431        let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other);
432        for (&node, &new_node) in node_map.iter() {
436            let nodetype = other.get_optype(other.get_node(node));
437            self.as_mut().op_types.set(new_node, nodetype.clone());
438            let meta = other.base_hugr().metadata.get(node);
439            self.as_mut().metadata.set(new_node, meta.clone());
440        }
441        debug_assert_eq!(
442            Some(&new_root.pg_index()),
443            node_map.get(&other.get_pg_index(other.root()))
444        );
445        InsertionResult {
446            new_root,
447            node_map: translate_indices(node_map).collect(),
448        }
449    }
450
451    fn insert_subgraph(
452        &mut self,
453        root: Node,
454        other: &impl HugrView,
455        subgraph: &SiblingSubgraph,
456    ) -> HashMap<Node, Node> {
457        let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> =
459            NodeFiltered::new_node_filtered(
460                other.portgraph(),
461                |node, ctx| ctx.contains(&node.into()),
462                subgraph.nodes(),
463            );
464        let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph);
465        for (&node, &new_node) in node_map.iter() {
467            let nodetype = other.get_optype(other.get_node(node));
468            self.as_mut().op_types.set(new_node, nodetype.clone());
469            let meta = other.base_hugr().metadata.get(node);
470            self.as_mut().metadata.set(new_node, meta.clone());
471            if let Ok(exts) = nodetype.used_extensions() {
473                self.use_extensions(exts);
474            }
475        }
476        translate_indices(node_map).collect()
477    }
478
479    fn copy_descendants(
480        &mut self,
481        root: Node,
482        new_parent: Node,
483        subst: Option<Substitution>,
484    ) -> BTreeMap<Node, Node> {
485        let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index());
486        let root2 = descendants.next();
487        debug_assert_eq!(root2, Some(root.pg_index()));
488        let nodes = Vec::from_iter(descendants);
489        let node_map = translate_indices(
490            portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes)
491                .copy_in_parent()
492                .expect("Is a MultiPortGraph"),
493        )
494        .collect::<BTreeMap<_, _>>();
495
496        for node in self.children(root).collect::<Vec<_>>() {
497            self.set_parent(*node_map.get(&node).unwrap(), new_parent);
498        }
499
500        for (&node, &new_node) in node_map.iter() {
502            for ch in self.children(node).collect::<Vec<_>>() {
503                self.set_parent(*node_map.get(&ch).unwrap(), new_node);
504            }
505            let new_optype = match (&subst, self.get_optype(node)) {
506                (None, op) => op.clone(),
507                (Some(subst), op) => op.substitute(subst),
508            };
509            self.as_mut().op_types.set(new_node.pg_index(), new_optype);
510            let meta = self.base_hugr().metadata.get(node.pg_index()).clone();
511            self.as_mut().metadata.set(new_node.pg_index(), meta);
512        }
513        node_map
514    }
515}
516
517fn insert_hugr_internal<H: HugrView>(
526    hugr: &mut Hugr,
527    root: Node,
528    other: &H,
529) -> (Node, HashMap<portgraph::NodeIndex, portgraph::NodeIndex>) {
530    let node_map = hugr
531        .graph
532        .insert_graph(&other.portgraph())
533        .unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}"));
534    let other_root = node_map[&other.get_pg_index(other.root())];
535
536    hugr.hierarchy
538        .push_child(other_root, root.pg_index())
539        .expect("Inserting a newly-created node into the hierarchy should never fail.");
540    for (&node, &new_node) in node_map.iter() {
541        other.children(other.get_node(node)).for_each(|child| {
542            hugr.hierarchy
543                .push_child(node_map[&other.get_pg_index(child)], new_node)
544                .expect("Inserting a newly-created node into the hierarchy should never fail.");
545        });
546    }
547
548    hugr.extensions.extend(other.extensions());
550
551    (other_root.into(), node_map)
552}
553
554fn insert_subgraph_internal(
567    hugr: &mut Hugr,
568    root: Node,
569    other: &impl HugrView,
570    portgraph: &impl portgraph::LinkView,
571) -> HashMap<portgraph::NodeIndex, portgraph::NodeIndex> {
572    let node_map = hugr
573        .graph
574        .insert_graph(&portgraph)
575        .expect("Internal error while inserting a subgraph into another");
576
577    for (&node, &new_node) in node_map.iter() {
580        let new_parent = other
581            .get_parent(other.get_node(node))
582            .and_then(|parent| node_map.get(&other.get_pg_index(parent)).copied())
583            .unwrap_or(root.pg_index());
584        hugr.hierarchy
585            .push_child(new_node, new_parent)
586            .expect("Inserting a newly-created node into the hierarchy should never fail.");
587    }
588
589    node_map
590}
591
592#[track_caller]
594pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
595    if !hugr.valid_node(node) {
596        panic!(
597            "Received an invalid node {node} while mutating a HUGR:\n\n {}",
598            hugr.mermaid_string()
599        );
600    }
601}
602
603#[track_caller]
605pub(super) fn panic_invalid_non_root<H: HugrView + ?Sized>(hugr: &H, node: H::Node) {
606    if !hugr.valid_non_root(node) {
607        panic!(
608            "Received an invalid non-root node {node} while mutating a HUGR:\n\n {}",
609            hugr.mermaid_string()
610        );
611    }
612}
613
614#[track_caller]
616pub(super) fn panic_invalid_port<H: HugrView + ?Sized>(
617    hugr: &H,
618    node: Node,
619    port: impl Into<Port>,
620) {
621    let port = port.into();
622    if hugr
623        .portgraph()
624        .port_index(node.pg_index(), port.pg_offset())
625        .is_none()
626    {
627        panic!(
628            "Received an invalid port {port} for node {node} while mutating a HUGR:\n\n {}",
629            hugr.mermaid_string()
630        );
631    }
632}
633
634#[cfg(test)]
635mod test {
636    use crate::extension::PRELUDE;
637    use crate::{
638        extension::prelude::{usize_t, Noop},
639        ops::{self, dataflow::IOTrait, FuncDefn, Input, Output},
640        types::Signature,
641    };
642
643    use super::*;
644
645    #[test]
646    fn simple_function() -> Result<(), Box<dyn std::error::Error>> {
647        let mut hugr = Hugr::default();
648        hugr.use_extension(PRELUDE.to_owned());
649
650        let module: Node = hugr.root();
652
653        let f: Node = hugr.add_node_with_parent(
655            module,
656            ops::FuncDefn {
657                name: "main".into(),
658                signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()])
659                    .with_prelude()
660                    .into(),
661            },
662        );
663
664        {
665            let f_in = hugr.add_node_with_parent(f, ops::Input::new(vec![usize_t()]));
666            let f_out = hugr.add_node_with_parent(f, ops::Output::new(vec![usize_t(), usize_t()]));
667            let noop = hugr.add_node_with_parent(f, Noop(usize_t()));
668
669            hugr.connect(f_in, 0, noop, 0);
670            hugr.connect(noop, 0, f_out, 0);
671            hugr.connect(noop, 0, f_out, 1);
672        }
673
674        hugr.validate()?;
675
676        Ok(())
677    }
678
679    #[test]
680    fn metadata() {
681        let mut hugr = Hugr::default();
682
683        let root: Node = hugr.root();
685
686        assert_eq!(hugr.get_metadata(root, "meta"), None);
687
688        *hugr.get_metadata_mut(root, "meta") = "test".into();
689        assert_eq!(hugr.get_metadata(root, "meta"), Some(&"test".into()));
690
691        hugr.set_metadata(root, "meta", "new");
692        assert_eq!(hugr.get_metadata(root, "meta"), Some(&"new".into()));
693
694        hugr.remove_metadata(root, "meta");
695        assert_eq!(hugr.get_metadata(root, "meta"), None);
696    }
697
698    #[test]
699    fn remove_subtree() {
700        let mut hugr = Hugr::default();
701        hugr.use_extension(PRELUDE.to_owned());
702        let root = hugr.root();
703        let [foo, bar] = ["foo", "bar"].map(|name| {
704            let fd = hugr.add_node_with_parent(
705                root,
706                FuncDefn {
707                    name: name.to_string(),
708                    signature: Signature::new_endo(usize_t()).into(),
709                },
710            );
711            let inp = hugr.add_node_with_parent(fd, Input::new(usize_t()));
712            let out = hugr.add_node_with_parent(fd, Output::new(usize_t()));
713            hugr.connect(inp, 0, out, 0);
714            fd
715        });
716        hugr.validate().unwrap();
717        assert_eq!(hugr.node_count(), 7);
718
719        hugr.remove_subtree(foo);
720        hugr.validate().unwrap();
721        assert_eq!(hugr.node_count(), 4);
722
723        hugr.remove_subtree(bar);
724        hugr.validate().unwrap();
725        assert_eq!(hugr.node_count(), 1);
726    }
727}