1use std::iter;
4
5use itertools::{Either, Itertools};
6use portgraph::{LinkView, MultiPortGraph, PortView};
7
8use crate::hugr::internal::HugrMutInternals;
9use crate::hugr::{HugrError, HugrMut};
10use crate::ops::handle::NodeHandle;
11use crate::{Direction, Hugr, Node, Port};
12
13use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};
14
15type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>;
16
17#[derive(Clone)]
32pub struct SiblingGraph<'g, Root = Node> {
33 root: Node,
37
38 graph: FlatRegionGraph<'g>,
40
41 hugr: &'g Hugr,
43
44 _phantom: std::marker::PhantomData<Root>,
46}
47
48macro_rules! impl_base_members {
51 () => {
52 #[inline]
53 fn node_count(&self) -> usize {
54 self.base_hugr()
55 .hierarchy
56 .child_count(self.get_pg_index(self.root))
57 + 1
58 }
59
60 #[inline]
61 fn edge_count(&self) -> usize {
62 self.nodes()
64 .map(|n| self.output_neighbours(n).count())
65 .sum()
66 }
67
68 #[inline]
69 fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone {
70 let children = self
72 .base_hugr()
73 .hierarchy
74 .children(self.get_pg_index(self.root))
75 .map(|n| self.get_node(n));
76 iter::once(self.root).chain(children)
77 }
78
79 fn children(
80 &self,
81 node: Self::Node,
82 ) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
83 let children = match node == self.root {
85 true => self.base_hugr().hierarchy.children(self.get_pg_index(node)),
86 false => portgraph::hierarchy::Children::default(),
87 };
88 children.map(|n| self.get_node(n))
89 }
90 };
91}
92
93impl<Root: NodeHandle> HugrView for SiblingGraph<'_, Root> {
94 impl_base_members! {}
95
96 #[inline]
97 fn contains_node(&self, node: Node) -> bool {
98 self.graph.contains_node(self.get_pg_index(node))
99 }
100
101 #[inline]
102 fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
103 self.graph
104 .port_offsets(self.get_pg_index(node), dir)
105 .map_into()
106 }
107
108 #[inline]
109 fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
110 self.graph
111 .all_port_offsets(self.get_pg_index(node))
112 .map_into()
113 }
114
115 fn linked_ports(
116 &self,
117 node: Node,
118 port: impl Into<Port>,
119 ) -> impl Iterator<Item = (Node, Port)> + Clone {
120 let port = self
121 .graph
122 .port_index(self.get_pg_index(node), port.into().pg_offset())
123 .unwrap();
124 self.graph.port_links(port).map(|(_, link)| {
125 let node = self.graph.port_node(link).unwrap();
126 let offset = self.graph.port_offset(link).unwrap();
127 (self.get_node(node), offset.into())
128 })
129 }
130
131 fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
132 self.graph
133 .get_connections(self.get_pg_index(node), self.get_pg_index(other))
134 .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into()))
135 }
136
137 #[inline]
138 fn num_ports(&self, node: Node, dir: Direction) -> usize {
139 self.graph.num_ports(self.get_pg_index(node), dir)
140 }
141
142 #[inline]
143 fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
144 self.graph
145 .neighbours(self.get_pg_index(node), dir)
146 .map(|n| self.get_node(n))
147 }
148
149 #[inline]
150 fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
151 self.graph
152 .all_neighbours(self.get_pg_index(node))
153 .map(|n| self.get_node(n))
154 }
155}
156impl<Root: NodeHandle> RootTagged for SiblingGraph<'_, Root> {
157 type RootHandle = Root;
158}
159
160impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> {
161 fn new_unchecked(hugr: &'a impl HugrView<Node = Node>, root: Node) -> Self {
162 let hugr = hugr.base_hugr();
163 Self {
164 root,
165 graph: FlatRegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)),
166 hugr,
167 _phantom: std::marker::PhantomData,
168 }
169 }
170}
171
172impl<'a, Root> HierarchyView<'a> for SiblingGraph<'a, Root>
173where
174 Root: NodeHandle,
175{
176 fn try_new(hugr: &'a impl HugrView<Node = Node>, root: Node) -> Result<Self, HugrError> {
177 assert!(
178 hugr.valid_node(root),
179 "Cannot create a sibling graph from an invalid node {}.",
180 root
181 );
182 check_tag::<Root, _>(hugr, root)?;
183 Ok(Self::new_unchecked(hugr, root))
184 }
185}
186
187impl<Root: NodeHandle> ExtractHugr for SiblingGraph<'_, Root> {}
188
189impl<'g, Root: NodeHandle> HugrInternals for SiblingGraph<'g, Root>
190where
191 Root: NodeHandle,
192{
193 type Portgraph<'p>
194 = &'p FlatRegionGraph<'g>
195 where
196 Self: 'p;
197 type Node = Node;
198
199 #[inline]
200 fn portgraph(&self) -> Self::Portgraph<'_> {
201 &self.graph
202 }
203
204 #[inline]
205 fn base_hugr(&self) -> &Hugr {
206 self.hugr
207 }
208
209 #[inline]
210 fn root_node(&self) -> Node {
211 self.root
212 }
213
214 #[inline]
215 fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex {
216 self.hugr.get_pg_index(node)
217 }
218
219 #[inline]
220 fn get_node(&self, index: portgraph::NodeIndex) -> Node {
221 self.hugr.get_node(index)
222 }
223}
224
225pub struct SiblingMut<'g, Root = Node> {
237 root: Node,
239
240 hugr: &'g mut Hugr,
242
243 _phantom: std::marker::PhantomData<Root>,
245}
246
247impl<'g, Root: NodeHandle> SiblingMut<'g, Root> {
248 pub fn try_new<Base: HugrMut>(hugr: &'g mut Base, root: Node) -> Result<Self, HugrError> {
251 if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) {
252 return Err(HugrError::InvalidTag {
253 required: Base::RootHandle::TAG,
254 actual: Root::TAG,
255 });
256 }
257 check_tag::<Root, _>(hugr, root)?;
258 Ok(Self {
259 hugr: hugr.hugr_mut(),
260 root,
261 _phantom: std::marker::PhantomData,
262 })
263 }
264}
265
266impl<Root: NodeHandle> ExtractHugr for SiblingMut<'_, Root> {}
267
268impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> {
269 type Portgraph<'p>
270 = FlatRegionGraph<'p>
271 where
272 'g: 'p,
273 Root: 'p;
274 type Node = Node;
275
276 fn portgraph(&self) -> Self::Portgraph<'_> {
277 FlatRegionGraph::new(
278 &self.base_hugr().graph,
279 &self.base_hugr().hierarchy,
280 self.root.pg_index(),
281 )
282 }
283
284 fn base_hugr(&self) -> &Hugr {
285 self.hugr
286 }
287
288 fn root_node(&self) -> Node {
289 self.root
290 }
291
292 #[inline]
293 fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex {
294 self.hugr.get_pg_index(node)
295 }
296
297 #[inline]
298 fn get_node(&self, index: portgraph::NodeIndex) -> Node {
299 self.hugr.get_node(index)
300 }
301}
302
303impl<Root: NodeHandle> HugrView for SiblingMut<'_, Root> {
304 impl_base_members! {}
305
306 fn contains_node(&self, node: Node) -> bool {
307 node == self.root || self.base_hugr().get_parent(node) == Some(self.root)
310 }
311
312 fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
313 self.base_hugr().node_ports(node, dir)
314 }
315
316 fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
317 self.base_hugr().all_node_ports(node)
318 }
319
320 fn linked_ports(
321 &self,
322 node: Node,
323 port: impl Into<Port>,
324 ) -> impl Iterator<Item = (Node, Port)> + Clone {
325 self.hugr
326 .linked_ports(node, port)
327 .filter(|(n, _)| self.contains_node(*n))
328 }
329
330 fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
331 match self.contains_node(node) && self.contains_node(other) {
332 false => Either::Left(iter::empty()),
334 true => Either::Right(self.hugr.node_connections(node, other)),
336 }
337 }
338
339 fn num_ports(&self, node: Node, dir: Direction) -> usize {
340 self.base_hugr().num_ports(node, dir)
341 }
342
343 fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
344 self.hugr
345 .neighbours(node, dir)
346 .filter(|n| self.contains_node(*n))
347 }
348
349 fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
350 self.hugr
351 .all_neighbours(node)
352 .filter(|n| self.contains_node(*n))
353 }
354}
355
356impl<Root: NodeHandle> RootTagged for SiblingMut<'_, Root> {
357 type RootHandle = Root;
358}
359
360impl<Root: NodeHandle> HugrMutInternals for SiblingMut<'_, Root> {
361 fn hugr_mut(&mut self) -> &mut Hugr {
362 self.hugr
363 }
364}
365
366impl<Root: NodeHandle> HugrMut for SiblingMut<'_, Root> {}
367
368#[cfg(test)]
369mod test {
370 use std::borrow::Cow;
371
372 use rstest::rstest;
373
374 use crate::builder::test::simple_dfg_hugr;
375 use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
376 use crate::extension::prelude::{qb_t, usize_t};
377 use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID};
378 use crate::ops::{dataflow::IOTrait, Input, OpTag, Output};
379 use crate::ops::{OpTrait, OpType};
380 use crate::types::Signature;
381 use crate::utils::test_quantum_extension::EXTENSION_ID;
382 use crate::IncomingPort;
383
384 use super::super::descendants::test::make_module_hgr;
385 use super::*;
386
387 fn test_properties<T>(
388 hugr: &Hugr,
389 def: Node,
390 inner: Node,
391 region: T,
392 inner_region: T,
393 ) -> Result<(), Box<dyn std::error::Error>>
394 where
395 T: HugrView<Node = Node> + Sized,
396 {
397 let def_io = region.get_io(def).unwrap();
398
399 assert_eq!(region.node_count(), 5);
400 assert_eq!(region.portgraph().node_count(), 5);
401 assert!(region.nodes().all(|n| n == def
402 || hugr.get_parent(n) == Some(def)
403 || hugr.get_parent(n) == Some(inner)));
404 assert_eq!(region.children(inner).count(), 0);
405
406 assert_eq!(
407 region.poly_func_type(),
408 Some(
409 Signature::new_endo(vec![usize_t(), qb_t()])
410 .with_extension_delta(EXTENSION_ID)
411 .into()
412 )
413 );
414
415 assert_eq!(
416 inner_region.inner_function_type().map(Cow::into_owned),
417 Some(Signature::new(vec![usize_t()], vec![usize_t()]))
418 );
419 assert_eq!(inner_region.node_count(), 3);
420 assert_eq!(inner_region.edge_count(), 1);
421 assert_eq!(inner_region.children(inner).count(), 2);
422 assert_eq!(inner_region.children(hugr.root()).count(), 0);
423 assert_eq!(
424 inner_region.num_ports(inner, Direction::Outgoing),
425 inner_region.node_ports(inner, Direction::Outgoing).count()
426 );
427 assert_eq!(
428 inner_region.num_ports(inner, Direction::Incoming)
429 + inner_region.num_ports(inner, Direction::Outgoing),
430 inner_region.all_node_ports(inner).count()
431 );
432
433 assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0);
436 assert_eq!(region.node_connections(inner, def_io[1]).count(), 1);
437 assert_eq!(
438 inner_region
439 .linked_ports(inner, IncomingPort::from(0))
440 .count(),
441 0
442 );
443 assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1);
444 assert_eq!(
445 inner_region.neighbours(inner, Direction::Outgoing).count(),
446 0
447 );
448 assert_eq!(inner_region.all_neighbours(inner).count(), 0);
449 assert_eq!(
450 inner_region
451 .linked_ports(inner, IncomingPort::from(0))
452 .count(),
453 0
454 );
455
456 Ok(())
457 }
458
459 #[rstest]
460 fn sibling_graph_properties() -> Result<(), Box<dyn std::error::Error>> {
461 let (hugr, def, inner) = make_module_hgr()?;
462
463 test_properties::<SiblingGraph>(
464 &hugr,
465 def,
466 inner,
467 SiblingGraph::try_new(&hugr, def).unwrap(),
468 SiblingGraph::try_new(&hugr, inner).unwrap(),
469 )
470 }
471
472 #[rstest]
473 fn sibling_mut_properties() -> Result<(), Box<dyn std::error::Error>> {
474 let (hugr, def, inner) = make_module_hgr()?;
475 let mut def_region_hugr = hugr.clone();
476 let mut inner_region_hugr = hugr.clone();
477
478 test_properties::<SiblingMut>(
479 &hugr,
480 def,
481 inner,
482 SiblingMut::try_new(&mut def_region_hugr, def).unwrap(),
483 SiblingMut::try_new(&mut inner_region_hugr, inner).unwrap(),
484 )
485 }
486
487 #[test]
488 fn nested_flat() -> Result<(), Box<dyn std::error::Error>> {
489 let mut module_builder = ModuleBuilder::new();
490 let fty = Signature::new(vec![usize_t()], vec![usize_t()]);
491 let mut fbuild = module_builder.define_function("main", fty.clone())?;
492 let dfg = fbuild.dfg_builder(fty, fbuild.input_wires())?;
493 let ins = dfg.input_wires();
494 let sub_dfg = dfg.finish_with_outputs(ins)?;
495 let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?;
496 let h = module_builder.finish_hugr()?;
497 let sub_dfg = sub_dfg.node();
498
499 let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&h, sub_dfg)?;
501 let fun_view: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&h, fun.node())?;
502 assert_eq!(fun_view.children(sub_dfg).count(), 0);
503
504 let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&fun_view, sub_dfg)?;
506
507 let just_io = vec![
509 Input::new(vec![usize_t()]).into(),
510 Output::new(vec![usize_t()]).into(),
511 ];
512 for d in [dfg_view, nested_dfg_view] {
513 assert_eq!(
514 d.children(sub_dfg).map(|n| d.get_optype(n)).collect_vec(),
515 just_io.iter().collect_vec()
516 );
517 }
518
519 Ok(())
520 }
521
522 #[rstest]
524 fn flat_mut(mut simple_dfg_hugr: Hugr) {
525 simple_dfg_hugr.validate().unwrap();
526 let root = simple_dfg_hugr.root();
527 let signature = simple_dfg_hugr.inner_function_type().unwrap().into_owned();
528
529 let sib_mut = SiblingMut::<CfgID>::try_new(&mut simple_dfg_hugr, root);
530 assert_eq!(
531 sib_mut.err(),
532 Some(HugrError::InvalidTag {
533 required: OpTag::Cfg,
534 actual: OpTag::Dfg
535 })
536 );
537
538 let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
539 let bad_nodetype: OpType = crate::ops::CFG { signature }.into();
540 assert_eq!(
541 sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()),
542 Err(HugrError::InvalidTag {
543 required: OpTag::Dfg,
544 actual: OpTag::Cfg
545 })
546 );
547
548 simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap();
550 assert!(simple_dfg_hugr.validate().is_err());
551 }
552
553 #[rstest]
554 fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) {
555 let root = simple_dfg_hugr.root();
556 let case_nodetype = crate::ops::Case {
557 signature: simple_dfg_hugr
558 .root_type()
559 .dataflow_signature()
560 .unwrap()
561 .into_owned(),
562 };
563 let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
564 assert_eq!(
566 sib_mut.replace_op(root, case_nodetype),
567 Err(HugrError::InvalidTag {
568 required: OpTag::Dfg,
569 actual: OpTag::Case
570 })
571 );
572
573 let nested_sib_mut = SiblingMut::<DataflowParentID>::try_new(&mut sib_mut, root);
574 assert!(nested_sib_mut.is_err());
575 }
576
577 #[rstest]
578 fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
579 let (hugr, _def, inner) = make_module_hgr()?;
580
581 let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?;
582 let extracted = region.extract_hugr();
583 extracted.validate()?;
584
585 let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?;
586
587 assert_eq!(region.node_count(), extracted.node_count());
588 assert_eq!(region.root_type(), extracted.root_type());
589
590 Ok(())
591 }
592}