ade_graph/implementations/
filtered_graph.rs

1use crate::implementations::Graph;
2use ade_traits::{EdgeTrait, GraphViewTrait, NodeTrait};
3use fixedbitset::FixedBitSet;
4
5pub struct FilteredGraph<'a, N: NodeTrait, E: EdgeTrait> {
6    base: &'a Graph<N, E>,
7    active: FixedBitSet,
8}
9
10impl<'a, N: NodeTrait, E: EdgeTrait> FilteredGraph<'a, N, E> {
11    pub fn new(base: &'a Graph<N, E>, active_nodes: impl IntoIterator<Item = u32>) -> Self {
12        let node_count = base.get_node_keys().count();
13
14        // Assume normalized keys: 0, 1, 2, ..., n-1
15        let mut active = FixedBitSet::with_capacity(node_count);
16        for key in active_nodes {
17            if (key as usize) < node_count {
18                active.insert(key as usize);
19            }
20        }
21
22        Self { base, active }
23    }
24
25    fn is_active(&self, key: u32) -> bool {
26        self.active.contains(key as usize)
27    }
28}
29
30impl<N: NodeTrait, E: EdgeTrait> GraphViewTrait<N, E> for FilteredGraph<'_, N, E> {
31    fn is_empty(&self) -> bool {
32        self.active.count_ones(..) == 0
33    }
34
35    fn get_node(&self, key: u32) -> &N {
36        if !self.is_active(key) {
37            panic!("Node {} not active in filtered graph", key);
38        }
39        self.base.get_node(key)
40    }
41
42    fn has_node(&self, key: u32) -> bool {
43        self.is_active(key)
44    }
45
46    fn get_nodes<'b>(&'b self) -> impl Iterator<Item = &'b N>
47    where
48        N: 'b,
49    {
50        self.base
51            .get_nodes()
52            .filter(move |n| self.is_active(n.key()))
53    }
54
55    fn get_node_keys(&self) -> impl Iterator<Item = u32> {
56        self.base
57            .get_node_keys()
58            .filter(move |&k| self.is_active(k))
59    }
60
61    fn get_edge(&self, source: u32, target: u32) -> &E {
62        if !self.is_active(source) {
63            panic!("Source node {} not active in filtered graph", source);
64        }
65        if !self.is_active(target) {
66            panic!("Target node {} not active in filtered graph", target);
67        }
68        self.base.get_edge(source, target)
69    }
70
71    fn has_edge(&self, source: u32, target: u32) -> bool {
72        self.is_active(source) && self.is_active(target) && self.base.has_edge(source, target)
73    }
74
75    fn get_edges<'b>(&'b self) -> impl Iterator<Item = &'b E>
76    where
77        E: 'b,
78    {
79        self.base
80            .get_edges()
81            .filter(move |e| self.is_active(e.source()) && self.is_active(e.target()))
82    }
83
84    fn get_predecessors<'b>(&'b self, node_key: u32) -> impl Iterator<Item = &'b N>
85    where
86        N: 'b,
87    {
88        if !self.is_active(node_key) {
89            panic!("Node {} not active in filtered graph", node_key);
90        }
91        self.base
92            .get_node(node_key)
93            .predecessors()
94            .iter()
95            .filter(move |&&pred| self.is_active(pred))
96            .map(move |&pred| self.base.get_node(pred))
97    }
98
99    fn get_successors<'b>(&'b self, node_key: u32) -> impl Iterator<Item = &'b N>
100    where
101        N: 'b,
102    {
103        if !self.is_active(node_key) {
104            panic!("Node {} not active in filtered graph", node_key);
105        }
106        self.base
107            .get_node(node_key)
108            .successors()
109            .iter()
110            .filter(move |&&succ| self.is_active(succ))
111            .map(move |&succ| self.base.get_node(succ))
112    }
113
114    fn get_successors_keys(&self, node_key: u32) -> impl Iterator<Item = u32> {
115        if !self.is_active(node_key) {
116            panic!("Node {} not active in filtered graph", node_key);
117        }
118        self.base
119            .get_successors_keys(node_key)
120            .filter(move |&succ| self.is_active(succ))
121    }
122
123    fn get_predecessors_keys(&self, node_key: u32) -> impl Iterator<Item = u32> {
124        if !self.is_active(node_key) {
125            panic!("Node {} not active in filtered graph", node_key);
126        }
127        self.base
128            .get_predecessors_keys(node_key)
129            .filter(move |&pred| self.is_active(pred))
130    }
131
132    fn filter(&self, node_keys: &[u32]) -> impl GraphViewTrait<N, E> {
133        // Intersect the requested nodes with the currently active ones
134        let filtered_keys = node_keys.iter().copied().filter(|&key| self.is_active(key));
135
136        FilteredGraph::new(self.base, filtered_keys)
137    }
138
139    fn has_sequential_keys(&self) -> bool {
140        let size = self.active.count_ones(..);
141        if size == 0 {
142            return true;
143        }
144
145        // Quick checks first - check if we have nodes 0 and size-1
146        if !self.is_active(0) || !self.is_active(size as u32 - 1) {
147            return false;
148        }
149
150        // Check if all keys from 0 to size-1 are active
151        (0..size as u32).all(|i| self.is_active(i))
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::implementations::{Edge, Node};
159    use ade_traits::GraphViewTrait;
160
161    #[test]
162    fn test_filtered_graph_has_sequential_keys() {
163        let mut base_graph = Graph::<Node, Edge>::new(Vec::new(), Vec::new());
164
165        // Add nodes 0, 1, 2, 3, 4
166        for i in 0..5 {
167            base_graph.add_node(Node::new(i));
168        }
169
170        // Test 1: Filter to nodes 0, 1, 2 (sequential from 0)
171        let filtered = FilteredGraph::new(&base_graph, vec![0, 1, 2]);
172        assert!(filtered.has_sequential_keys());
173
174        // Test 2: Filter to nodes 1, 2, 3 (not starting from 0)
175        let filtered = FilteredGraph::new(&base_graph, vec![1, 2, 3]);
176        assert!(!filtered.has_sequential_keys());
177
178        // Test 3: Filter to nodes 0, 2 (not consecutive)
179        let filtered = FilteredGraph::new(&base_graph, vec![0, 2]);
180        assert!(!filtered.has_sequential_keys());
181
182        // Test 4: Empty filter
183        let filtered = FilteredGraph::new(&base_graph, vec![]);
184        assert!(filtered.has_sequential_keys());
185
186        // Test 5: Single node 0
187        let filtered = FilteredGraph::new(&base_graph, vec![0]);
188        assert!(filtered.has_sequential_keys());
189
190        // Test 6: Single node not 0
191        let filtered = FilteredGraph::new(&base_graph, vec![2]);
192        assert!(!filtered.has_sequential_keys());
193    }
194}