hugr_core/hugr/views/
petgraph.rs

1//! Implementations of petgraph's traits for Hugr Region views.
2
3use crate::core::HugrNode;
4use crate::hugr::HugrView;
5use crate::ops::OpType;
6use crate::types::EdgeKind;
7use crate::{NodeIndex, Port};
8
9use petgraph::visit as pv;
10
11/// Wrapper for a `HugrView` that implements petgraph's traits.
12///
13/// It can be used to apply petgraph's algorithms to a Hugr.
14#[derive(Debug)]
15pub struct PetgraphWrapper<'a, T> {
16    pub(crate) hugr: &'a T,
17}
18
19impl<T> Clone for PetgraphWrapper<'_, T> {
20    fn clone(&self) -> Self {
21        *self
22    }
23}
24
25impl<T> Copy for PetgraphWrapper<'_, T> {}
26
27impl<'a, T> From<&'a T> for PetgraphWrapper<'a, T>
28where
29    T: HugrView,
30{
31    fn from(hugr: &'a T) -> Self {
32        Self { hugr }
33    }
34}
35
36impl<T> pv::GraphBase for PetgraphWrapper<'_, T>
37where
38    T: HugrView,
39{
40    type NodeId = T::Node;
41    type EdgeId = ((T::Node, Port), (T::Node, Port));
42}
43
44impl<T> pv::GraphProp for PetgraphWrapper<'_, T>
45where
46    T: HugrView,
47{
48    type EdgeType = petgraph::Directed;
49}
50
51impl<T> pv::GraphRef for PetgraphWrapper<'_, T> where T: HugrView {}
52
53impl<T> pv::NodeCount for PetgraphWrapper<'_, T>
54where
55    T: HugrView,
56{
57    fn node_count(&self) -> usize {
58        HugrView::num_nodes(self.hugr)
59    }
60}
61
62impl<T> pv::NodeIndexable for PetgraphWrapper<'_, T>
63where
64    T: HugrView,
65    // TODO: Define a trait for nodes that are equivalent to usizes, and implement it for `Node`
66    T::Node: NodeIndex + From<portgraph::NodeIndex>,
67{
68    fn node_bound(&self) -> usize {
69        HugrView::num_nodes(self.hugr)
70    }
71
72    fn to_index(&self, ix: Self::NodeId) -> usize {
73        ix.index()
74    }
75
76    fn from_index(&self, ix: usize) -> Self::NodeId {
77        portgraph::NodeIndex::new(ix).into()
78    }
79}
80
81impl<T> pv::EdgeCount for PetgraphWrapper<'_, T>
82where
83    T: HugrView,
84{
85    fn edge_count(&self) -> usize {
86        HugrView::num_edges(self.hugr)
87    }
88}
89
90impl<T> pv::Data for PetgraphWrapper<'_, T>
91where
92    T: HugrView,
93{
94    type NodeWeight = OpType;
95    type EdgeWeight = EdgeKind;
96}
97
98impl<'a, T> pv::IntoNodeReferences for PetgraphWrapper<'a, T>
99where
100    T: HugrView,
101{
102    type NodeRef = HugrNodeRef<'a, T::Node>;
103    type NodeReferences = Box<dyn Iterator<Item = HugrNodeRef<'a, T::Node>> + 'a>;
104
105    fn node_references(self) -> Self::NodeReferences {
106        Box::new(
107            self.hugr
108                .nodes()
109                .map(|n| HugrNodeRef::from_node(n, self.hugr)),
110        )
111    }
112}
113
114impl<'a, T> pv::IntoNodeIdentifiers for PetgraphWrapper<'a, T>
115where
116    T: HugrView,
117{
118    type NodeIdentifiers = Box<dyn Iterator<Item = T::Node> + 'a>;
119
120    fn node_identifiers(self) -> Self::NodeIdentifiers {
121        Box::new(self.hugr.nodes())
122    }
123}
124
125impl<'a, T> pv::IntoNeighbors for PetgraphWrapper<'a, T>
126where
127    T: HugrView,
128{
129    type Neighbors = Box<dyn Iterator<Item = T::Node> + 'a>;
130
131    fn neighbors(self, n: Self::NodeId) -> Self::Neighbors {
132        Box::new(self.hugr.output_neighbours(n))
133    }
134}
135
136impl<'a, T> pv::IntoNeighborsDirected for PetgraphWrapper<'a, T>
137where
138    T: HugrView,
139{
140    type NeighborsDirected = Box<dyn Iterator<Item = T::Node> + 'a>;
141
142    fn neighbors_directed(
143        self,
144        n: Self::NodeId,
145        d: petgraph::Direction,
146    ) -> Self::NeighborsDirected {
147        Box::new(self.hugr.neighbours(n, d.into()))
148    }
149}
150
151impl<T> pv::Visitable for PetgraphWrapper<'_, T>
152where
153    T: HugrView,
154{
155    type Map = std::collections::HashSet<Self::NodeId>;
156
157    fn visit_map(&self) -> Self::Map {
158        std::collections::HashSet::new()
159    }
160
161    fn reset_map(&self, map: &mut Self::Map) {
162        map.clear();
163    }
164}
165
166impl<T> pv::GetAdjacencyMatrix for PetgraphWrapper<'_, T>
167where
168    T: HugrView,
169{
170    type AdjMatrix = std::collections::HashSet<(Self::NodeId, Self::NodeId)>;
171
172    fn adjacency_matrix(&self) -> Self::AdjMatrix {
173        let mut matrix = std::collections::HashSet::new();
174        for node in self.hugr.nodes() {
175            for neighbour in self.hugr.output_neighbours(node) {
176                matrix.insert((node, neighbour));
177            }
178        }
179        matrix
180    }
181
182    fn is_adjacent(&self, matrix: &Self::AdjMatrix, a: Self::NodeId, b: Self::NodeId) -> bool {
183        matrix.contains(&(a, b))
184    }
185}
186
187/// Reference to a Hugr node and its associated `OpType`.
188#[derive(Debug, Clone, Copy)]
189pub struct HugrNodeRef<'a, N> {
190    node: N,
191    op: &'a OpType,
192}
193
194impl<'a, N: HugrNode> HugrNodeRef<'a, N> {
195    pub(self) fn from_node(node: N, hugr: &'a impl HugrView<Node = N>) -> Self {
196        Self {
197            node,
198            op: hugr.get_optype(node),
199        }
200    }
201}
202
203impl<N: HugrNode> pv::NodeRef for HugrNodeRef<'_, N> {
204    type NodeId = N;
205
206    type Weight = OpType;
207
208    fn id(&self) -> Self::NodeId {
209        self.node
210    }
211
212    fn weight(&self) -> &Self::Weight {
213        self.op
214    }
215}
216
217#[cfg(test)]
218mod test {
219    use petgraph::visit::{
220        EdgeCount, GetAdjacencyMatrix, IntoNodeReferences, NodeCount, NodeIndexable, NodeRef,
221    };
222
223    use crate::HugrView;
224    use crate::hugr::views::tests::sample_hugr;
225    use crate::ops::handle::NodeHandle;
226
227    use super::PetgraphWrapper;
228
229    #[test]
230    fn test_petgraph_wrapper() {
231        let (hugr, cx1, cx2) = sample_hugr();
232        let wrapper = PetgraphWrapper::from(&hugr);
233
234        assert_eq!(wrapper.node_count(), 9);
235        assert_eq!(wrapper.node_bound(), 9);
236        assert_eq!(wrapper.edge_count(), 11);
237
238        let cx1_index = cx1.node().into_portgraph().index();
239        assert_eq!(wrapper.to_index(cx1.node()), cx1_index);
240        assert_eq!(wrapper.from_index(cx1_index), cx1.node());
241
242        let cx1_ref = wrapper
243            .node_references()
244            .find(|n| n.id() == cx1.node())
245            .unwrap();
246        assert_eq!(cx1_ref.weight(), hugr.get_optype(cx1.node()));
247
248        let adj = wrapper.adjacency_matrix();
249        assert!(wrapper.is_adjacent(&adj, cx1.node(), cx2.node()));
250    }
251}