1#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct BitMatrix {
9 rows: usize,
10 cols: usize,
11 words_per_row: usize,
12 data: Vec<u64>,
13}
14
15mod simd;
17
18impl BitMatrix {
19 pub fn new(rows: usize, cols: usize) -> Self {
21 let words_per_row = (cols + 63) / 64;
22 let data = vec![0u64; rows * words_per_row];
23 let mut m = Self { rows, cols, words_per_row, data };
24 m.clear_unused_bits();
25 m
26 }
27
28 pub fn rows(&self) -> usize { self.rows }
30
31 pub fn cols(&self) -> usize { self.cols }
33
34 fn index(&self, row: usize, col: usize) -> (usize, u64) {
35 assert!(row < self.rows, "row out of bounds");
36 assert!(col < self.cols, "col out of bounds");
37 let word = col / 64;
38 let bit = (col % 64) as u64;
39 (row * self.words_per_row + word, 1u64 << bit)
40 }
41
42 pub fn set(&mut self, row: usize, col: usize, val: bool) {
44 let (idx, mask) = self.index(row, col);
45 if val { self.data[idx] |= mask; } else { self.data[idx] &= !mask; }
46 }
47
48 pub fn get(&self, row: usize, col: usize) -> bool {
50 let (idx, mask) = self.index(row, col);
51 (self.data[idx] & mask) != 0
52 }
53
54 pub fn row_words(&self, row: usize) -> &[u64] {
56 assert!(row < self.rows, "row out of bounds");
57 let start = row * self.words_per_row;
58 &self.data[start..start + self.words_per_row]
59 }
60
61 fn row_words_mut(&mut self, row: usize) -> &mut [u64] {
62 assert!(row < self.rows, "row out of bounds");
63 let start = row * self.words_per_row;
64 &mut self.data[start..start + self.words_per_row]
65 }
66
67 fn last_word_mask(&self) -> u64 {
68 let rem = self.cols % 64;
69 if rem == 0 { !0u64 } else { (1u64 << rem) - 1 }
70 }
71
72 fn clear_unused_bits(&mut self) {
73 if self.cols % 64 == 0 { return; }
74 let mask = self.last_word_mask();
75 for r in 0..self.rows {
76 let idx = r * self.words_per_row + (self.words_per_row - 1);
77 self.data[idx] &= mask;
78 }
79 }
80
81 pub fn count_ones(&self) -> usize {
83 let mut sum = 0usize;
84 let mask = self.last_word_mask();
85 for r in 0..self.rows {
86 let start = r * self.words_per_row;
87 for w in 0..self.words_per_row {
88 let mut v = self.data[start + w];
89 if w + 1 == self.words_per_row { v &= mask; }
91 sum += v.count_ones() as usize;
92 }
93 }
94 sum
95 }
96
97 pub fn bitand(&self, other: &Self) -> Self {
99 assert_eq!(self.rows, other.rows);
100 assert_eq!(self.cols, other.cols);
101 let mut out = self.clone();
102 out.bitand_assign(other);
103 out
104 }
105
106 pub fn bitand_assign(&mut self, other: &Self) {
108 assert_eq!(self.rows, other.rows);
109 assert_eq!(self.cols, other.cols);
110 for r in 0..self.rows {
111 let start = r * self.words_per_row;
112 let end = start + self.words_per_row;
113 simd::block_and(&mut self.data[start..end], &other.data[start..end]);
114 }
115 self.clear_unused_bits();
116 }
117
118 pub fn bitor_assign(&mut self, other: &Self) {
120 assert_eq!(self.rows, other.rows);
121 assert_eq!(self.cols, other.cols);
122 for r in 0..self.rows {
123 let start = r * self.words_per_row;
124 let end = start + self.words_per_row;
125 simd::block_or(&mut self.data[start..end], &other.data[start..end]);
126 }
127 self.clear_unused_bits();
128 }
129
130 pub fn bitxor_assign(&mut self, other: &Self) {
132 assert_eq!(self.rows, other.rows);
133 assert_eq!(self.cols, other.cols);
134 for r in 0..self.rows {
135 let start = r * self.words_per_row;
136 let end = start + self.words_per_row;
137 simd::block_xor(&mut self.data[start..end], &other.data[start..end]);
138 }
139 self.clear_unused_bits();
140 }
141
142 pub fn row_and_assign(&mut self, dst_row: usize, src_row: usize) {
144 let w = self.words_per_row;
145 let mask = self.last_word_mask();
146 let dst_start = dst_row * w;
147 let src_start = src_row * w;
148 for i in 0..w {
149 let dst_i = dst_start + i;
150 let src_i = src_start + i;
151 self.data[dst_i] &= self.data[src_i];
152 }
153 self.data[dst_start + w - 1] &= mask;
154 }
155
156 pub fn row_or_assign(&mut self, dst_row: usize, src_row: usize) {
158 let w = self.words_per_row;
159 let mask = self.last_word_mask();
160 let dst_start = dst_row * w;
161 let src_start = src_row * w;
162 for i in 0..w {
163 let dst_i = dst_start + i;
164 let src_i = src_start + i;
165 self.data[dst_i] |= self.data[src_i];
166 }
167 self.data[dst_start + w - 1] &= mask;
168 }
169
170 pub fn row_xor_assign(&mut self, dst_row: usize, src_row: usize) {
172 let w = self.words_per_row;
173 let mask = self.last_word_mask();
174 let dst_start = dst_row * w;
175 let src_start = src_row * w;
176 for i in 0..w {
177 let dst_i = dst_start + i;
178 let src_i = src_start + i;
179 self.data[dst_i] ^= self.data[src_i];
180 }
181 self.data[dst_start + w - 1] &= mask;
182 }
183
184 pub fn col_and_assign(&mut self, dst_col: usize, src_col: usize) {
186 assert!(dst_col < self.cols && src_col < self.cols);
187 let dst_word = dst_col / 64;
188 let dst_mask = 1u64 << (dst_col % 64);
189 let src_word = src_col / 64;
190 let src_mask = 1u64 << (src_col % 64);
191 for r in 0..self.rows {
192 let base = r * self.words_per_row;
193 let src_idx = base + src_word;
194 if (self.data[src_idx] & src_mask) == 0 {
195 let dst_idx = base + dst_word;
196 self.data[dst_idx] &= !dst_mask;
197 }
198 }
199 }
200
201 pub fn col_or_assign(&mut self, dst_col: usize, src_col: usize) {
203 assert!(dst_col < self.cols && src_col < self.cols);
204 let dst_word = dst_col / 64;
205 let dst_mask = 1u64 << (dst_col % 64);
206 let src_word = src_col / 64;
207 let src_mask = 1u64 << (src_col % 64);
208 for r in 0..self.rows {
209 let base = r * self.words_per_row;
210 let src_idx = base + src_word;
211 if (self.data[src_idx] & src_mask) != 0 {
212 let dst_idx = base + dst_word;
213 self.data[dst_idx] |= dst_mask;
214 }
215 }
216 }
217
218 pub fn col_xor_assign(&mut self, dst_col: usize, src_col: usize) {
220 assert!(dst_col < self.cols && src_col < self.cols);
221 let dst_word = dst_col / 64;
222 let dst_mask = 1u64 << (dst_col % 64);
223 let src_word = src_col / 64;
224 let src_mask = 1u64 << (src_col % 64);
225 for r in 0..self.rows {
226 let base = r * self.words_per_row;
227 let src_idx = base + src_word;
228 if (self.data[src_idx] & src_mask) != 0 {
229 let dst_idx = base + dst_word;
230 self.data[dst_idx] ^= dst_mask;
231 }
232 }
233 }
234
235 pub fn column(&self, col: usize) -> Vec<bool> {
237 assert!(col < self.cols);
238 let mut v = Vec::with_capacity(self.rows);
239 for r in 0..self.rows {
240 v.push(self.get(r, col));
241 }
242 v
243 }
244
245 pub fn set_column(&mut self, col: usize, src: &[bool]) {
247 assert!(col < self.cols);
248 assert!(src.len() == self.rows);
249 for r in 0..self.rows { self.set(r, col, src[r]); }
250 }
251
252 pub fn iter_row(&self, row: usize) -> RowIter<'_> {
254 assert!(row < self.rows);
255 RowIter { m: self, row, col: 0 }
256 }
257
258 pub fn iter_col(&self, col: usize) -> ColIter<'_> {
260 assert!(col < self.cols);
261 ColIter { m: self, col, row: 0 }
262 }
263
264 pub fn to_vec(&self) -> Vec<Vec<bool>> {
266 let mut out = Vec::with_capacity(self.rows);
267 for r in 0..self.rows {
268 let mut row = Vec::with_capacity(self.cols);
269 for c in 0..self.cols { row.push(self.get(r, c)); }
270 out.push(row);
271 }
272 out
273 }
274}
275
276pub struct RowIter<'a> { m: &'a BitMatrix, row: usize, col: usize }
278
279impl<'a> Iterator for RowIter<'a> {
280 type Item = bool;
281 fn next(&mut self) -> Option<bool> {
282 if self.col >= self.m.cols { return None; }
283 let v = self.m.get(self.row, self.col);
284 self.col += 1;
285 Some(v)
286 }
287}
288
289pub struct ColIter<'a> { m: &'a BitMatrix, col: usize, row: usize }
291
292impl<'a> Iterator for ColIter<'a> {
293 type Item = bool;
294 fn next(&mut self) -> Option<bool> {
295 if self.row >= self.m.rows { return None; }
296 let v = self.m.get(self.row, self.col);
297 self.row += 1;
298 Some(v)
299 }
300}
301
302impl From<Vec<Vec<bool>>> for BitMatrix {
303 fn from(v: Vec<Vec<bool>>) -> Self {
304 let rows = v.len();
305 let cols = if rows == 0 { 0 } else { v[0].len() };
306 let mut m = BitMatrix::new(rows, cols);
307 for (r, rowv) in v.into_iter().enumerate() {
308 assert_eq!(rowv.len(), cols);
309 for (c, b) in rowv.into_iter().enumerate() {
310 if b { m.set(r, c, true); }
311 }
312 }
313 m
314 }
315}
316
317impl BitMatrix {
318 pub fn from_vec(v: Vec<Vec<bool>>) -> Self { BitMatrix::from(v) }
320
321 pub fn into_vec(self) -> Vec<Vec<bool>> { self.to_vec() }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn basic_set_get() {
331 let mut m = BitMatrix::new(3, 130); assert_eq!(m.rows(), 3);
333 assert_eq!(m.cols(), 130);
334 assert!(!m.get(1, 1));
335 m.set(1, 1, true);
336 assert!(m.get(1, 1));
337 m.set(1, 129, true);
338 assert!(m.get(1, 129));
339 assert_eq!(m.count_ones(), 2);
340 }
341
342 #[test]
343 fn row_ops_and_matrix_ops() {
344 let mut a = BitMatrix::new(2, 70);
345 let mut b = BitMatrix::new(2, 70);
346 a.set(0, 1, true);
347 a.set(0, 69, true);
348 b.set(0, 1, true);
349 b.set(0, 2, true);
350
351 a.row_and_assign(0, 0); assert!(a.get(0, 1));
353 a.row_and_assign(0, 0); let c = a.bitand(&b);
357 assert!(c.get(0, 1));
358 assert!(!c.get(0, 2));
359
360 a.bitor_assign(&b);
361 assert!(a.get(0, 2));
362
363 a.bitxor_assign(&b);
364 assert!(!a.get(0, 2));
366 }
367
368 #[test]
369 fn column_get_set() {
370 let mut m = BitMatrix::new(4, 10);
371 m.set_column(3, &[true, false, true, false]);
372 let col = m.column(3);
373 assert_eq!(col, vec![true, false, true, false]);
374 }
375
376 #[test]
377 fn column_ops_and_iterators() {
378 let mut m = BitMatrix::new(4, 10);
379 m.set(0, 1, true);
381 m.set(1, 1, true);
382 m.set(2, 2, true);
383 m.set(3, 2, true);
384
385 m.col_or_assign(3, 1);
387 assert!(m.get(0, 3));
388 assert!(m.get(1, 3));
389 assert!(!m.get(2, 3));
390
391 m.col_xor_assign(3, 2);
393 assert!(m.get(2, 3));
395 assert!(m.get(3, 3));
396
397 m.col_and_assign(3, 2);
399 assert!(!m.get(0, 3));
400 assert!(!m.get(1, 3));
401
402 let row0: Vec<bool> = m.iter_row(0).collect();
404 assert_eq!(row0.len(), 10);
405 let col2: Vec<bool> = m.iter_col(2).collect();
406 assert_eq!(col2.len(), 4);
407
408 let v = m.to_vec();
410 let m2 = BitMatrix::from(v.clone());
411 assert_eq!(m2.to_vec(), v);
412 }
413
414 #[test]
415 fn masks_keep_bounds() {
416 let mut m = BitMatrix::new(1, 70); m.set(0, 69, true);
418 assert!(m.get(0, 69));
419 }
421}