formualizer_eval/engine/
masks.rs

1//! Mask API for criteria evaluation (Phase 3)
2
3use std::ops::Range;
4use std::sync::Arc;
5
6/// Dense bitmask for row selection
7#[derive(Debug, Clone)]
8pub struct DenseMask {
9    bits: Arc<Vec<u64>>,
10    len: usize,
11}
12
13impl DenseMask {
14    /// Create a new mask with the given length (in bits)
15    pub fn new(len: usize) -> Self {
16        let n_words = len.div_ceil(64);
17        Self {
18            bits: Arc::new(vec![0u64; n_words]),
19            len,
20        }
21    }
22
23    /// Create a mask with all bits set
24    pub fn all_ones(len: usize) -> Self {
25        let n_words = len.div_ceil(64);
26        let mut bits = vec![!0u64; n_words];
27
28        // Clear unused bits in the last word
29        let remainder = len % 64;
30        if remainder > 0 && n_words > 0 {
31            bits[n_words - 1] = (1u64 << remainder) - 1;
32        }
33
34        Self {
35            bits: Arc::new(bits),
36            len,
37        }
38    }
39
40    /// Get the length in bits
41    pub fn len(&self) -> usize {
42        self.len
43    }
44
45    /// Check if empty
46    pub fn is_empty(&self) -> bool {
47        self.len == 0
48    }
49
50    /// Set a bit at the given index (requires mutable Arc)
51    pub fn set(&mut self, index: usize, value: bool) {
52        if index >= self.len {
53            return;
54        }
55
56        let bits = Arc::make_mut(&mut self.bits);
57        let word_idx = index / 64;
58        let bit_idx = index % 64;
59
60        if value {
61            bits[word_idx] |= 1u64 << bit_idx;
62        } else {
63            bits[word_idx] &= !(1u64 << bit_idx);
64        }
65    }
66
67    /// Get a bit at the given index
68    pub fn get(&self, index: usize) -> bool {
69        if index >= self.len {
70            return false;
71        }
72
73        let word_idx = index / 64;
74        let bit_idx = index % 64;
75
76        (self.bits[word_idx] & (1u64 << bit_idx)) != 0
77    }
78
79    /// Count the number of set bits
80    pub fn count_ones(&self) -> u64 {
81        self.bits.iter().map(|w| w.count_ones() as u64).sum()
82    }
83
84    /// Calculate density (ratio of set bits to total bits)
85    pub fn density(&self) -> f64 {
86        if self.len == 0 {
87            return 0.0;
88        }
89        self.count_ones() as f64 / self.len as f64
90    }
91
92    /// In-place AND with another mask
93    pub fn and_inplace(&mut self, other: &DenseMask) {
94        let bits = Arc::make_mut(&mut self.bits);
95        let min_words = bits.len().min(other.bits.len());
96
97        for (dst, src) in bits.iter_mut().zip(other.bits.iter()).take(min_words) {
98            *dst &= *src;
99        }
100
101        // Clear any remaining words if other is shorter
102        for slot in bits.iter_mut().skip(min_words) {
103            *slot = 0;
104        }
105    }
106
107    /// In-place OR with another mask
108    pub fn or_inplace(&mut self, other: &DenseMask) {
109        let bits = Arc::make_mut(&mut self.bits);
110        let min_words = bits.len().min(other.bits.len());
111
112        for (dst, src) in bits.iter_mut().zip(other.bits.iter()).take(min_words) {
113            *dst |= *src;
114        }
115    }
116
117    /// In-place NOT within used rows range
118    pub fn not_inplace(&mut self, used_rows: Range<u32>) {
119        let bits = Arc::make_mut(&mut self.bits);
120        let start = used_rows.start as usize;
121        let end = used_rows.end.min(self.len as u32) as usize;
122
123        // Flip bits in the used range
124        for i in start..end {
125            let word_idx = i / 64;
126            let bit_idx = i % 64;
127            bits[word_idx] ^= 1u64 << bit_idx;
128        }
129    }
130
131    /// Iterate over set bit positions
132    pub fn iter_ones(&self) -> Box<dyn Iterator<Item = u32> + '_> {
133        Box::new(DenseMaskIterator {
134            mask: self,
135            word_idx: 0,
136            current_word: if self.bits.is_empty() {
137                0
138            } else {
139                self.bits[0]
140            },
141            base_index: 0,
142        })
143    }
144
145    /// Create from a slice of row indices
146    pub fn from_indices(indices: &[u32], len: usize) -> Self {
147        let mut mask = Self::new(len);
148        for &idx in indices {
149            if (idx as usize) < len {
150                mask.set(idx as usize, true);
151            }
152        }
153        mask
154    }
155
156    /// Apply to select values from a slice
157    pub fn select<'a, T>(&self, values: &'a [T]) -> Vec<&'a T> {
158        let min_len = self.len.min(values.len());
159        let mut result = Vec::new();
160
161        // Choose iteration strategy based on density
162        if self.density() < 0.1 {
163            // Sparse: iterate set bits
164            for idx in self.iter_ones() {
165                if (idx as usize) < min_len {
166                    result.push(&values[idx as usize]);
167                }
168            }
169        } else {
170            // Dense: linear scan
171            for (idx, value) in values.iter().take(min_len).enumerate() {
172                if self.get(idx) {
173                    result.push(value);
174                }
175            }
176        }
177
178        result
179    }
180}
181
182/// Iterator for set bit positions in a DenseMask
183struct DenseMaskIterator<'a> {
184    mask: &'a DenseMask,
185    word_idx: usize,
186    current_word: u64,
187    base_index: u32,
188}
189
190impl<'a> Iterator for DenseMaskIterator<'a> {
191    type Item = u32;
192
193    fn next(&mut self) -> Option<Self::Item> {
194        loop {
195            // Find next set bit in current word
196            if self.current_word != 0 {
197                let bit_idx = self.current_word.trailing_zeros();
198                let index = self.base_index + bit_idx;
199
200                // Clear the found bit
201                self.current_word &= self.current_word - 1;
202
203                // Check if within bounds
204                if index < self.mask.len as u32 {
205                    return Some(index);
206                }
207            }
208
209            // Move to next word
210            self.word_idx += 1;
211            if self.word_idx >= self.mask.bits.len() {
212                return None;
213            }
214
215            self.current_word = self.mask.bits[self.word_idx];
216            self.base_index = (self.word_idx * 64) as u32;
217        }
218    }
219}