hugr_core/hugr/views/
descendants.rs

1//! DescendantsGraph: view onto the subgraph of the HUGR starting from a root
2//! (all descendants at all depths).
3
4use itertools::Itertools;
5use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};
6
7use crate::hugr::HugrError;
8use crate::ops::handle::NodeHandle;
9use crate::{Direction, Hugr, Node, Port};
10
11use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};
12
13type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;
14
15/// View of a HUGR descendants graph.
16///
17/// Includes the root node (which uniquely has no parent) and all its descendants.
18///
19/// See [`SiblingGraph`] for a view that includes only the root and
20/// its immediate children.  Prefer using [`SiblingGraph`] when possible,
21/// as it is more efficient.
22///
23/// Implements the [`HierarchyView`] trait, as well as [`HugrView`], it can be
24/// used interchangeably with [`SiblingGraph`].
25///
26/// [`SiblingGraph`]: super::SiblingGraph
27#[derive(Clone)]
28pub struct DescendantsGraph<'g, Root = Node> {
29    /// The chosen root node.
30    // TODO: this can only be made generic once the call to base_hugr is removed
31    // in try_new. See https://github.com/CQCL/hugr/issues/1926
32    root: Node,
33
34    /// The graph encoding the adjacency structure of the HUGR.
35    graph: RegionGraph<'g>,
36
37    /// The node hierarchy.
38    hugr: &'g Hugr,
39
40    /// The operation handle of the root node.
41    _phantom: std::marker::PhantomData<Root>,
42}
43impl<Root: NodeHandle> HugrView for DescendantsGraph<'_, Root> {
44    #[inline]
45    fn contains_node(&self, node: Node) -> bool {
46        self.graph.contains_node(self.get_pg_index(node))
47    }
48
49    #[inline]
50    fn node_count(&self) -> usize {
51        self.graph.node_count()
52    }
53
54    #[inline]
55    fn edge_count(&self) -> usize {
56        self.graph.link_count()
57    }
58
59    #[inline]
60    fn nodes(&self) -> impl Iterator<Item = Node> + Clone {
61        self.graph.nodes_iter().map(|index| self.get_node(index))
62    }
63
64    #[inline]
65    fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
66        self.graph
67            .port_offsets(self.get_pg_index(node), dir)
68            .map_into()
69    }
70
71    #[inline]
72    fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
73        self.graph
74            .all_port_offsets(self.get_pg_index(node))
75            .map_into()
76    }
77
78    fn linked_ports(
79        &self,
80        node: Node,
81        port: impl Into<Port>,
82    ) -> impl Iterator<Item = (Node, Port)> + Clone {
83        let port = self
84            .graph
85            .port_index(self.get_pg_index(node), port.into().pg_offset())
86            .unwrap();
87        self.graph.port_links(port).map(|(_, link)| {
88            let port: PortIndex = link.into();
89            let node = self.graph.port_node(port).unwrap();
90            let offset = self.graph.port_offset(port).unwrap();
91            (self.get_node(node), offset.into())
92        })
93    }
94
95    fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
96        self.graph
97            .get_connections(self.get_pg_index(node), self.get_pg_index(other))
98            .map(|(p1, p2)| {
99                [p1, p2].map(|link| {
100                    let offset = self.graph.port_offset(link).unwrap();
101                    offset.into()
102                })
103            })
104    }
105
106    #[inline]
107    fn num_ports(&self, node: Node, dir: Direction) -> usize {
108        self.graph.num_ports(self.get_pg_index(node), dir)
109    }
110
111    #[inline]
112    fn children(&self, node: Node) -> impl DoubleEndedIterator<Item = Node> + Clone {
113        let children = match self.graph.contains_node(self.get_pg_index(node)) {
114            true => self.base_hugr().hierarchy.children(self.get_pg_index(node)),
115            false => portgraph::hierarchy::Children::default(),
116        };
117        children.map(|index| self.get_node(index))
118    }
119
120    #[inline]
121    fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
122        self.graph
123            .neighbours(self.get_pg_index(node), dir)
124            .map(|index| self.get_node(index))
125    }
126
127    #[inline]
128    fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
129        self.graph
130            .all_neighbours(self.get_pg_index(node))
131            .map(|index| self.get_node(index))
132    }
133}
134impl<Root: NodeHandle> RootTagged for DescendantsGraph<'_, Root> {
135    type RootHandle = Root;
136}
137
138impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root>
139where
140    Root: NodeHandle,
141{
142    fn try_new(hugr: &'a impl HugrView<Node = Node>, root: Node) -> Result<Self, HugrError> {
143        check_tag::<Root, Node>(hugr, root)?;
144        let hugr = hugr.base_hugr();
145        Ok(Self {
146            root,
147            graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)),
148            hugr,
149            _phantom: std::marker::PhantomData,
150        })
151    }
152}
153
154impl<Root: NodeHandle> ExtractHugr for DescendantsGraph<'_, Root> {}
155
156impl<'g, Root> super::HugrInternals for DescendantsGraph<'g, Root>
157where
158    Root: NodeHandle,
159{
160    type Portgraph<'p>
161        = &'p RegionGraph<'g>
162    where
163        Self: 'p;
164
165    type Node = Node;
166
167    #[inline]
168    fn portgraph(&self) -> Self::Portgraph<'_> {
169        &self.graph
170    }
171
172    #[inline]
173    fn base_hugr(&self) -> &Hugr {
174        self.hugr
175    }
176
177    #[inline]
178    fn root_node(&self) -> Node {
179        self.root
180    }
181
182    #[inline]
183    fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex {
184        self.hugr.get_pg_index(node)
185    }
186
187    #[inline]
188    fn get_node(&self, index: portgraph::NodeIndex) -> Node {
189        self.hugr.get_node(index)
190    }
191}
192
193#[cfg(test)]
194pub(super) mod test {
195    use std::borrow::Cow;
196
197    use rstest::rstest;
198
199    use crate::extension::prelude::{qb_t, usize_t};
200    use crate::IncomingPort;
201    use crate::{
202        builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder},
203        types::Signature,
204        utils::test_quantum_extension::{h_gate, EXTENSION_ID},
205    };
206
207    use super::*;
208
209    /// Make a module hugr with a fn definition containing an inner dfg node.
210    ///
211    /// Returns the hugr, the fn node id, and the nested dgf node id.
212    pub(in crate::hugr::views) fn make_module_hgr(
213    ) -> Result<(Hugr, Node, Node), Box<dyn std::error::Error>> {
214        let mut module_builder = ModuleBuilder::new();
215
216        let (f_id, inner_id) = {
217            let mut func_builder = module_builder.define_function(
218                "main",
219                Signature::new_endo(vec![usize_t(), qb_t()]).with_extension_delta(EXTENSION_ID),
220            )?;
221
222            let [int, qb] = func_builder.input_wires_arr();
223
224            let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?;
225
226            let inner_id = {
227                let inner_builder = func_builder
228                    .dfg_builder(Signature::new(vec![usize_t()], vec![usize_t()]), [int])?;
229                let w = inner_builder.input_wires();
230                inner_builder.finish_with_outputs(w)
231            }?;
232
233            let f_id =
234                func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?;
235            (f_id, inner_id)
236        };
237        let hugr = module_builder.finish_hugr()?;
238        Ok((hugr, f_id.handle().node(), inner_id.handle().node()))
239    }
240
241    #[test]
242    fn full_region() -> Result<(), Box<dyn std::error::Error>> {
243        let (hugr, def, inner) = make_module_hgr()?;
244
245        let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;
246        let def_io = region.get_io(def).unwrap();
247
248        assert_eq!(region.node_count(), 7);
249        assert!(region.nodes().all(|n| n == def
250            || hugr.get_parent(n) == Some(def)
251            || hugr.get_parent(n) == Some(inner)));
252        assert_eq!(region.children(inner).count(), 2);
253
254        assert_eq!(
255            region.poly_func_type(),
256            Some(
257                Signature::new_endo(vec![usize_t(), qb_t()])
258                    .with_extension_delta(EXTENSION_ID)
259                    .into()
260            )
261        );
262
263        let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?;
264        assert_eq!(
265            inner_region.inner_function_type().map(Cow::into_owned),
266            Some(Signature::new(vec![usize_t()], vec![usize_t()]))
267        );
268        assert_eq!(inner_region.node_count(), 3);
269        assert_eq!(inner_region.edge_count(), 1);
270        assert_eq!(inner_region.children(inner).count(), 2);
271        assert_eq!(inner_region.children(hugr.root()).count(), 0);
272        assert_eq!(
273            inner_region.num_ports(inner, Direction::Outgoing),
274            inner_region.node_ports(inner, Direction::Outgoing).count()
275        );
276        assert_eq!(
277            inner_region.num_ports(inner, Direction::Incoming)
278                + inner_region.num_ports(inner, Direction::Outgoing),
279            inner_region.all_node_ports(inner).count()
280        );
281
282        // The inner region filters out the connections to the main function I/O nodes,
283        // while the outer region includes them.
284        assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0);
285        assert_eq!(region.node_connections(inner, def_io[1]).count(), 1);
286        assert_eq!(
287            inner_region
288                .linked_ports(inner, IncomingPort::from(0))
289                .count(),
290            0
291        );
292        assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1);
293        assert_eq!(
294            inner_region.neighbours(inner, Direction::Outgoing).count(),
295            0
296        );
297        assert_eq!(inner_region.all_neighbours(inner).count(), 0);
298        assert_eq!(
299            inner_region
300                .linked_ports(inner, IncomingPort::from(0))
301                .count(),
302            0
303        );
304
305        Ok(())
306    }
307
308    #[rstest]
309    fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
310        let (hugr, def, _inner) = make_module_hgr()?;
311
312        let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;
313        let extracted = region.extract_hugr();
314        extracted.validate()?;
315
316        let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;
317
318        assert_eq!(region.node_count(), extracted.node_count());
319        assert_eq!(region.root_type(), extracted.root_type());
320
321        Ok(())
322    }
323}