oxihuman_core/
bit_matrix.rs1#![allow(dead_code)]
7
8#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct BitMatrix {
12 pub rows: usize,
13 pub cols: usize,
14 words_per_row: usize,
15 data: Vec<u64>,
16}
17
18#[allow(dead_code)]
19impl BitMatrix {
20 pub fn new(rows: usize, cols: usize) -> Self {
21 let words_per_row = cols.div_ceil(64);
22 let data = vec![0u64; rows * words_per_row];
23 Self {
24 rows,
25 cols,
26 words_per_row,
27 data,
28 }
29 }
30
31 fn index(&self, row: usize, col: usize) -> (usize, usize) {
32 let word = row * self.words_per_row + col / 64;
33 let bit = col % 64;
34 (word, bit)
35 }
36
37 pub fn set(&mut self, row: usize, col: usize) {
39 assert!(row < self.rows && col < self.cols);
40 let (w, b) = self.index(row, col);
41 self.data[w] |= 1u64 << b;
42 }
43
44 pub fn clear_bit(&mut self, row: usize, col: usize) {
46 assert!(row < self.rows && col < self.cols);
47 let (w, b) = self.index(row, col);
48 self.data[w] &= !(1u64 << b);
49 }
50
51 pub fn toggle(&mut self, row: usize, col: usize) {
53 assert!(row < self.rows && col < self.cols);
54 let (w, b) = self.index(row, col);
55 self.data[w] ^= 1u64 << b;
56 }
57
58 pub fn get(&self, row: usize, col: usize) -> bool {
60 assert!(row < self.rows && col < self.cols);
61 let (w, b) = self.index(row, col);
62 (self.data[w] >> b) & 1 == 1
63 }
64
65 pub fn row_popcount(&self, row: usize) -> u32 {
67 assert!(row < self.rows);
68 let start = row * self.words_per_row;
69 let end = start + self.words_per_row;
70 self.data[start..end].iter().map(|w| w.count_ones()).sum()
71 }
72
73 pub fn total_popcount(&self) -> u32 {
75 self.data.iter().map(|w| w.count_ones()).sum()
76 }
77
78 pub fn clear_all(&mut self) {
80 self.data.fill(0);
81 }
82
83 pub fn fill_all(&mut self) {
85 self.data.fill(u64::MAX);
86 if !self.cols.is_multiple_of(64) {
88 let mask = (1u64 << (self.cols % 64)) - 1;
89 for row in 0..self.rows {
90 let last = (row + 1) * self.words_per_row - 1;
91 self.data[last] &= mask;
92 }
93 }
94 }
95
96 pub fn row_and(&mut self, r1: usize, r2: usize) {
98 assert!(r1 < self.rows && r2 < self.rows);
99 for w in 0..self.words_per_row {
100 self.data[r1 * self.words_per_row + w] &= self.data[r2 * self.words_per_row + w];
101 }
102 }
103
104 pub fn row_or(&mut self, r1: usize, r2: usize) {
106 assert!(r1 < self.rows && r2 < self.rows);
107 for w in 0..self.words_per_row {
108 self.data[r1 * self.words_per_row + w] |= self.data[r2 * self.words_per_row + w];
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn new_all_clear() {
119 let m = BitMatrix::new(4, 64);
120 assert_eq!(m.total_popcount(), 0);
121 }
122
123 #[test]
124 fn set_and_get() {
125 let mut m = BitMatrix::new(4, 64);
126 m.set(2, 33);
127 assert!(m.get(2, 33));
128 assert!(!m.get(2, 32));
129 }
130
131 #[test]
132 fn clear_bit() {
133 let mut m = BitMatrix::new(4, 64);
134 m.set(1, 10);
135 m.clear_bit(1, 10);
136 assert!(!m.get(1, 10));
137 }
138
139 #[test]
140 fn toggle() {
141 let mut m = BitMatrix::new(4, 64);
142 m.toggle(0, 0);
143 assert!(m.get(0, 0));
144 m.toggle(0, 0);
145 assert!(!m.get(0, 0));
146 }
147
148 #[test]
149 fn row_popcount() {
150 let mut m = BitMatrix::new(3, 64);
151 m.set(1, 0);
152 m.set(1, 10);
153 m.set(1, 63);
154 assert_eq!(m.row_popcount(1), 3);
155 }
156
157 #[test]
158 fn total_popcount() {
159 let mut m = BitMatrix::new(3, 64);
160 m.set(0, 0);
161 m.set(1, 1);
162 m.set(2, 2);
163 assert_eq!(m.total_popcount(), 3);
164 }
165
166 #[test]
167 fn fill_all_popcount() {
168 let mut m = BitMatrix::new(2, 10);
169 m.fill_all();
170 assert_eq!(m.total_popcount(), 20);
171 }
172
173 #[test]
174 fn row_and() {
175 let mut m = BitMatrix::new(2, 64);
176 m.set(0, 5);
177 m.set(0, 10);
178 m.set(1, 5);
179 m.row_and(0, 1);
180 assert!(m.get(0, 5));
181 assert!(!m.get(0, 10));
182 }
183
184 #[test]
185 fn row_or() {
186 let mut m = BitMatrix::new(2, 64);
187 m.set(0, 5);
188 m.set(1, 10);
189 m.row_or(0, 1);
190 assert!(m.get(0, 5));
191 assert!(m.get(0, 10));
192 }
193
194 #[test]
195 fn wide_matrix() {
196 let mut m = BitMatrix::new(2, 200);
197 m.set(0, 199);
198 m.set(1, 65);
199 assert!(m.get(0, 199));
200 assert!(m.get(1, 65));
201 assert!(!m.get(0, 65));
202 }
203}