hugr_core/hugr/views/
petgraph.rs

1//! Implementations of petgraph's traits for Hugr Region views.
2
3use crate::core::HugrNode;
4use crate::hugr::HugrView;
5use crate::ops::OpType;
6use crate::types::EdgeKind;
7use crate::Port;
8
9use petgraph::visit as pv;
10
11/// Wrapper for a HugrView that implements petgraph's traits.
12///
13/// It can be used to apply petgraph's algorithms to a Hugr.
14#[derive(Debug)]
15pub struct PetgraphWrapper<'a, T> {
16    pub(crate) hugr: &'a T,
17}
18
19impl<T> Clone for PetgraphWrapper<'_, T> {
20    fn clone(&self) -> Self {
21        *self
22    }
23}
24
25impl<T> Copy for PetgraphWrapper<'_, T> {}
26
27impl<'a, T> From<&'a T> for PetgraphWrapper<'a, T>
28where
29    T: HugrView,
30{
31    fn from(hugr: &'a T) -> Self {
32        Self { hugr }
33    }
34}
35
36impl<T> pv::GraphBase for PetgraphWrapper<'_, T>
37where
38    T: HugrView,
39{
40    type NodeId = T::Node;
41    type EdgeId = ((T::Node, Port), (T::Node, Port));
42}
43
44impl<T> pv::GraphProp for PetgraphWrapper<'_, T>
45where
46    T: HugrView,
47{
48    type EdgeType = petgraph::Directed;
49}
50
51impl<T> pv::GraphRef for PetgraphWrapper<'_, T> where T: HugrView {}
52
53impl<T> pv::NodeCount for PetgraphWrapper<'_, T>
54where
55    T: HugrView,
56{
57    fn node_count(&self) -> usize {
58        HugrView::node_count(self.hugr)
59    }
60}
61
62impl<T> pv::NodeIndexable for PetgraphWrapper<'_, T>
63where
64    T: HugrView,
65{
66    fn node_bound(&self) -> usize {
67        HugrView::node_count(self.hugr)
68    }
69
70    fn to_index(&self, ix: Self::NodeId) -> usize {
71        self.hugr.get_pg_index(ix).into()
72    }
73
74    fn from_index(&self, ix: usize) -> Self::NodeId {
75        self.hugr.get_node(portgraph::NodeIndex::new(ix))
76    }
77}
78
79impl<T> pv::EdgeCount for PetgraphWrapper<'_, T>
80where
81    T: HugrView,
82{
83    fn edge_count(&self) -> usize {
84        HugrView::edge_count(self.hugr)
85    }
86}
87
88impl<T> pv::Data for PetgraphWrapper<'_, T>
89where
90    T: HugrView,
91{
92    type NodeWeight = OpType;
93    type EdgeWeight = EdgeKind;
94}
95
96impl<'a, T> pv::IntoNodeReferences for PetgraphWrapper<'a, T>
97where
98    T: HugrView,
99{
100    type NodeRef = HugrNodeRef<'a, T::Node>;
101    type NodeReferences = Box<dyn Iterator<Item = HugrNodeRef<'a, T::Node>> + 'a>;
102
103    fn node_references(self) -> Self::NodeReferences {
104        Box::new(
105            self.hugr
106                .nodes()
107                .map(|n| HugrNodeRef::from_node(n, self.hugr)),
108        )
109    }
110}
111
112impl<'a, T> pv::IntoNodeIdentifiers for PetgraphWrapper<'a, T>
113where
114    T: HugrView,
115{
116    type NodeIdentifiers = Box<dyn Iterator<Item = T::Node> + 'a>;
117
118    fn node_identifiers(self) -> Self::NodeIdentifiers {
119        Box::new(self.hugr.nodes())
120    }
121}
122
123impl<'a, T> pv::IntoNeighbors for PetgraphWrapper<'a, T>
124where
125    T: HugrView,
126{
127    type Neighbors = Box<dyn Iterator<Item = T::Node> + 'a>;
128
129    fn neighbors(self, n: Self::NodeId) -> Self::Neighbors {
130        Box::new(self.hugr.output_neighbours(n))
131    }
132}
133
134impl<'a, T> pv::IntoNeighborsDirected for PetgraphWrapper<'a, T>
135where
136    T: HugrView,
137{
138    type NeighborsDirected = Box<dyn Iterator<Item = T::Node> + 'a>;
139
140    fn neighbors_directed(
141        self,
142        n: Self::NodeId,
143        d: petgraph::Direction,
144    ) -> Self::NeighborsDirected {
145        Box::new(self.hugr.neighbours(n, d.into()))
146    }
147}
148
149impl<T> pv::Visitable for PetgraphWrapper<'_, T>
150where
151    T: HugrView,
152{
153    type Map = std::collections::HashSet<Self::NodeId>;
154
155    fn visit_map(&self) -> Self::Map {
156        std::collections::HashSet::new()
157    }
158
159    fn reset_map(&self, map: &mut Self::Map) {
160        map.clear();
161    }
162}
163
164impl<T> pv::GetAdjacencyMatrix for PetgraphWrapper<'_, T>
165where
166    T: HugrView,
167{
168    type AdjMatrix = std::collections::HashSet<(Self::NodeId, Self::NodeId)>;
169
170    fn adjacency_matrix(&self) -> Self::AdjMatrix {
171        let mut matrix = std::collections::HashSet::new();
172        for node in self.hugr.nodes() {
173            for neighbour in self.hugr.output_neighbours(node) {
174                matrix.insert((node, neighbour));
175            }
176        }
177        matrix
178    }
179
180    fn is_adjacent(&self, matrix: &Self::AdjMatrix, a: Self::NodeId, b: Self::NodeId) -> bool {
181        matrix.contains(&(a, b))
182    }
183}
184
185/// Reference to a Hugr node and its associated OpType.
186#[derive(Debug, Clone, Copy)]
187pub struct HugrNodeRef<'a, N> {
188    node: N,
189    op: &'a OpType,
190}
191
192impl<'a, N: HugrNode> HugrNodeRef<'a, N> {
193    pub(self) fn from_node(node: N, hugr: &'a impl HugrView<Node = N>) -> Self {
194        Self {
195            node,
196            op: hugr.get_optype(node),
197        }
198    }
199}
200
201impl<N: HugrNode> pv::NodeRef for HugrNodeRef<'_, N> {
202    type NodeId = N;
203
204    type Weight = OpType;
205
206    fn id(&self) -> Self::NodeId {
207        self.node
208    }
209
210    fn weight(&self) -> &Self::Weight {
211        self.op
212    }
213}
214
215#[cfg(test)]
216mod test {
217    use petgraph::visit::{
218        EdgeCount, GetAdjacencyMatrix, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef,
219    };
220
221    use crate::hugr::views::tests::sample_hugr;
222    use crate::ops::handle::NodeHandle;
223    use crate::HugrView;
224
225    use super::PetgraphWrapper;
226
227    #[test]
228    fn test_petgraph_wrapper() {
229        let (hugr, cx1, cx2) = sample_hugr();
230        let wrapper = PetgraphWrapper::from(&hugr);
231
232        assert_eq!(wrapper.node_count(), 5);
233        assert_eq!(wrapper.node_bound(), 5);
234        assert_eq!(wrapper.edge_count(), 7);
235
236        let cx1_index = cx1.node().pg_index().index();
237        assert_eq!(wrapper.to_index(cx1.node()), cx1_index);
238        assert_eq!(wrapper.from_index(cx1_index), cx1.node());
239
240        let cx1_ref = wrapper
241            .node_references()
242            .find(|n| n.id() == cx1.node())
243            .unwrap();
244        assert_eq!(cx1_ref.weight(), hugr.get_optype(cx1.node()));
245
246        let adj = wrapper.adjacency_matrix();
247        assert!(wrapper.is_adjacent(&adj, cx1.node(), cx2.node()));
248    }
249}