Skip to main content

grafeo_core/execution/
selection.rs

1//! SelectionVector for filtering.
2
3/// A selection vector indicating which rows are active.
4///
5/// Used for efficient filtering without copying data.
6#[derive(Debug, Clone)]
7pub struct SelectionVector {
8    /// Indices of selected rows.
9    indices: Vec<u16>,
10}
11
12impl SelectionVector {
13    /// Maximum capacity (limited to u16 for space efficiency).
14    pub const MAX_CAPACITY: usize = u16::MAX as usize;
15
16    /// Creates a new selection vector selecting all rows up to count.
17    ///
18    /// # Panics
19    ///
20    /// Panics if `count` exceeds `SelectionVector::MAX_CAPACITY` (65535).
21    #[must_use]
22    pub fn new_all(count: usize) -> Self {
23        assert!(count <= Self::MAX_CAPACITY);
24        Self {
25            indices: (0..count as u16).collect(),
26        }
27    }
28
29    /// Creates a new empty selection vector.
30    #[must_use]
31    pub fn new_empty() -> Self {
32        Self {
33            indices: Vec::new(),
34        }
35    }
36
37    /// Creates a new selection vector with the given capacity.
38    #[must_use]
39    pub fn with_capacity(capacity: usize) -> Self {
40        Self {
41            indices: Vec::with_capacity(capacity.min(Self::MAX_CAPACITY)),
42        }
43    }
44
45    /// Creates a selection vector from a predicate.
46    ///
47    /// Selects all indices where the predicate returns true.
48    #[must_use]
49    pub fn from_predicate<F>(count: usize, predicate: F) -> Self
50    where
51        F: Fn(usize) -> bool,
52    {
53        let indices: Vec<u16> = (0..count)
54            .filter(|&i| predicate(i))
55            .map(|i| i as u16)
56            .collect();
57        Self { indices }
58    }
59
60    /// Returns the number of selected rows.
61    #[must_use]
62    pub fn len(&self) -> usize {
63        self.indices.len()
64    }
65
66    /// Returns true if no rows are selected.
67    #[must_use]
68    pub fn is_empty(&self) -> bool {
69        self.indices.is_empty()
70    }
71
72    /// Gets the actual row index at position.
73    #[must_use]
74    pub fn get(&self, position: usize) -> Option<usize> {
75        self.indices.get(position).map(|&i| i as usize)
76    }
77
78    /// Pushes a new index.
79    ///
80    /// # Panics
81    ///
82    /// Panics if `index` exceeds `SelectionVector::MAX_CAPACITY` (65535).
83    pub fn push(&mut self, index: usize) {
84        assert!(index <= Self::MAX_CAPACITY);
85        self.indices.push(index as u16);
86    }
87
88    /// Returns the indices as a slice.
89    #[must_use]
90    pub fn as_slice(&self) -> &[u16] {
91        &self.indices
92    }
93
94    /// Clears all selections.
95    pub fn clear(&mut self) {
96        self.indices.clear();
97    }
98
99    /// Filters this selection by another predicate.
100    ///
101    /// Returns a new selection containing only indices that pass the predicate.
102    #[must_use]
103    pub fn filter<F>(&self, predicate: F) -> Self
104    where
105        F: Fn(usize) -> bool,
106    {
107        let indices: Vec<u16> = self
108            .indices
109            .iter()
110            .copied()
111            .filter(|&i| predicate(i as usize))
112            .collect();
113        Self { indices }
114    }
115
116    /// Computes the intersection of two selection vectors.
117    #[must_use]
118    pub fn intersect(&self, other: &Self) -> Self {
119        // Assumes both are sorted (which they typically are)
120        let mut result = Vec::new();
121        let mut i = 0;
122        let mut j = 0;
123
124        while i < self.indices.len() && j < other.indices.len() {
125            match self.indices[i].cmp(&other.indices[j]) {
126                std::cmp::Ordering::Less => i += 1,
127                std::cmp::Ordering::Greater => j += 1,
128                std::cmp::Ordering::Equal => {
129                    result.push(self.indices[i]);
130                    i += 1;
131                    j += 1;
132                }
133            }
134        }
135
136        Self { indices: result }
137    }
138
139    /// Computes the union of two selection vectors.
140    #[must_use]
141    pub fn union(&self, other: &Self) -> Self {
142        let mut result = Vec::new();
143        let mut i = 0;
144        let mut j = 0;
145
146        while i < self.indices.len() && j < other.indices.len() {
147            match self.indices[i].cmp(&other.indices[j]) {
148                std::cmp::Ordering::Less => {
149                    result.push(self.indices[i]);
150                    i += 1;
151                }
152                std::cmp::Ordering::Greater => {
153                    result.push(other.indices[j]);
154                    j += 1;
155                }
156                std::cmp::Ordering::Equal => {
157                    result.push(self.indices[i]);
158                    i += 1;
159                    j += 1;
160                }
161            }
162        }
163
164        result.extend_from_slice(&self.indices[i..]);
165        result.extend_from_slice(&other.indices[j..]);
166
167        Self { indices: result }
168    }
169
170    /// Returns an iterator over selected indices.
171    pub fn iter(&self) -> impl Iterator<Item = usize> + '_ {
172        self.indices.iter().map(|&i| i as usize)
173    }
174
175    /// Checks if a given index is in the selection.
176    #[must_use]
177    pub fn contains(&self, index: usize) -> bool {
178        if index > u16::MAX as usize {
179            return false;
180        }
181        // Since indices are typically sorted, use binary search
182        self.indices.binary_search(&(index as u16)).is_ok()
183    }
184}
185
186impl Default for SelectionVector {
187    fn default() -> Self {
188        Self::new_empty()
189    }
190}
191
192impl IntoIterator for SelectionVector {
193    type Item = usize;
194    type IntoIter = std::iter::Map<std::vec::IntoIter<u16>, fn(u16) -> usize>;
195
196    fn into_iter(self) -> Self::IntoIter {
197        self.indices.into_iter().map(|i| i as usize)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_selection_all() {
207        let sel = SelectionVector::new_all(10);
208        assert_eq!(sel.len(), 10);
209
210        for i in 0..10 {
211            assert_eq!(sel.get(i), Some(i));
212        }
213    }
214
215    #[test]
216    fn test_selection_from_predicate() {
217        let sel = SelectionVector::from_predicate(10, |i| i % 2 == 0);
218
219        assert_eq!(sel.len(), 5);
220        assert_eq!(sel.get(0), Some(0));
221        assert_eq!(sel.get(1), Some(2));
222        assert_eq!(sel.get(2), Some(4));
223    }
224
225    #[test]
226    fn test_selection_filter() {
227        let sel = SelectionVector::new_all(10);
228        let filtered = sel.filter(|i| i >= 5);
229
230        assert_eq!(filtered.len(), 5);
231        assert_eq!(filtered.get(0), Some(5));
232    }
233
234    #[test]
235    fn test_selection_intersect() {
236        let sel1 = SelectionVector::from_predicate(10, |i| i % 2 == 0); // 0, 2, 4, 6, 8
237        let sel2 = SelectionVector::from_predicate(10, |i| i % 3 == 0); // 0, 3, 6, 9
238
239        let intersection = sel1.intersect(&sel2);
240        // Intersection: 0, 6
241
242        assert_eq!(intersection.len(), 2);
243        assert_eq!(intersection.get(0), Some(0));
244        assert_eq!(intersection.get(1), Some(6));
245    }
246
247    #[test]
248    fn test_selection_union() {
249        let sel1 = SelectionVector::from_predicate(10, |i| i == 1 || i == 3); // 1, 3
250        let sel2 = SelectionVector::from_predicate(10, |i| i == 2 || i == 3); // 2, 3
251
252        let union = sel1.union(&sel2);
253        // Union: 1, 2, 3
254
255        assert_eq!(union.len(), 3);
256        assert_eq!(union.get(0), Some(1));
257        assert_eq!(union.get(1), Some(2));
258        assert_eq!(union.get(2), Some(3));
259    }
260
261    #[test]
262    fn test_selection_iterator() {
263        let sel = SelectionVector::from_predicate(5, |i| i < 3);
264        let collected: Vec<_> = sel.iter().collect();
265
266        assert_eq!(collected, vec![0, 1, 2]);
267    }
268}