hugr_core/hugr/views/
rerooted.rs

1//! A HUGR wrapper with a modified entrypoint node, returned by
2//! [`HugrView::with_entrypoint`] and [`HugrMut::with_entrypoint_mut`].
3
4use crate::hugr::internal::{HugrInternals, HugrMutInternals};
5use crate::hugr::{HugrMut, hugrmut::InsertForestResult};
6
7use super::{HugrView, panic_invalid_node};
8
9/// A HUGR wrapper with a modified entrypoint node.
10///
11/// All nodes from the original are still present, but the main entrypoint used
12/// for traversals and optimizations is altered.
13#[derive(Clone)]
14pub struct Rerooted<H: HugrView> {
15    hugr: H,
16    entrypoint: H::Node,
17}
18
19impl<H: HugrView> Rerooted<H> {
20    /// Create a hierarchical view of a whole HUGR
21    ///
22    /// # Panics
23    ///
24    /// If the new entrypoint is not in the HUGR.
25    ///
26    /// [`OpTag`]: crate::ops::OpTag
27    pub fn new(hugr: H, entrypoint: H::Node) -> Self {
28        panic_invalid_node(&hugr, entrypoint);
29        Self { hugr, entrypoint }
30    }
31
32    /// Returns the HUGR wrapped in this view.
33    pub fn into_unwrapped(self) -> H {
34        self.hugr
35    }
36}
37
38impl<H: HugrView> HugrInternals for Rerooted<H> {
39    type RegionPortgraph<'p>
40        = H::RegionPortgraph<'p>
41    where
42        Self: 'p;
43
44    type Node = H::Node;
45
46    type RegionPortgraphNodes = H::RegionPortgraphNodes;
47
48    super::impls::hugr_internal_methods! {this, &this.hugr}
49}
50
51impl<H: HugrView> HugrView for Rerooted<H> {
52    #[inline]
53    fn entrypoint(&self) -> Self::Node {
54        self.entrypoint
55    }
56
57    #[inline]
58    fn entrypoint_optype(&self) -> &crate::ops::OpType {
59        self.hugr.get_optype(self.entrypoint)
60    }
61
62    fn mermaid_string_with_formatter(
63        &self,
64        formatter: crate::hugr::views::render::MermaidFormatter<Self>,
65    ) -> String {
66        self.hugr
67            .mermaid_string_with_formatter(formatter.with_hugr(&self.hugr))
68    }
69
70    delegate::delegate! {
71        to (&self.hugr) {
72                fn module_root(&self) -> Self::Node;
73                fn contains_node(&self, node: Self::Node) -> bool;
74                fn get_parent(&self, node: Self::Node) -> Option<Self::Node>;
75                fn get_metadata(&self, node: Self::Node, key: impl AsRef<str>) -> Option<&crate::hugr::NodeMetadata>;
76                fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType;
77                fn num_nodes(&self) -> usize;
78                fn num_edges(&self) -> usize;
79                fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize;
80                fn num_inputs(&self, node: Self::Node) -> usize;
81                fn num_outputs(&self, node: Self::Node) -> usize;
82                fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone;
83                fn node_ports(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator<Item = crate::Port> + Clone;
84                fn node_outputs(&self, node: Self::Node) -> impl Iterator<Item = crate::OutgoingPort> + Clone;
85                fn node_inputs(&self, node: Self::Node) -> impl Iterator<Item = crate::IncomingPort> + Clone;
86                fn all_node_ports(&self, node: Self::Node) -> impl Iterator<Item = crate::Port> + Clone;
87                fn linked_ports(&self, node: Self::Node, port: impl Into<crate::Port>) -> impl Iterator<Item = (Self::Node, crate::Port)> + Clone;
88                fn all_linked_ports(&self, node: Self::Node, dir: crate::Direction) -> itertools::Either<impl Iterator<Item = (Self::Node, crate::OutgoingPort)>, impl Iterator<Item = (Self::Node, crate::IncomingPort)>>;
89                fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator<Item = (Self::Node, crate::OutgoingPort)>;
90                fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator<Item = (Self::Node, crate::IncomingPort)>;
91                fn single_linked_port(&self, node: Self::Node, port: impl Into<crate::Port>) -> Option<(Self::Node, crate::Port)>;
92                fn single_linked_output(&self, node: Self::Node, port: impl Into<crate::IncomingPort>) -> Option<(Self::Node, crate::OutgoingPort)>;
93                fn single_linked_input(&self, node: Self::Node, port: impl Into<crate::OutgoingPort>) -> Option<(Self::Node, crate::IncomingPort)>;
94                fn linked_outputs(&self, node: Self::Node, port: impl Into<crate::IncomingPort>) -> impl Iterator<Item = (Self::Node, crate::OutgoingPort)>;
95                fn linked_inputs(&self, node: Self::Node, port: impl Into<crate::OutgoingPort>) -> impl Iterator<Item = (Self::Node, crate::IncomingPort)>;
96                fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator<Item = [crate::Port; 2]> + Clone;
97                fn is_linked(&self, node: Self::Node, port: impl Into<crate::Port>) -> bool;
98                fn children(&self, node: Self::Node) -> impl DoubleEndedIterator<Item = Self::Node> + Clone;
99                fn descendants(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
100                fn first_child(&self, node: Self::Node) -> Option<Self::Node>;
101                fn neighbours(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator<Item = Self::Node> + Clone;
102                fn all_neighbours(&self, node: Self::Node) -> impl Iterator<Item = Self::Node> + Clone;
103                #[expect(deprecated)]
104                fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig<Self::Node>) -> String;
105                fn dot_string(&self) -> String;
106                fn static_source(&self, node: Self::Node) -> Option<Self::Node>;
107                fn static_targets(&self, node: Self::Node) -> Option<impl Iterator<Item = (Self::Node, crate::IncomingPort)>>;
108                fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator<Item = (crate::Port, crate::types::Type)>;
109                fn extensions(&self) -> &crate::extension::ExtensionRegistry;
110                fn validate(&self) -> Result<(), crate::hugr::ValidationError<Self::Node>>;
111                fn extract_hugr(&self, parent: Self::Node) -> (crate::Hugr, impl crate::hugr::views::ExtractionResult<Self::Node> + 'static);
112        }
113    }
114}
115
116impl<H: HugrMutInternals> HugrMutInternals for Rerooted<H> {
117    super::impls::hugr_mut_internal_methods! {this, &mut this.hugr}
118}
119
120impl<H: HugrMut> HugrMut for Rerooted<H> {
121    fn set_entrypoint(&mut self, root: Self::Node) {
122        self.entrypoint = root;
123        self.hugr.set_entrypoint(root);
124    }
125
126    delegate::delegate! {
127        to (&mut self.hugr) {
128                fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef<str>) -> &mut crate::hugr::NodeMetadata;
129                fn set_metadata(&mut self, node: Self::Node, key: impl AsRef<str>, metadata: impl Into<crate::hugr::NodeMetadata>);
130                fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef<str>);
131                fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into<crate::ops::OpType>) -> Self::Node;
132                fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into<crate::ops::OpType>) -> Self::Node;
133                fn add_node_after(&mut self, sibling: Self::Node, op: impl Into<crate::ops::OpType>) -> Self::Node;
134                fn remove_node(&mut self, node: Self::Node) -> crate::ops::OpType;
135                fn remove_subtree(&mut self, node: Self::Node);
136                fn copy_descendants(&mut self, root: Self::Node, new_parent: Self::Node, subst: Option<crate::types::Substitution>) -> std::collections::BTreeMap<Self::Node, Self::Node>;
137                fn connect(&mut self, src: Self::Node, src_port: impl Into<crate::OutgoingPort>, dst: Self::Node, dst_port: impl Into<crate::IncomingPort>);
138                fn disconnect(&mut self, node: Self::Node, port: impl Into<crate::Port>);
139                fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort);
140                fn insert_forest(&mut self, other: crate::Hugr, roots: impl IntoIterator<Item=(crate::Node, Self::Node)>) -> InsertForestResult<crate::Node, Self::Node>;
141                fn insert_view_forest<Other: crate::hugr::HugrView>(&mut self, other: &Other, nodes: impl Iterator<Item=Other::Node> + Clone, roots: impl IntoIterator<Item=(Other::Node, Self::Node)>) -> InsertForestResult<Other::Node, Self::Node>;
142                fn use_extension(&mut self, extension: impl Into<std::sync::Arc<crate::extension::Extension>>);
143                fn use_extensions<Reg>(&mut self, registry: impl IntoIterator<Item = Reg>) where crate::extension::ExtensionRegistry: Extend<Reg>;
144        }
145    }
146}
147
148#[cfg(test)]
149mod test {
150    use crate::builder::test::simple_cfg_hugr;
151    use crate::builder::{Dataflow, FunctionBuilder, HugrBuilder, SubContainer};
152    use crate::hugr::HugrMut;
153    use crate::hugr::internal::HugrMutInternals;
154    use crate::hugr::views::ExtractionResult;
155    use crate::ops::handle::NodeHandle;
156    use crate::ops::{DataflowBlock, OpType};
157    use crate::{HugrView, type_row, types::Signature};
158
159    #[test]
160    fn rerooted() {
161        let mut builder = FunctionBuilder::new("main", Signature::new(vec![], vec![])).unwrap();
162        let dfg = builder
163            .dfg_builder_endo([])
164            .unwrap()
165            .finish_sub_container()
166            .unwrap()
167            .node();
168        let mut h = builder.finish_hugr().unwrap();
169        let _func = h.entrypoint();
170
171        // Immutable wrappers
172        let dfg_v = h.with_entrypoint(dfg);
173        assert_eq!(dfg_v.module_root(), h.module_root());
174        assert_eq!(dfg_v.entrypoint(), dfg);
175        assert!(dfg_v.entrypoint_optype().is_dfg());
176        assert!(dfg_v.get_optype(dfg_v.module_root().node()).is_module());
177
178        // Mutable wrappers
179        let mut dfg_v = h.with_entrypoint_mut(dfg);
180        {
181            // That is a HugrMutInternal, so we can try:
182            let root = dfg_v.entrypoint();
183            let bb: OpType = DataflowBlock {
184                inputs: type_row![],
185                other_outputs: type_row![],
186                sum_rows: vec![type_row![]],
187            }
188            .into();
189            dfg_v.replace_op(root, bb.clone());
190
191            assert!(dfg_v.entrypoint_optype().is_dataflow_block());
192            assert!(dfg_v.get_optype(dfg_v.module_root().node()).is_module());
193        }
194        // That modified the original HUGR
195        assert!(h.get_optype(dfg).is_dataflow_block());
196        assert!(h.entrypoint_optype().is_func_defn());
197        assert!(h.get_optype(h.module_root().node()).is_module());
198    }
199
200    #[test]
201    fn extract_rerooted() {
202        let mut hugr = simple_cfg_hugr();
203        let cfg = hugr.entrypoint();
204        let basic_block = hugr.first_child(cfg).unwrap();
205        hugr.set_entrypoint(basic_block);
206        assert!(hugr.get_optype(hugr.entrypoint()).is_dataflow_block());
207
208        let rerooted = hugr.with_entrypoint(cfg);
209        assert!(rerooted.get_optype(rerooted.entrypoint()).is_cfg());
210
211        // Extract the basic block
212        let (extracted_hugr, map) = rerooted.extract_hugr(basic_block);
213        let extracted_cfg = map.extracted_node(cfg);
214        let extracted_bb = map.extracted_node(basic_block);
215        assert_eq!(extracted_hugr.entrypoint(), extracted_bb);
216        assert!(extracted_hugr.get_optype(extracted_cfg).is_cfg());
217        assert_eq!(
218            extracted_hugr.first_child(extracted_cfg),
219            Some(extracted_bb)
220        );
221        assert!(extracted_hugr.get_optype(extracted_bb).is_dataflow_block());
222
223        // Extract the cfg (and current entrypoint)
224        let (extracted_hugr, map) = rerooted.extract_hugr(cfg);
225        let extracted_cfg = map.extracted_node(cfg);
226        let extracted_bb = map.extracted_node(basic_block);
227        assert_eq!(extracted_hugr.entrypoint(), extracted_cfg);
228        assert!(extracted_hugr.get_optype(extracted_cfg).is_cfg());
229        assert_eq!(
230            extracted_hugr.first_child(extracted_cfg),
231            Some(extracted_bb)
232        );
233        assert!(extracted_hugr.get_optype(extracted_bb).is_dataflow_block());
234    }
235
236    #[test]
237    fn mermaid_format() {
238        let h = simple_cfg_hugr();
239        let rerooted = h.with_entrypoint(h.entrypoint());
240        let mermaid_str = rerooted.mermaid_format().finish();
241        assert_eq!(mermaid_str, h.mermaid_format().finish());
242    }
243}