bitset_matrix/
lib.rs

1//! A compact, row-major 2D bitset matrix with fast row-wise bitwise operations.
2//!
3//! The matrix stores bits in contiguous `u64` words per row. Row-wise operations (AND/OR/XOR)
4//! are implemented as word-wise loops for speed. Column-wise operations are supported but
5//! are naturally slower because bits are packed across words.
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct BitMatrix {
9    rows: usize,
10    cols: usize,
11    words_per_row: usize,
12    data: Vec<u64>,
13}
14
15// SIMD helper module (feature-gated)
16mod simd;
17
18impl BitMatrix {
19    /// Create a new `rows x cols` zeroed bit matrix.
20    pub fn new(rows: usize, cols: usize) -> Self {
21        let words_per_row = (cols + 63) / 64;
22        let data = vec![0u64; rows * words_per_row];
23        let mut m = Self { rows, cols, words_per_row, data };
24        m.clear_unused_bits();
25        m
26    }
27
28    /// Number of rows.
29    pub fn rows(&self) -> usize { self.rows }
30
31    /// Number of columns.
32    pub fn cols(&self) -> usize { self.cols }
33
34    fn index(&self, row: usize, col: usize) -> (usize, u64) {
35        assert!(row < self.rows, "row out of bounds");
36        assert!(col < self.cols, "col out of bounds");
37        let word = col / 64;
38        let bit = (col % 64) as u64;
39        (row * self.words_per_row + word, 1u64 << bit)
40    }
41
42    /// Set a bit at (row, col).
43    pub fn set(&mut self, row: usize, col: usize, val: bool) {
44        let (idx, mask) = self.index(row, col);
45        if val { self.data[idx] |= mask; } else { self.data[idx] &= !mask; }
46    }
47
48    /// Get the bit at (row, col).
49    pub fn get(&self, row: usize, col: usize) -> bool {
50        let (idx, mask) = self.index(row, col);
51        (self.data[idx] & mask) != 0
52    }
53
54    /// Returns a slice of the words for `row`.
55    pub fn row_words(&self, row: usize) -> &[u64] {
56        assert!(row < self.rows, "row out of bounds");
57        let start = row * self.words_per_row;
58        &self.data[start..start + self.words_per_row]
59    }
60
61    fn row_words_mut(&mut self, row: usize) -> &mut [u64] {
62        assert!(row < self.rows, "row out of bounds");
63        let start = row * self.words_per_row;
64        &mut self.data[start..start + self.words_per_row]
65    }
66
67    fn last_word_mask(&self) -> u64 {
68        let rem = self.cols % 64;
69        if rem == 0 { !0u64 } else { (1u64 << rem) - 1 }
70    }
71
72    fn clear_unused_bits(&mut self) {
73        if self.cols % 64 == 0 { return; }
74        let mask = self.last_word_mask();
75        for r in 0..self.rows {
76            let idx = r * self.words_per_row + (self.words_per_row - 1);
77            self.data[idx] &= mask;
78        }
79    }
80
81    /// Count number of set bits in the matrix.
82    pub fn count_ones(&self) -> usize {
83        let mut sum = 0usize;
84        let mask = self.last_word_mask();
85        for r in 0..self.rows {
86            let start = r * self.words_per_row;
87            for w in 0..self.words_per_row {
88                let mut v = self.data[start + w];
89                // mask last word in row
90                if w + 1 == self.words_per_row { v &= mask; }
91                sum += v.count_ones() as usize;
92            }
93        }
94        sum
95    }
96
97    /// Bitwise AND producing a new matrix. Requires same shape.
98    pub fn bitand(&self, other: &Self) -> Self {
99        assert_eq!(self.rows, other.rows);
100        assert_eq!(self.cols, other.cols);
101        let mut out = self.clone();
102        out.bitand_assign(other);
103        out
104    }
105
106    /// In-place AND with `other`.
107    pub fn bitand_assign(&mut self, other: &Self) {
108        assert_eq!(self.rows, other.rows);
109        assert_eq!(self.cols, other.cols);
110        for r in 0..self.rows {
111            let start = r * self.words_per_row;
112            let end = start + self.words_per_row;
113            simd::block_and(&mut self.data[start..end], &other.data[start..end]);
114        }
115        self.clear_unused_bits();
116    }
117
118    /// In-place OR with `other`.
119    pub fn bitor_assign(&mut self, other: &Self) {
120        assert_eq!(self.rows, other.rows);
121        assert_eq!(self.cols, other.cols);
122        for r in 0..self.rows {
123            let start = r * self.words_per_row;
124            let end = start + self.words_per_row;
125            simd::block_or(&mut self.data[start..end], &other.data[start..end]);
126        }
127        self.clear_unused_bits();
128    }
129
130    /// In-place XOR with `other`.
131    pub fn bitxor_assign(&mut self, other: &Self) {
132        assert_eq!(self.rows, other.rows);
133        assert_eq!(self.cols, other.cols);
134        for r in 0..self.rows {
135            let start = r * self.words_per_row;
136            let end = start + self.words_per_row;
137            simd::block_xor(&mut self.data[start..end], &other.data[start..end]);
138        }
139        self.clear_unused_bits();
140    }
141
142    /// Fast in-place row-wise AND: `dst_row` &= `src_row`.
143    pub fn row_and_assign(&mut self, dst_row: usize, src_row: usize) {
144        let w = self.words_per_row;
145        let mask = self.last_word_mask();
146        let dst_start = dst_row * w;
147        let src_start = src_row * w;
148        for i in 0..w {
149            let dst_i = dst_start + i;
150            let src_i = src_start + i;
151            self.data[dst_i] &= self.data[src_i];
152        }
153        self.data[dst_start + w - 1] &= mask;
154    }
155
156    /// Fast in-place row-wise OR: `dst_row` |= `src_row`.
157    pub fn row_or_assign(&mut self, dst_row: usize, src_row: usize) {
158        let w = self.words_per_row;
159        let mask = self.last_word_mask();
160        let dst_start = dst_row * w;
161        let src_start = src_row * w;
162        for i in 0..w {
163            let dst_i = dst_start + i;
164            let src_i = src_start + i;
165            self.data[dst_i] |= self.data[src_i];
166        }
167        self.data[dst_start + w - 1] &= mask;
168    }
169
170    /// Fast in-place row-wise XOR: `dst_row` ^= `src_row`.
171    pub fn row_xor_assign(&mut self, dst_row: usize, src_row: usize) {
172        let w = self.words_per_row;
173        let mask = self.last_word_mask();
174        let dst_start = dst_row * w;
175        let src_start = src_row * w;
176        for i in 0..w {
177            let dst_i = dst_start + i;
178            let src_i = src_start + i;
179            self.data[dst_i] ^= self.data[src_i];
180        }
181        self.data[dst_start + w - 1] &= mask;
182    }
183
184    /// In-place column-wise AND: for each row r, `dst_col[r] &= src_col[r]`.
185    pub fn col_and_assign(&mut self, dst_col: usize, src_col: usize) {
186        assert!(dst_col < self.cols && src_col < self.cols);
187        let dst_word = dst_col / 64;
188        let dst_mask = 1u64 << (dst_col % 64);
189        let src_word = src_col / 64;
190        let src_mask = 1u64 << (src_col % 64);
191        for r in 0..self.rows {
192            let base = r * self.words_per_row;
193            let src_idx = base + src_word;
194            if (self.data[src_idx] & src_mask) == 0 {
195                let dst_idx = base + dst_word;
196                self.data[dst_idx] &= !dst_mask;
197            }
198        }
199    }
200
201    /// In-place column-wise OR: for each row r, `dst_col[r] |= src_col[r]`.
202    pub fn col_or_assign(&mut self, dst_col: usize, src_col: usize) {
203        assert!(dst_col < self.cols && src_col < self.cols);
204        let dst_word = dst_col / 64;
205        let dst_mask = 1u64 << (dst_col % 64);
206        let src_word = src_col / 64;
207        let src_mask = 1u64 << (src_col % 64);
208        for r in 0..self.rows {
209            let base = r * self.words_per_row;
210            let src_idx = base + src_word;
211            if (self.data[src_idx] & src_mask) != 0 {
212                let dst_idx = base + dst_word;
213                self.data[dst_idx] |= dst_mask;
214            }
215        }
216    }
217
218    /// In-place column-wise XOR: for each row r, `dst_col[r] ^= src_col[r]`.
219    pub fn col_xor_assign(&mut self, dst_col: usize, src_col: usize) {
220        assert!(dst_col < self.cols && src_col < self.cols);
221        let dst_word = dst_col / 64;
222        let dst_mask = 1u64 << (dst_col % 64);
223        let src_word = src_col / 64;
224        let src_mask = 1u64 << (src_col % 64);
225        for r in 0..self.rows {
226            let base = r * self.words_per_row;
227            let src_idx = base + src_word;
228            if (self.data[src_idx] & src_mask) != 0 {
229                let dst_idx = base + dst_word;
230                self.data[dst_idx] ^= dst_mask;
231            }
232        }
233    }
234
235    /// Get a column as a Vec<bool> (col-wise access is slower).
236    pub fn column(&self, col: usize) -> Vec<bool> {
237        assert!(col < self.cols);
238        let mut v = Vec::with_capacity(self.rows);
239        for r in 0..self.rows {
240            v.push(self.get(r, col));
241        }
242        v
243    }
244
245    /// Set a column from a slice of bools.
246    pub fn set_column(&mut self, col: usize, src: &[bool]) {
247        assert!(col < self.cols);
248        assert!(src.len() == self.rows);
249        for r in 0..self.rows { self.set(r, col, src[r]); }
250    }
251
252    /// Row iterator (yields booleans across columns for a row).
253    pub fn iter_row(&self, row: usize) -> RowIter<'_> {
254        assert!(row < self.rows);
255        RowIter { m: self, row, col: 0 }
256    }
257
258    /// Column iterator (yields booleans across rows for a column).
259    pub fn iter_col(&self, col: usize) -> ColIter<'_> {
260        assert!(col < self.cols);
261        ColIter { m: self, col, row: 0 }
262    }
263
264    /// Convert to Vec<Vec<bool>>.
265    pub fn to_vec(&self) -> Vec<Vec<bool>> {
266        let mut out = Vec::with_capacity(self.rows);
267        for r in 0..self.rows {
268            let mut row = Vec::with_capacity(self.cols);
269            for c in 0..self.cols { row.push(self.get(r, c)); }
270            out.push(row);
271        }
272        out
273    }
274}
275
276/// Row iterator type
277pub struct RowIter<'a> { m: &'a BitMatrix, row: usize, col: usize }
278
279impl<'a> Iterator for RowIter<'a> {
280    type Item = bool;
281    fn next(&mut self) -> Option<bool> {
282        if self.col >= self.m.cols { return None; }
283        let v = self.m.get(self.row, self.col);
284        self.col += 1;
285        Some(v)
286    }
287}
288
289/// Column iterator type
290pub struct ColIter<'a> { m: &'a BitMatrix, col: usize, row: usize }
291
292impl<'a> Iterator for ColIter<'a> {
293    type Item = bool;
294    fn next(&mut self) -> Option<bool> {
295        if self.row >= self.m.rows { return None; }
296        let v = self.m.get(self.row, self.col);
297        self.row += 1;
298        Some(v)
299    }
300}
301
302impl From<Vec<Vec<bool>>> for BitMatrix {
303    fn from(v: Vec<Vec<bool>>) -> Self {
304        let rows = v.len();
305        let cols = if rows == 0 { 0 } else { v[0].len() };
306        let mut m = BitMatrix::new(rows, cols);
307        for (r, rowv) in v.into_iter().enumerate() {
308            assert_eq!(rowv.len(), cols);
309            for (c, b) in rowv.into_iter().enumerate() {
310                if b { m.set(r, c, true); }
311            }
312        }
313        m
314    }
315}
316
317impl BitMatrix {
318    /// Create from a Vec<Vec<bool>> by value.
319    pub fn from_vec(v: Vec<Vec<bool>>) -> Self { BitMatrix::from(v) }
320
321    /// Convert into Vec<Vec<bool>> consuming self.
322    pub fn into_vec(self) -> Vec<Vec<bool>> { self.to_vec() }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn basic_set_get() {
331        let mut m = BitMatrix::new(3, 130); // more than 2 words per row
332        assert_eq!(m.rows(), 3);
333        assert_eq!(m.cols(), 130);
334        assert!(!m.get(1, 1));
335        m.set(1, 1, true);
336        assert!(m.get(1, 1));
337        m.set(1, 129, true);
338        assert!(m.get(1, 129));
339        assert_eq!(m.count_ones(), 2);
340    }
341
342    #[test]
343    fn row_ops_and_matrix_ops() {
344        let mut a = BitMatrix::new(2, 70);
345        let mut b = BitMatrix::new(2, 70);
346        a.set(0, 1, true);
347        a.set(0, 69, true);
348        b.set(0, 1, true);
349        b.set(0, 2, true);
350
351        a.row_and_assign(0, 0); // no-op
352        assert!(a.get(0, 1));
353        a.row_and_assign(0, 0); // still ok
354
355        // test matrix and/or/xor
356        let c = a.bitand(&b);
357        assert!(c.get(0, 1));
358        assert!(!c.get(0, 2));
359
360        a.bitor_assign(&b);
361        assert!(a.get(0, 2));
362
363        a.bitxor_assign(&b);
364        // XORing twice with b reverts bits that were only in b
365        assert!(!a.get(0, 2));
366    }
367
368    #[test]
369    fn column_get_set() {
370        let mut m = BitMatrix::new(4, 10);
371        m.set_column(3, &[true, false, true, false]);
372        let col = m.column(3);
373        assert_eq!(col, vec![true, false, true, false]);
374    }
375
376    #[test]
377    fn column_ops_and_iterators() {
378        let mut m = BitMatrix::new(4, 10);
379        // Set a few bits in cols 1 and 2
380        m.set(0, 1, true);
381        m.set(1, 1, true);
382        m.set(2, 2, true);
383        m.set(3, 2, true);
384
385        // OR column 3 with column 1
386        m.col_or_assign(3, 1);
387        assert!(m.get(0, 3));
388        assert!(m.get(1, 3));
389        assert!(!m.get(2, 3));
390
391        // XOR column 3 with column 2
392        m.col_xor_assign(3, 2);
393        // row 2 and 3 had col2 set
394        assert!(m.get(2, 3));
395        assert!(m.get(3, 3));
396
397        // AND column 3 with column 2 (clear bits where col2==0)
398        m.col_and_assign(3, 2);
399        assert!(!m.get(0, 3));
400        assert!(!m.get(1, 3));
401
402        // iterators
403        let row0: Vec<bool> = m.iter_row(0).collect();
404        assert_eq!(row0.len(), 10);
405        let col2: Vec<bool> = m.iter_col(2).collect();
406        assert_eq!(col2.len(), 4);
407
408        // to/from vec conversions
409        let v = m.to_vec();
410        let m2 = BitMatrix::from(v.clone());
411        assert_eq!(m2.to_vec(), v);
412    }
413
414    #[test]
415    fn masks_keep_bounds() {
416        let mut m = BitMatrix::new(1, 70); // 70 -> 2 words, last word only 6 valid bits
417        m.set(0, 69, true);
418        assert!(m.get(0, 69));
419        // outside of bounds should panic when directly accessed
420    }
421}