formualizer_eval/engine/
masks.rs1use std::ops::Range;
4use std::sync::Arc;
5
6#[derive(Debug, Clone)]
8pub struct DenseMask {
9 bits: Arc<Vec<u64>>,
10 len: usize,
11}
12
13impl DenseMask {
14 pub fn new(len: usize) -> Self {
16 let n_words = len.div_ceil(64);
17 Self {
18 bits: Arc::new(vec![0u64; n_words]),
19 len,
20 }
21 }
22
23 pub fn all_ones(len: usize) -> Self {
25 let n_words = len.div_ceil(64);
26 let mut bits = vec![!0u64; n_words];
27
28 let remainder = len % 64;
30 if remainder > 0 && n_words > 0 {
31 bits[n_words - 1] = (1u64 << remainder) - 1;
32 }
33
34 Self {
35 bits: Arc::new(bits),
36 len,
37 }
38 }
39
40 pub fn len(&self) -> usize {
42 self.len
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.len == 0
48 }
49
50 pub fn set(&mut self, index: usize, value: bool) {
52 if index >= self.len {
53 return;
54 }
55
56 let bits = Arc::make_mut(&mut self.bits);
57 let word_idx = index / 64;
58 let bit_idx = index % 64;
59
60 if value {
61 bits[word_idx] |= 1u64 << bit_idx;
62 } else {
63 bits[word_idx] &= !(1u64 << bit_idx);
64 }
65 }
66
67 pub fn get(&self, index: usize) -> bool {
69 if index >= self.len {
70 return false;
71 }
72
73 let word_idx = index / 64;
74 let bit_idx = index % 64;
75
76 (self.bits[word_idx] & (1u64 << bit_idx)) != 0
77 }
78
79 pub fn count_ones(&self) -> u64 {
81 self.bits.iter().map(|w| w.count_ones() as u64).sum()
82 }
83
84 pub fn density(&self) -> f64 {
86 if self.len == 0 {
87 return 0.0;
88 }
89 self.count_ones() as f64 / self.len as f64
90 }
91
92 pub fn and_inplace(&mut self, other: &DenseMask) {
94 let bits = Arc::make_mut(&mut self.bits);
95 let min_words = bits.len().min(other.bits.len());
96
97 for (dst, src) in bits.iter_mut().zip(other.bits.iter()).take(min_words) {
98 *dst &= *src;
99 }
100
101 for slot in bits.iter_mut().skip(min_words) {
103 *slot = 0;
104 }
105 }
106
107 pub fn or_inplace(&mut self, other: &DenseMask) {
109 let bits = Arc::make_mut(&mut self.bits);
110 let min_words = bits.len().min(other.bits.len());
111
112 for (dst, src) in bits.iter_mut().zip(other.bits.iter()).take(min_words) {
113 *dst |= *src;
114 }
115 }
116
117 pub fn not_inplace(&mut self, used_rows: Range<u32>) {
119 let bits = Arc::make_mut(&mut self.bits);
120 let start = used_rows.start as usize;
121 let end = used_rows.end.min(self.len as u32) as usize;
122
123 for i in start..end {
125 let word_idx = i / 64;
126 let bit_idx = i % 64;
127 bits[word_idx] ^= 1u64 << bit_idx;
128 }
129 }
130
131 pub fn iter_ones(&self) -> Box<dyn Iterator<Item = u32> + '_> {
133 Box::new(DenseMaskIterator {
134 mask: self,
135 word_idx: 0,
136 current_word: if self.bits.is_empty() {
137 0
138 } else {
139 self.bits[0]
140 },
141 base_index: 0,
142 })
143 }
144
145 pub fn from_indices(indices: &[u32], len: usize) -> Self {
147 let mut mask = Self::new(len);
148 for &idx in indices {
149 if (idx as usize) < len {
150 mask.set(idx as usize, true);
151 }
152 }
153 mask
154 }
155
156 pub fn select<'a, T>(&self, values: &'a [T]) -> Vec<&'a T> {
158 let min_len = self.len.min(values.len());
159 let mut result = Vec::new();
160
161 if self.density() < 0.1 {
163 for idx in self.iter_ones() {
165 if (idx as usize) < min_len {
166 result.push(&values[idx as usize]);
167 }
168 }
169 } else {
170 for (idx, value) in values.iter().take(min_len).enumerate() {
172 if self.get(idx) {
173 result.push(value);
174 }
175 }
176 }
177
178 result
179 }
180}
181
182struct DenseMaskIterator<'a> {
184 mask: &'a DenseMask,
185 word_idx: usize,
186 current_word: u64,
187 base_index: u32,
188}
189
190impl<'a> Iterator for DenseMaskIterator<'a> {
191 type Item = u32;
192
193 fn next(&mut self) -> Option<Self::Item> {
194 loop {
195 if self.current_word != 0 {
197 let bit_idx = self.current_word.trailing_zeros();
198 let index = self.base_index + bit_idx;
199
200 self.current_word &= self.current_word - 1;
202
203 if index < self.mask.len as u32 {
205 return Some(index);
206 }
207 }
208
209 self.word_idx += 1;
211 if self.word_idx >= self.mask.bits.len() {
212 return None;
213 }
214
215 self.current_word = self.mask.bits[self.word_idx];
216 self.base_index = (self.word_idx * 64) as u32;
217 }
218 }
219}