1use itertools::Itertools;
5use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};
6
7use crate::hugr::HugrError;
8use crate::ops::handle::NodeHandle;
9use crate::{Direction, Hugr, Node, Port};
10
11use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};
12
13type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>;
14
15#[derive(Clone)]
28pub struct DescendantsGraph<'g, Root = Node> {
29 root: Node,
33
34 graph: RegionGraph<'g>,
36
37 hugr: &'g Hugr,
39
40 _phantom: std::marker::PhantomData<Root>,
42}
43impl<Root: NodeHandle> HugrView for DescendantsGraph<'_, Root> {
44 #[inline]
45 fn contains_node(&self, node: Node) -> bool {
46 self.graph.contains_node(self.get_pg_index(node))
47 }
48
49 #[inline]
50 fn node_count(&self) -> usize {
51 self.graph.node_count()
52 }
53
54 #[inline]
55 fn edge_count(&self) -> usize {
56 self.graph.link_count()
57 }
58
59 #[inline]
60 fn nodes(&self) -> impl Iterator<Item = Node> + Clone {
61 self.graph.nodes_iter().map(|index| self.get_node(index))
62 }
63
64 #[inline]
65 fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
66 self.graph
67 .port_offsets(self.get_pg_index(node), dir)
68 .map_into()
69 }
70
71 #[inline]
72 fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
73 self.graph
74 .all_port_offsets(self.get_pg_index(node))
75 .map_into()
76 }
77
78 fn linked_ports(
79 &self,
80 node: Node,
81 port: impl Into<Port>,
82 ) -> impl Iterator<Item = (Node, Port)> + Clone {
83 let port = self
84 .graph
85 .port_index(self.get_pg_index(node), port.into().pg_offset())
86 .unwrap();
87 self.graph.port_links(port).map(|(_, link)| {
88 let port: PortIndex = link.into();
89 let node = self.graph.port_node(port).unwrap();
90 let offset = self.graph.port_offset(port).unwrap();
91 (self.get_node(node), offset.into())
92 })
93 }
94
95 fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
96 self.graph
97 .get_connections(self.get_pg_index(node), self.get_pg_index(other))
98 .map(|(p1, p2)| {
99 [p1, p2].map(|link| {
100 let offset = self.graph.port_offset(link).unwrap();
101 offset.into()
102 })
103 })
104 }
105
106 #[inline]
107 fn num_ports(&self, node: Node, dir: Direction) -> usize {
108 self.graph.num_ports(self.get_pg_index(node), dir)
109 }
110
111 #[inline]
112 fn children(&self, node: Node) -> impl DoubleEndedIterator<Item = Node> + Clone {
113 let children = match self.graph.contains_node(self.get_pg_index(node)) {
114 true => self.base_hugr().hierarchy.children(self.get_pg_index(node)),
115 false => portgraph::hierarchy::Children::default(),
116 };
117 children.map(|index| self.get_node(index))
118 }
119
120 #[inline]
121 fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
122 self.graph
123 .neighbours(self.get_pg_index(node), dir)
124 .map(|index| self.get_node(index))
125 }
126
127 #[inline]
128 fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
129 self.graph
130 .all_neighbours(self.get_pg_index(node))
131 .map(|index| self.get_node(index))
132 }
133}
134impl<Root: NodeHandle> RootTagged for DescendantsGraph<'_, Root> {
135 type RootHandle = Root;
136}
137
138impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root>
139where
140 Root: NodeHandle,
141{
142 fn try_new(hugr: &'a impl HugrView<Node = Node>, root: Node) -> Result<Self, HugrError> {
143 check_tag::<Root, Node>(hugr, root)?;
144 let hugr = hugr.base_hugr();
145 Ok(Self {
146 root,
147 graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)),
148 hugr,
149 _phantom: std::marker::PhantomData,
150 })
151 }
152}
153
154impl<Root: NodeHandle> ExtractHugr for DescendantsGraph<'_, Root> {}
155
156impl<'g, Root> super::HugrInternals for DescendantsGraph<'g, Root>
157where
158 Root: NodeHandle,
159{
160 type Portgraph<'p>
161 = &'p RegionGraph<'g>
162 where
163 Self: 'p;
164
165 type Node = Node;
166
167 #[inline]
168 fn portgraph(&self) -> Self::Portgraph<'_> {
169 &self.graph
170 }
171
172 #[inline]
173 fn base_hugr(&self) -> &Hugr {
174 self.hugr
175 }
176
177 #[inline]
178 fn root_node(&self) -> Node {
179 self.root
180 }
181
182 #[inline]
183 fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex {
184 self.hugr.get_pg_index(node)
185 }
186
187 #[inline]
188 fn get_node(&self, index: portgraph::NodeIndex) -> Node {
189 self.hugr.get_node(index)
190 }
191}
192
193#[cfg(test)]
194pub(super) mod test {
195 use std::borrow::Cow;
196
197 use rstest::rstest;
198
199 use crate::extension::prelude::{qb_t, usize_t};
200 use crate::IncomingPort;
201 use crate::{
202 builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder},
203 types::Signature,
204 utils::test_quantum_extension::{h_gate, EXTENSION_ID},
205 };
206
207 use super::*;
208
209 pub(in crate::hugr::views) fn make_module_hgr(
213 ) -> Result<(Hugr, Node, Node), Box<dyn std::error::Error>> {
214 let mut module_builder = ModuleBuilder::new();
215
216 let (f_id, inner_id) = {
217 let mut func_builder = module_builder.define_function(
218 "main",
219 Signature::new_endo(vec![usize_t(), qb_t()]).with_extension_delta(EXTENSION_ID),
220 )?;
221
222 let [int, qb] = func_builder.input_wires_arr();
223
224 let q_out = func_builder.add_dataflow_op(h_gate(), vec![qb])?;
225
226 let inner_id = {
227 let inner_builder = func_builder
228 .dfg_builder(Signature::new(vec![usize_t()], vec![usize_t()]), [int])?;
229 let w = inner_builder.input_wires();
230 inner_builder.finish_with_outputs(w)
231 }?;
232
233 let f_id =
234 func_builder.finish_with_outputs(inner_id.outputs().chain(q_out.outputs()))?;
235 (f_id, inner_id)
236 };
237 let hugr = module_builder.finish_hugr()?;
238 Ok((hugr, f_id.handle().node(), inner_id.handle().node()))
239 }
240
241 #[test]
242 fn full_region() -> Result<(), Box<dyn std::error::Error>> {
243 let (hugr, def, inner) = make_module_hgr()?;
244
245 let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;
246 let def_io = region.get_io(def).unwrap();
247
248 assert_eq!(region.node_count(), 7);
249 assert!(region.nodes().all(|n| n == def
250 || hugr.get_parent(n) == Some(def)
251 || hugr.get_parent(n) == Some(inner)));
252 assert_eq!(region.children(inner).count(), 2);
253
254 assert_eq!(
255 region.poly_func_type(),
256 Some(
257 Signature::new_endo(vec![usize_t(), qb_t()])
258 .with_extension_delta(EXTENSION_ID)
259 .into()
260 )
261 );
262
263 let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?;
264 assert_eq!(
265 inner_region.inner_function_type().map(Cow::into_owned),
266 Some(Signature::new(vec![usize_t()], vec![usize_t()]))
267 );
268 assert_eq!(inner_region.node_count(), 3);
269 assert_eq!(inner_region.edge_count(), 1);
270 assert_eq!(inner_region.children(inner).count(), 2);
271 assert_eq!(inner_region.children(hugr.root()).count(), 0);
272 assert_eq!(
273 inner_region.num_ports(inner, Direction::Outgoing),
274 inner_region.node_ports(inner, Direction::Outgoing).count()
275 );
276 assert_eq!(
277 inner_region.num_ports(inner, Direction::Incoming)
278 + inner_region.num_ports(inner, Direction::Outgoing),
279 inner_region.all_node_ports(inner).count()
280 );
281
282 assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0);
285 assert_eq!(region.node_connections(inner, def_io[1]).count(), 1);
286 assert_eq!(
287 inner_region
288 .linked_ports(inner, IncomingPort::from(0))
289 .count(),
290 0
291 );
292 assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1);
293 assert_eq!(
294 inner_region.neighbours(inner, Direction::Outgoing).count(),
295 0
296 );
297 assert_eq!(inner_region.all_neighbours(inner).count(), 0);
298 assert_eq!(
299 inner_region
300 .linked_ports(inner, IncomingPort::from(0))
301 .count(),
302 0
303 );
304
305 Ok(())
306 }
307
308 #[rstest]
309 fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
310 let (hugr, def, _inner) = make_module_hgr()?;
311
312 let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;
313 let extracted = region.extract_hugr();
314 extracted.validate()?;
315
316 let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?;
317
318 assert_eq!(region.node_count(), extracted.node_count());
319 assert_eq!(region.root_type(), extracted.root_type());
320
321 Ok(())
322 }
323}