hugr_core/hugr/views/
sibling.rs

1//! SiblingGraph: view onto a sibling subgraph of the HUGR.
2
3use std::iter;
4
5use itertools::{Either, Itertools};
6use portgraph::{LinkView, MultiPortGraph, PortView};
7
8use crate::hugr::internal::HugrMutInternals;
9use crate::hugr::{HugrError, HugrMut};
10use crate::ops::handle::NodeHandle;
11use crate::{Direction, Hugr, Node, Port};
12
13use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};
14
15type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>;
16
17/// View of a HUGR sibling graph.
18///
19/// Includes only the root node and its direct children, but no deeper descendants.
20/// However, the descendants can still be accessed by creating [`SiblingGraph`]s and/or
21/// [`DescendantsGraph`]s from nodes in this view.
22///
23/// Uniquely, the root node has no parent.
24///
25/// See [`DescendantsGraph`] for a view that includes all descendants of the root.
26///
27/// Implements the [`HierarchyView`] trait, as well as [`HugrView`], it can be
28/// used interchangeably with [`DescendantsGraph`].
29///
30/// [`DescendantsGraph`]: super::DescendantsGraph
31#[derive(Clone)]
32pub struct SiblingGraph<'g, Root = Node> {
33    /// The chosen root node.
34    // TODO: this can only be made generic once the call to base_hugr is removed
35    // in try_new. See https://github.com/CQCL/hugr/issues/1926
36    root: Node,
37
38    /// The filtered portgraph encoding the adjacency structure of the HUGR.
39    graph: FlatRegionGraph<'g>,
40
41    /// The underlying Hugr onto which this view is a filter
42    hugr: &'g Hugr,
43
44    /// The operation type of the root node.
45    _phantom: std::marker::PhantomData<Root>,
46}
47
48/// HugrView trait members common to both [SiblingGraph] and [SiblingMut],
49/// i.e. that rely only on [HugrInternals::base_hugr]
50macro_rules! impl_base_members {
51    () => {
52        #[inline]
53        fn node_count(&self) -> usize {
54            self.base_hugr()
55                .hierarchy
56                .child_count(self.get_pg_index(self.root))
57                + 1
58        }
59
60        #[inline]
61        fn edge_count(&self) -> usize {
62            // Faster implementation than filtering all the nodes in the internal graph.
63            self.nodes()
64                .map(|n| self.output_neighbours(n).count())
65                .sum()
66        }
67
68        #[inline]
69        fn nodes(&self) -> impl Iterator<Item = Self::Node> + Clone {
70            // Faster implementation than filtering all the nodes in the internal graph.
71            let children = self
72                .base_hugr()
73                .hierarchy
74                .children(self.get_pg_index(self.root))
75                .map(|n| self.get_node(n));
76            iter::once(self.root).chain(children)
77        }
78
79        fn children(
80            &self,
81            node: Self::Node,
82        ) -> impl DoubleEndedIterator<Item = Self::Node> + Clone {
83            // Same as SiblingGraph
84            let children = match node == self.root {
85                true => self.base_hugr().hierarchy.children(self.get_pg_index(node)),
86                false => portgraph::hierarchy::Children::default(),
87            };
88            children.map(|n| self.get_node(n))
89        }
90    };
91}
92
93impl<Root: NodeHandle> HugrView for SiblingGraph<'_, Root> {
94    impl_base_members! {}
95
96    #[inline]
97    fn contains_node(&self, node: Node) -> bool {
98        self.graph.contains_node(self.get_pg_index(node))
99    }
100
101    #[inline]
102    fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
103        self.graph
104            .port_offsets(self.get_pg_index(node), dir)
105            .map_into()
106    }
107
108    #[inline]
109    fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
110        self.graph
111            .all_port_offsets(self.get_pg_index(node))
112            .map_into()
113    }
114
115    fn linked_ports(
116        &self,
117        node: Node,
118        port: impl Into<Port>,
119    ) -> impl Iterator<Item = (Node, Port)> + Clone {
120        let port = self
121            .graph
122            .port_index(self.get_pg_index(node), port.into().pg_offset())
123            .unwrap();
124        self.graph.port_links(port).map(|(_, link)| {
125            let node = self.graph.port_node(link).unwrap();
126            let offset = self.graph.port_offset(link).unwrap();
127            (self.get_node(node), offset.into())
128        })
129    }
130
131    fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
132        self.graph
133            .get_connections(self.get_pg_index(node), self.get_pg_index(other))
134            .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into()))
135    }
136
137    #[inline]
138    fn num_ports(&self, node: Node, dir: Direction) -> usize {
139        self.graph.num_ports(self.get_pg_index(node), dir)
140    }
141
142    #[inline]
143    fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
144        self.graph
145            .neighbours(self.get_pg_index(node), dir)
146            .map(|n| self.get_node(n))
147    }
148
149    #[inline]
150    fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
151        self.graph
152            .all_neighbours(self.get_pg_index(node))
153            .map(|n| self.get_node(n))
154    }
155}
156impl<Root: NodeHandle> RootTagged for SiblingGraph<'_, Root> {
157    type RootHandle = Root;
158}
159
160impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> {
161    fn new_unchecked(hugr: &'a impl HugrView<Node = Node>, root: Node) -> Self {
162        let hugr = hugr.base_hugr();
163        Self {
164            root,
165            graph: FlatRegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)),
166            hugr,
167            _phantom: std::marker::PhantomData,
168        }
169    }
170}
171
172impl<'a, Root> HierarchyView<'a> for SiblingGraph<'a, Root>
173where
174    Root: NodeHandle,
175{
176    fn try_new(hugr: &'a impl HugrView<Node = Node>, root: Node) -> Result<Self, HugrError> {
177        assert!(
178            hugr.valid_node(root),
179            "Cannot create a sibling graph from an invalid node {}.",
180            root
181        );
182        check_tag::<Root, _>(hugr, root)?;
183        Ok(Self::new_unchecked(hugr, root))
184    }
185}
186
187impl<Root: NodeHandle> ExtractHugr for SiblingGraph<'_, Root> {}
188
189impl<'g, Root: NodeHandle> HugrInternals for SiblingGraph<'g, Root>
190where
191    Root: NodeHandle,
192{
193    type Portgraph<'p>
194        = &'p FlatRegionGraph<'g>
195    where
196        Self: 'p;
197    type Node = Node;
198
199    #[inline]
200    fn portgraph(&self) -> Self::Portgraph<'_> {
201        &self.graph
202    }
203
204    #[inline]
205    fn base_hugr(&self) -> &Hugr {
206        self.hugr
207    }
208
209    #[inline]
210    fn root_node(&self) -> Node {
211        self.root
212    }
213
214    #[inline]
215    fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex {
216        self.hugr.get_pg_index(node)
217    }
218
219    #[inline]
220    fn get_node(&self, index: portgraph::NodeIndex) -> Node {
221        self.hugr.get_node(index)
222    }
223}
224
225/// Mutable view onto a HUGR sibling graph.
226///
227/// Like [SiblingGraph], includes only the root node and its direct children, but no
228/// deeper descendants; but the descendants can still be accessed by creating nested
229/// [SiblingMut] instances from nodes in the view.
230///
231/// Uniquely, the root node has no parent.
232///
233/// [HugrView] methods may be slower than for an immutable [SiblingGraph]
234/// as the latter may cache information about the graph connectivity,
235/// whereas (in order to ease mutation) this does not.
236pub struct SiblingMut<'g, Root = Node> {
237    /// The chosen root node.
238    root: Node,
239
240    /// The rest of the HUGR.
241    hugr: &'g mut Hugr,
242
243    /// The operation type of the root node.
244    _phantom: std::marker::PhantomData<Root>,
245}
246
247impl<'g, Root: NodeHandle> SiblingMut<'g, Root> {
248    /// Create a new SiblingMut from a base.
249    /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference.
250    pub fn try_new<Base: HugrMut>(hugr: &'g mut Base, root: Node) -> Result<Self, HugrError> {
251        if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) {
252            return Err(HugrError::InvalidTag {
253                required: Base::RootHandle::TAG,
254                actual: Root::TAG,
255            });
256        }
257        check_tag::<Root, _>(hugr, root)?;
258        Ok(Self {
259            hugr: hugr.hugr_mut(),
260            root,
261            _phantom: std::marker::PhantomData,
262        })
263    }
264}
265
266impl<Root: NodeHandle> ExtractHugr for SiblingMut<'_, Root> {}
267
268impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> {
269    type Portgraph<'p>
270        = FlatRegionGraph<'p>
271    where
272        'g: 'p,
273        Root: 'p;
274    type Node = Node;
275
276    fn portgraph(&self) -> Self::Portgraph<'_> {
277        FlatRegionGraph::new(
278            &self.base_hugr().graph,
279            &self.base_hugr().hierarchy,
280            self.root.pg_index(),
281        )
282    }
283
284    fn base_hugr(&self) -> &Hugr {
285        self.hugr
286    }
287
288    fn root_node(&self) -> Node {
289        self.root
290    }
291
292    #[inline]
293    fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex {
294        self.hugr.get_pg_index(node)
295    }
296
297    #[inline]
298    fn get_node(&self, index: portgraph::NodeIndex) -> Node {
299        self.hugr.get_node(index)
300    }
301}
302
303impl<Root: NodeHandle> HugrView for SiblingMut<'_, Root> {
304    impl_base_members! {}
305
306    fn contains_node(&self, node: Node) -> bool {
307        // Don't call self.get_parent(). That requires valid_node(node)
308        // which infinitely-recurses back here.
309        node == self.root || self.base_hugr().get_parent(node) == Some(self.root)
310    }
311
312    fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator<Item = Port> + Clone {
313        self.base_hugr().node_ports(node, dir)
314    }
315
316    fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone {
317        self.base_hugr().all_node_ports(node)
318    }
319
320    fn linked_ports(
321        &self,
322        node: Node,
323        port: impl Into<Port>,
324    ) -> impl Iterator<Item = (Node, Port)> + Clone {
325        self.hugr
326            .linked_ports(node, port)
327            .filter(|(n, _)| self.contains_node(*n))
328    }
329
330    fn node_connections(&self, node: Node, other: Node) -> impl Iterator<Item = [Port; 2]> + Clone {
331        match self.contains_node(node) && self.contains_node(other) {
332            // The nodes are not in the sibling graph
333            false => Either::Left(iter::empty()),
334            // The nodes are in the sibling graph
335            true => Either::Right(self.hugr.node_connections(node, other)),
336        }
337    }
338
339    fn num_ports(&self, node: Node, dir: Direction) -> usize {
340        self.base_hugr().num_ports(node, dir)
341    }
342
343    fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator<Item = Node> + Clone {
344        self.hugr
345            .neighbours(node, dir)
346            .filter(|n| self.contains_node(*n))
347    }
348
349    fn all_neighbours(&self, node: Node) -> impl Iterator<Item = Node> + Clone {
350        self.hugr
351            .all_neighbours(node)
352            .filter(|n| self.contains_node(*n))
353    }
354}
355
356impl<Root: NodeHandle> RootTagged for SiblingMut<'_, Root> {
357    type RootHandle = Root;
358}
359
360impl<Root: NodeHandle> HugrMutInternals for SiblingMut<'_, Root> {
361    fn hugr_mut(&mut self) -> &mut Hugr {
362        self.hugr
363    }
364}
365
366impl<Root: NodeHandle> HugrMut for SiblingMut<'_, Root> {}
367
368#[cfg(test)]
369mod test {
370    use std::borrow::Cow;
371
372    use rstest::rstest;
373
374    use crate::builder::test::simple_dfg_hugr;
375    use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
376    use crate::extension::prelude::{qb_t, usize_t};
377    use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID};
378    use crate::ops::{dataflow::IOTrait, Input, OpTag, Output};
379    use crate::ops::{OpTrait, OpType};
380    use crate::types::Signature;
381    use crate::utils::test_quantum_extension::EXTENSION_ID;
382    use crate::IncomingPort;
383
384    use super::super::descendants::test::make_module_hgr;
385    use super::*;
386
387    fn test_properties<T>(
388        hugr: &Hugr,
389        def: Node,
390        inner: Node,
391        region: T,
392        inner_region: T,
393    ) -> Result<(), Box<dyn std::error::Error>>
394    where
395        T: HugrView<Node = Node> + Sized,
396    {
397        let def_io = region.get_io(def).unwrap();
398
399        assert_eq!(region.node_count(), 5);
400        assert_eq!(region.portgraph().node_count(), 5);
401        assert!(region.nodes().all(|n| n == def
402            || hugr.get_parent(n) == Some(def)
403            || hugr.get_parent(n) == Some(inner)));
404        assert_eq!(region.children(inner).count(), 0);
405
406        assert_eq!(
407            region.poly_func_type(),
408            Some(
409                Signature::new_endo(vec![usize_t(), qb_t()])
410                    .with_extension_delta(EXTENSION_ID)
411                    .into()
412            )
413        );
414
415        assert_eq!(
416            inner_region.inner_function_type().map(Cow::into_owned),
417            Some(Signature::new(vec![usize_t()], vec![usize_t()]))
418        );
419        assert_eq!(inner_region.node_count(), 3);
420        assert_eq!(inner_region.edge_count(), 1);
421        assert_eq!(inner_region.children(inner).count(), 2);
422        assert_eq!(inner_region.children(hugr.root()).count(), 0);
423        assert_eq!(
424            inner_region.num_ports(inner, Direction::Outgoing),
425            inner_region.node_ports(inner, Direction::Outgoing).count()
426        );
427        assert_eq!(
428            inner_region.num_ports(inner, Direction::Incoming)
429                + inner_region.num_ports(inner, Direction::Outgoing),
430            inner_region.all_node_ports(inner).count()
431        );
432
433        // The inner region filters out the connections to the main function I/O nodes,
434        // while the outer region includes them.
435        assert_eq!(inner_region.node_connections(inner, def_io[1]).count(), 0);
436        assert_eq!(region.node_connections(inner, def_io[1]).count(), 1);
437        assert_eq!(
438            inner_region
439                .linked_ports(inner, IncomingPort::from(0))
440                .count(),
441            0
442        );
443        assert_eq!(region.linked_ports(inner, IncomingPort::from(0)).count(), 1);
444        assert_eq!(
445            inner_region.neighbours(inner, Direction::Outgoing).count(),
446            0
447        );
448        assert_eq!(inner_region.all_neighbours(inner).count(), 0);
449        assert_eq!(
450            inner_region
451                .linked_ports(inner, IncomingPort::from(0))
452                .count(),
453            0
454        );
455
456        Ok(())
457    }
458
459    #[rstest]
460    fn sibling_graph_properties() -> Result<(), Box<dyn std::error::Error>> {
461        let (hugr, def, inner) = make_module_hgr()?;
462
463        test_properties::<SiblingGraph>(
464            &hugr,
465            def,
466            inner,
467            SiblingGraph::try_new(&hugr, def).unwrap(),
468            SiblingGraph::try_new(&hugr, inner).unwrap(),
469        )
470    }
471
472    #[rstest]
473    fn sibling_mut_properties() -> Result<(), Box<dyn std::error::Error>> {
474        let (hugr, def, inner) = make_module_hgr()?;
475        let mut def_region_hugr = hugr.clone();
476        let mut inner_region_hugr = hugr.clone();
477
478        test_properties::<SiblingMut>(
479            &hugr,
480            def,
481            inner,
482            SiblingMut::try_new(&mut def_region_hugr, def).unwrap(),
483            SiblingMut::try_new(&mut inner_region_hugr, inner).unwrap(),
484        )
485    }
486
487    #[test]
488    fn nested_flat() -> Result<(), Box<dyn std::error::Error>> {
489        let mut module_builder = ModuleBuilder::new();
490        let fty = Signature::new(vec![usize_t()], vec![usize_t()]);
491        let mut fbuild = module_builder.define_function("main", fty.clone())?;
492        let dfg = fbuild.dfg_builder(fty, fbuild.input_wires())?;
493        let ins = dfg.input_wires();
494        let sub_dfg = dfg.finish_with_outputs(ins)?;
495        let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?;
496        let h = module_builder.finish_hugr()?;
497        let sub_dfg = sub_dfg.node();
498
499        // We can create a view from a child or grandchild of a hugr:
500        let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&h, sub_dfg)?;
501        let fun_view: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&h, fun.node())?;
502        assert_eq!(fun_view.children(sub_dfg).count(), 0);
503
504        // And also create a view from a child of another SiblingGraph
505        let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&fun_view, sub_dfg)?;
506
507        // Both ways work:
508        let just_io = vec![
509            Input::new(vec![usize_t()]).into(),
510            Output::new(vec![usize_t()]).into(),
511        ];
512        for d in [dfg_view, nested_dfg_view] {
513            assert_eq!(
514                d.children(sub_dfg).map(|n| d.get_optype(n)).collect_vec(),
515                just_io.iter().collect_vec()
516            );
517        }
518
519        Ok(())
520    }
521
522    /// Mutate a SiblingMut wrapper
523    #[rstest]
524    fn flat_mut(mut simple_dfg_hugr: Hugr) {
525        simple_dfg_hugr.validate().unwrap();
526        let root = simple_dfg_hugr.root();
527        let signature = simple_dfg_hugr.inner_function_type().unwrap().into_owned();
528
529        let sib_mut = SiblingMut::<CfgID>::try_new(&mut simple_dfg_hugr, root);
530        assert_eq!(
531            sib_mut.err(),
532            Some(HugrError::InvalidTag {
533                required: OpTag::Cfg,
534                actual: OpTag::Dfg
535            })
536        );
537
538        let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
539        let bad_nodetype: OpType = crate::ops::CFG { signature }.into();
540        assert_eq!(
541            sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()),
542            Err(HugrError::InvalidTag {
543                required: OpTag::Dfg,
544                actual: OpTag::Cfg
545            })
546        );
547
548        // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation
549        simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap();
550        assert!(simple_dfg_hugr.validate().is_err());
551    }
552
553    #[rstest]
554    fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) {
555        let root = simple_dfg_hugr.root();
556        let case_nodetype = crate::ops::Case {
557            signature: simple_dfg_hugr
558                .root_type()
559                .dataflow_signature()
560                .unwrap()
561                .into_owned(),
562        };
563        let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
564        // As expected, we cannot replace the root with a Case
565        assert_eq!(
566            sib_mut.replace_op(root, case_nodetype),
567            Err(HugrError::InvalidTag {
568                required: OpTag::Dfg,
569                actual: OpTag::Case
570            })
571        );
572
573        let nested_sib_mut = SiblingMut::<DataflowParentID>::try_new(&mut sib_mut, root);
574        assert!(nested_sib_mut.is_err());
575    }
576
577    #[rstest]
578    fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
579        let (hugr, _def, inner) = make_module_hgr()?;
580
581        let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?;
582        let extracted = region.extract_hugr();
583        extracted.validate()?;
584
585        let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?;
586
587        assert_eq!(region.node_count(), extracted.node_count());
588        assert_eq!(region.root_type(), extracted.root_type());
589
590        Ok(())
591    }
592}