1use std::collections::HashMap;
15
16use crate::graph::{self};
17
18pub struct FilterMap<
37 'g,
38 BaseNodeWeight,
39 BaseEdgeWeight,
40 NodeWeight,
41 EdgeWeight,
42 Graph: graph::Graph<BaseNodeWeight, BaseEdgeWeight>,
43> {
44 base_graph: &'g Graph,
45 node_map: HashMap<Graph::NodeRef, NodeWeight>,
46 edge_map: HashMap<Graph::EdgeRef, EdgeWeight>,
47}
48
49impl<
50 'g,
51 BaseNodeWeight,
52 BaseEdgeWeight,
53 NodeWeight,
54 EdgeWeight,
55 Graph: graph::Graph<BaseNodeWeight, BaseEdgeWeight>,
56 > FilterMap<'g, BaseNodeWeight, BaseEdgeWeight, NodeWeight, EdgeWeight, Graph>
57{
58 pub fn new(
67 base_graph: &'g Graph,
68 node_map: HashMap<Graph::NodeRef, NodeWeight>,
69 edge_map: HashMap<Graph::EdgeRef, EdgeWeight>,
70 ) -> Self {
71 assert!(edge_map.keys().all(|edge| {
73 let (a, b) = base_graph.adjacent_nodes(*edge);
74 node_map.contains_key(&a) && node_map.contains_key(&b)
75 }));
76
77 Self {
78 base_graph,
79 node_map,
80 edge_map,
81 }
82 }
83 pub fn general_filter_map<NodeFn, EdgeFn>(
99 base_graph: &'g Graph,
100 node_fn: NodeFn,
101 edge_fn: EdgeFn,
102 ) -> Self
103 where
104 NodeFn: Fn(&'g Graph, Graph::NodeRef) -> Option<NodeWeight>,
105 EdgeFn: Fn(&'g Graph, Graph::EdgeRef) -> Option<EdgeWeight>,
106 {
107 let node_map: HashMap<Graph::NodeRef, NodeWeight> = base_graph
108 .nodes()
110 .filter_map(|n| {
112 let weight = node_fn(base_graph, n)?;
114 Some((n, weight))
116 })
117 .collect();
119
120 let edge_map: HashMap<Graph::EdgeRef, EdgeWeight> = base_graph
121 .edges()
123 .filter(|e| {
125 let (a, b) = base_graph.adjacent_nodes(*e);
126 node_map.contains_key(&a) && node_map.contains_key(&b)
127 })
128 .filter_map(|e| {
130 let weight = edge_fn(base_graph, e)?;
131 Some((e, weight))
132 })
133 .collect();
135
136 Self {
137 base_graph,
138 node_map,
139 edge_map,
140 }
141 }
142 pub fn weight_filter_map<NodeFn, EdgeFn>(
147 base_graph: &'g Graph,
148 node_fn: NodeFn,
149 edge_fn: EdgeFn,
150 ) -> Self
151 where
152 NodeFn: Fn(&'g BaseNodeWeight) -> Option<NodeWeight>,
153 EdgeFn: Fn(&'g BaseEdgeWeight) -> Option<EdgeWeight>,
154 BaseNodeWeight: 'g,
155 BaseEdgeWeight: 'g,
156 {
157 Self::general_filter_map(
158 base_graph,
159 |_, n| node_fn(base_graph.node_weight(n)),
160 |_, e| edge_fn(base_graph.edge_weight(e)),
161 )
162 }
163 pub fn weight_map<NodeFn, EdgeFn>(
167 base_graph: &'g Graph,
168 node_fn: NodeFn,
169 edge_fn: EdgeFn,
170 ) -> Self
171 where
172 NodeFn: Fn(&'g BaseNodeWeight) -> NodeWeight,
173 EdgeFn: Fn(&'g BaseEdgeWeight) -> EdgeWeight,
174 BaseNodeWeight: 'g,
175 BaseEdgeWeight: 'g,
176 {
177 Self::general_filter_map(
178 base_graph,
179 |_, n| Some(node_fn(base_graph.node_weight(n))),
180 |_, e| Some(edge_fn(base_graph.edge_weight(e))),
181 )
182 }
183}
184
185impl<'g, NodeWeight, EdgeWeight, Graph: graph::Graph<NodeWeight, EdgeWeight>>
186 FilterMap<'g, NodeWeight, EdgeWeight, &'g NodeWeight, &'g EdgeWeight, Graph>
187{
188 pub fn weight_filter<NodeFn, EdgeFn>(
194 base_graph: &'g Graph,
195 node_fn: NodeFn,
196 edge_fn: EdgeFn,
197 ) -> Self
198 where
199 NodeFn: Fn(&'g NodeWeight) -> bool,
200 EdgeFn: Fn(&'g EdgeWeight) -> bool,
201 NodeWeight: 'g,
202 EdgeWeight: 'g,
203 {
204 Self::weight_filter_map(
205 base_graph,
206 |n| {
207 if node_fn(n) {
208 Some(n)
209 } else {
210 None
211 }
212 },
213 |e| {
214 if edge_fn(e) {
215 Some(e)
216 } else {
217 None
218 }
219 },
220 )
221 }
222}
223
224#[macro_export]
226macro_rules! filter_pattern {
227 ($graph:expr, node_pattern: $node_pattern:pat, edge_pattern: $edge_pattern:pat) => {
229 FilterMap::weight_filter_map(
230 $graph,
231 |node| match node {
238 $node_pattern => Some(node),
239 _ => None,
240 },
241 |edge| match edge {
242 $edge_pattern => Some(edge),
243 _ => None,
244 },
245 )
246 };
247 ($graph:expr, node_pattern: $node_pattern:pat) => {
248 filter_pattern!($graph, node_pattern: $node_pattern, edge_pattern: _)
249 };
250 ($graph:expr, edge_pattern: $edge_pattern:pat) => {
251 filter_pattern!($graph, node_pattern: _, edge_pattern: $edge_pattern)
252 };
253}
254
255pub use filter_pattern;
257
258impl<
259 'g,
260 BaseNodeWeight,
261 BaseEdgeWeight,
262 NodeWeight,
263 EdgeWeight,
264 Graph: graph::Graph<BaseNodeWeight, BaseEdgeWeight>,
265 > graph::Graph<NodeWeight, EdgeWeight>
266 for FilterMap<'g, BaseNodeWeight, BaseEdgeWeight, NodeWeight, EdgeWeight, Graph>
267{
268 type NodeRef = Graph::NodeRef;
269
270 type EdgeRef = Graph::EdgeRef;
271
272 fn is_directed(&self) -> bool {
273 self.base_graph.is_directed()
274 }
275
276 fn is_directed_edge(&self, edge: Self::EdgeRef) -> bool {
277 assert!(self.edge_map.contains_key(&edge));
278 self.base_graph.is_directed_edge(edge)
279 }
280
281 type AdjacentEdgesIterator<'a>
282 = impl Iterator<Item = Graph::EdgeRef> + 'a where Self: 'a;
283
284 fn adjacent_edges(&self, node: Self::NodeRef) -> Self::AdjacentEdgesIterator<'_> {
285 assert!(self.node_map.contains_key(&node));
286 self.base_graph
287 .adjacent_edges(node)
288 .filter(|e| self.edge_map.contains_key(e))
289 }
290
291 type IncomingEdgesIterator<'a> = impl Iterator<Item = Graph::EdgeRef> + 'a where Self: 'a;
292
293 fn incoming_edges(&self, node: Self::NodeRef) -> Self::IncomingEdgesIterator<'_> {
294 assert!(self.node_map.contains_key(&node));
295 self.base_graph
296 .incoming_edges(node)
297 .filter(|e| self.edge_map.contains_key(e))
298 }
299
300 type OutgoingEdgesIterator<'a> = impl Iterator<Item = Graph::EdgeRef> + 'a where Self: 'a;
301
302 fn outgoing_edges(&self, node: Self::NodeRef) -> Self::OutgoingEdgesIterator<'_> {
303 assert!(self.node_map.contains_key(&node));
304 self.base_graph
305 .outgoing_edges(node)
306 .filter(|e| self.edge_map.contains_key(e))
307 }
308
309 fn adjacent_nodes(&self, edge: Self::EdgeRef) -> (Self::NodeRef, Self::NodeRef) {
310 assert!(self.edge_map.contains_key(&edge));
311 let (a, b) = self.base_graph.adjacent_nodes(edge);
312
313 assert!(self.node_map.contains_key(&a));
314 assert!(self.node_map.contains_key(&b));
315
316 (a, b)
317 }
318
319 fn node_weight(&self, node: Self::NodeRef) -> &NodeWeight {
320 assert!(self.node_map.contains_key(&node));
321 &self.node_map[&node]
322 }
323
324 fn edge_weight(&self, edge: Self::EdgeRef) -> &EdgeWeight {
325 assert!(self.edge_map.contains_key(&edge));
326 &self.edge_map[&edge]
327 }
328
329 type NodeWeightsIterator<'a> = impl Iterator<Item =&'a NodeWeight>
330 where
331 Self: 'a,
332 NodeWeight: 'a;
333
334 fn node_weights(&self) -> Self::NodeWeightsIterator<'_> {
335 self.node_map.values()
336 }
337
338 type EdgeWeightsIterator<'a> = impl Iterator<Item =&'a EdgeWeight>
339 where
340 Self: 'a,
341 EdgeWeight: 'a;
342
343 fn edge_weights(&self) -> Self::EdgeWeightsIterator<'_> {
344 self.edge_map.values()
345 }
346
347 type NodesIterator<'a> = impl Iterator<Item = Graph::NodeRef> + 'a
348 where
349 Self: 'a;
350
351 fn nodes(&self) -> Self::NodesIterator<'_> {
352 self.node_map.keys().copied()
353 }
354
355 type EdgesIterator<'a> = impl Iterator<Item = Graph::EdgeRef> + 'a
356 where
357 Self: 'a;
358
359 fn edges(&self) -> Self::EdgesIterator<'_> {
360 self.edge_map.keys().copied()
361 }
362
363 fn count_edges(&self) -> usize {
364 self.edge_map.len()
365 }
366
367 fn count_nodes(&self) -> usize {
368 self.node_map.len()
369 }
370}