1use crate::gf2::add_assign_binary;
2use crate::iterators::OctetIter;
3use crate::octet::Octet;
4use crate::octets::BinaryOctetVec;
5use crate::util::get_both_ranges;
6use std::mem::size_of;
7
8pub trait BinaryMatrix: Clone {
10 fn new(height: usize, width: usize, trailing_dense_column_hint: usize) -> Self;
11
12 fn set(&mut self, i: usize, j: usize, value: Octet);
13
14 fn height(&self) -> usize;
15
16 fn width(&self) -> usize;
17
18 fn size_in_bytes(&self) -> usize;
19
20 fn count_ones(&self, row: usize, start_col: usize, end_col: usize) -> usize;
21
22 fn get_row_iter(&self, row: usize, start_col: usize, end_col: usize) -> OctetIter;
24
25 fn get_ones_in_column(&self, col: usize, start_row: usize, end_row: usize) -> Vec<u32>;
27
28 fn get_sub_row_as_octets(&self, row: usize, start_col: usize) -> BinaryOctetVec;
30
31 fn query_non_zero_columns(&self, row: usize, start_col: usize) -> Vec<usize>;
33
34 fn get(&self, i: usize, j: usize) -> Octet;
35
36 fn swap_rows(&mut self, i: usize, j: usize);
37
38 fn swap_columns(&mut self, i: usize, j: usize, start_row_hint: usize);
41
42 fn enable_column_access_acceleration(&mut self);
43
44 fn disable_column_access_acceleration(&mut self);
46
47 fn hint_column_dense_and_frozen(&mut self, i: usize);
49
50 fn add_assign_rows(&mut self, dest: usize, src: usize, start_col: usize);
52
53 fn resize(&mut self, new_height: usize, new_width: usize);
54}
55
56const WORD_WIDTH: usize = 64;
57
58#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
59pub struct DenseBinaryMatrix {
60 height: usize,
61 width: usize,
62 elements: Vec<u64>,
64}
65
66impl DenseBinaryMatrix {
67 fn bit_position(&self, row: usize, col: usize) -> (usize, usize) {
69 return (
70 row * self.row_word_width() + Self::word_offset(col),
71 col % WORD_WIDTH,
72 );
73 }
74
75 fn word_offset(col: usize) -> usize {
76 col / WORD_WIDTH
77 }
78
79 fn row_word_width(&self) -> usize {
81 (self.width + WORD_WIDTH - 1) / WORD_WIDTH
82 }
83
84 pub fn select_mask(bit: usize) -> u64 {
86 1u64 << (bit as u64)
87 }
88
89 fn select_bit_and_all_left_mask(bit: usize) -> u64 {
91 !DenseBinaryMatrix::select_all_right_of_mask(bit)
92 }
93
94 fn select_all_right_of_mask(bit: usize) -> u64 {
96 let mask = DenseBinaryMatrix::select_mask(bit);
97 mask - 1
99 }
100
101 fn clear_bit(word: &mut u64, bit: usize) {
102 *word &= !DenseBinaryMatrix::select_mask(bit);
103 }
104
105 fn set_bit(word: &mut u64, bit: usize) {
106 *word |= DenseBinaryMatrix::select_mask(bit);
107 }
108}
109
110impl BinaryMatrix for DenseBinaryMatrix {
111 fn new(height: usize, width: usize, _: usize) -> DenseBinaryMatrix {
112 let elements = vec![0; height * (width + WORD_WIDTH - 1) / WORD_WIDTH];
113 DenseBinaryMatrix {
114 height,
115 width,
116 elements,
117 }
118 }
119
120 fn set(&mut self, i: usize, j: usize, value: Octet) {
121 let (word, bit) = self.bit_position(i, j);
122 if value == Octet::zero() {
123 DenseBinaryMatrix::clear_bit(&mut self.elements[word], bit);
124 } else {
125 DenseBinaryMatrix::set_bit(&mut self.elements[word], bit);
126 }
127 }
128
129 fn height(&self) -> usize {
130 self.height
131 }
132
133 fn width(&self) -> usize {
134 self.width
135 }
136
137 fn size_in_bytes(&self) -> usize {
138 let mut bytes = size_of::<Self>();
139 bytes += size_of::<Vec<u64>>();
140 bytes += size_of::<u64>() * self.elements.len();
141
142 bytes
143 }
144
145 fn count_ones(&self, row: usize, start_col: usize, end_col: usize) -> usize {
146 let (start_word, start_bit) = self.bit_position(row, start_col);
147 let (end_word, end_bit) = self.bit_position(row, end_col);
148 if start_word == end_word {
150 let mut mask = DenseBinaryMatrix::select_bit_and_all_left_mask(start_bit);
151 mask &= DenseBinaryMatrix::select_all_right_of_mask(end_bit);
152 let bits = self.elements[start_word] & mask;
153 return bits.count_ones() as usize;
154 }
155
156 let first_word_bits =
157 self.elements[start_word] & DenseBinaryMatrix::select_bit_and_all_left_mask(start_bit);
158 let mut ones = first_word_bits.count_ones();
159 for word in (start_word + 1)..end_word {
160 ones += self.elements[word].count_ones();
161 }
162 if end_bit > 0 {
163 let bits =
164 self.elements[end_word] & DenseBinaryMatrix::select_all_right_of_mask(end_bit);
165 ones += bits.count_ones();
166 }
167
168 return ones as usize;
169 }
170
171 fn get_row_iter(&self, row: usize, start_col: usize, end_col: usize) -> OctetIter {
172 let (first_word, first_bit) = self.bit_position(row, start_col);
173 let (last_word, _) = self.bit_position(row, end_col);
174 OctetIter::new_dense_binary(
175 start_col,
176 end_col,
177 first_bit,
178 &self.elements[first_word..=last_word],
179 )
180 }
181
182 fn get_ones_in_column(&self, col: usize, start_row: usize, end_row: usize) -> Vec<u32> {
183 let mut rows = vec![];
184 for row in start_row..end_row {
185 if self.get(row, col) == Octet::one() {
186 rows.push(row as u32);
187 }
188 }
189
190 rows
191 }
192
193 fn get_sub_row_as_octets(&self, row: usize, start_col: usize) -> BinaryOctetVec {
194 let mut result = vec![
195 0;
196 (self.width - start_col + BinaryOctetVec::WORD_WIDTH - 1)
197 / BinaryOctetVec::WORD_WIDTH
198 ];
199 let mut word = result.len();
200 let mut bit = 0;
201 for col in (start_col..self.width).rev() {
202 if bit == 0 {
203 bit = BinaryOctetVec::WORD_WIDTH - 1;
204 word -= 1;
205 } else {
206 bit -= 1;
207 }
208 if self.get(row, col) == Octet::one() {
209 result[word] |= BinaryOctetVec::select_mask(bit);
210 }
211 }
212
213 BinaryOctetVec::new(result, self.width - start_col)
214 }
215
216 fn query_non_zero_columns(&self, row: usize, start_col: usize) -> Vec<usize> {
217 (start_col..self.width)
218 .filter(|col| self.get(row, *col) != Octet::zero())
219 .collect()
220 }
221
222 fn get(&self, i: usize, j: usize) -> Octet {
223 let (word, bit) = self.bit_position(i, j);
224 if self.elements[word] & DenseBinaryMatrix::select_mask(bit) == 0 {
225 return Octet::zero();
226 } else {
227 return Octet::one();
228 }
229 }
230
231 fn swap_rows(&mut self, i: usize, j: usize) {
232 let (row_i, _) = self.bit_position(i, 0);
233 let (row_j, _) = self.bit_position(j, 0);
234 for k in 0..self.row_word_width() {
235 self.elements.swap(row_i + k, row_j + k);
236 }
237 }
238
239 fn swap_columns(&mut self, i: usize, j: usize, start_row_hint: usize) {
240 let (word_i, bit_i) = self.bit_position(0, i);
242 let (word_j, bit_j) = self.bit_position(0, j);
243 let unset_i = !DenseBinaryMatrix::select_mask(bit_i);
244 let unset_j = !DenseBinaryMatrix::select_mask(bit_j);
245 let bit_i = DenseBinaryMatrix::select_mask(bit_i);
246 let bit_j = DenseBinaryMatrix::select_mask(bit_j);
247 let row_width = self.row_word_width();
248 for row in start_row_hint..self.height {
249 let i_set = self.elements[row * row_width + word_i] & bit_i != 0;
250 if self.elements[row * row_width + word_j] & bit_j == 0 {
251 self.elements[row * row_width + word_i] &= unset_i;
252 } else {
253 self.elements[row * row_width + word_i] |= bit_i;
254 }
255 if i_set {
256 self.elements[row * row_width + word_j] |= bit_j;
257 } else {
258 self.elements[row * row_width + word_j] &= unset_j;
259 }
260 }
261 }
262
263 fn enable_column_access_acceleration(&mut self) {
264 }
266
267 fn disable_column_access_acceleration(&mut self) {
268 }
270
271 fn hint_column_dense_and_frozen(&mut self, _: usize) {
272 }
274
275 fn add_assign_rows(&mut self, dest: usize, src: usize, _start_col: usize) {
276 assert_ne!(dest, src);
277 let (dest_word, _) = self.bit_position(dest, 0);
278 let (src_word, _) = self.bit_position(src, 0);
279 let row_width = self.row_word_width();
280 let (dest_row, temp_row) =
281 get_both_ranges(&mut self.elements, dest_word, src_word, row_width);
282 add_assign_binary(dest_row, temp_row);
283 }
284
285 fn resize(&mut self, new_height: usize, new_width: usize) {
286 assert!(new_height <= self.height);
287 assert!(new_width <= self.width);
288 let old_row_width = self.row_word_width();
289 self.height = new_height;
290 self.width = new_width;
291 let new_row_width = self.row_word_width();
292 let words_to_remove = old_row_width - new_row_width;
293 if words_to_remove > 0 {
294 let mut src = 0;
295 let mut dest = 0;
296 while dest < new_height * new_row_width {
297 self.elements[dest] = self.elements[src];
298 src += 1;
299 dest += 1;
300 if dest % new_row_width == 0 {
301 src += words_to_remove;
303 }
304 }
305 assert_eq!(src, new_height * old_row_width);
306 }
307 self.elements.truncate(new_height * self.row_word_width());
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use rand::Rng;
314
315 use crate::matrix::{BinaryMatrix, DenseBinaryMatrix};
316 use crate::octet::Octet;
317 use crate::sparse_matrix::SparseBinaryMatrix;
318
319 fn rand_dense_and_sparse(size: usize) -> (DenseBinaryMatrix, SparseBinaryMatrix) {
320 let mut dense = DenseBinaryMatrix::new(size, size, 0);
321 let mut sparse = SparseBinaryMatrix::new(size, size, 1);
322 for _ in 0..(size * size / 2) {
324 let i = rand::thread_rng().gen_range(0..size);
325 let j = rand::thread_rng().gen_range(0..size);
326 let value = rand::thread_rng().gen_range(0..2);
327 dense.set(i, j, Octet::new(value));
328 sparse.set(i, j, Octet::new(value));
329 }
330
331 return (dense, sparse);
332 }
333
334 fn assert_matrices_eq<T: BinaryMatrix, U: BinaryMatrix>(matrix1: &T, matrix2: &U) {
335 assert_eq!(matrix1.height(), matrix2.height());
336 assert_eq!(matrix1.width(), matrix2.width());
337 for i in 0..matrix1.height() {
338 for j in 0..matrix1.width() {
339 assert_eq!(
340 matrix1.get(i, j),
341 matrix2.get(i, j),
342 "Matrices are not equal at row={} col={}",
343 i,
344 j
345 );
346 }
347 }
348 }
349
350 #[test]
351 fn row_iter() {
352 let (dense, sparse) = rand_dense_and_sparse(8);
354 for row in 0..dense.height() {
355 let start_col = rand::thread_rng().gen_range(0..(dense.width() - 2));
356 let end_col = rand::thread_rng().gen_range((start_col + 1)..dense.width());
357 let mut dense_iter = dense.get_row_iter(row, start_col, end_col);
358 let mut sparse_iter = sparse.get_row_iter(row, start_col, end_col);
359 for col in start_col..end_col {
360 assert_eq!(dense.get(row, col), sparse.get(row, col));
361 assert_eq!((col, dense.get(row, col)), dense_iter.next().unwrap());
362 if sparse.get(row, col) != Octet::zero() {
364 assert_eq!((col, sparse.get(row, col)), sparse_iter.next().unwrap());
365 }
366 }
367 assert!(dense_iter.next().is_none());
368 assert!(sparse_iter.next().is_none());
369 }
370 }
371
372 #[test]
373 fn swap_rows() {
374 let (mut dense, mut sparse) = rand_dense_and_sparse(8);
376 dense.swap_rows(0, 4);
377 dense.swap_rows(1, 6);
378 dense.swap_rows(1, 7);
379 sparse.swap_rows(0, 4);
380 sparse.swap_rows(1, 6);
381 sparse.swap_rows(1, 7);
382 assert_matrices_eq(&dense, &sparse);
383 }
384
385 #[test]
386 fn swap_columns() {
387 let (mut dense, mut sparse) = rand_dense_and_sparse(8);
389 dense.swap_columns(0, 4, 0);
390 dense.swap_columns(1, 6, 0);
391 dense.swap_columns(1, 1, 0);
392 sparse.swap_columns(0, 4, 0);
393 sparse.swap_columns(1, 6, 0);
394 sparse.swap_columns(1, 1, 0);
395 assert_matrices_eq(&dense, &sparse);
396 }
397
398 #[test]
399 fn count_ones() {
400 let (dense, sparse) = rand_dense_and_sparse(8);
402 assert_eq!(dense.count_ones(0, 0, 5), sparse.count_ones(0, 0, 5));
403 assert_eq!(dense.count_ones(2, 2, 6), sparse.count_ones(2, 2, 6));
404 assert_eq!(dense.count_ones(3, 1, 2), sparse.count_ones(3, 1, 2));
405 }
406
407 #[test]
408 fn fma_rows() {
409 let (mut dense, mut sparse) = rand_dense_and_sparse(8);
411 dense.add_assign_rows(0, 1, 0);
412 dense.add_assign_rows(0, 2, 0);
413 dense.add_assign_rows(2, 1, 0);
414 sparse.add_assign_rows(0, 1, 0);
415 sparse.add_assign_rows(0, 2, 0);
416 sparse.add_assign_rows(2, 1, 0);
417 assert_matrices_eq(&dense, &sparse);
418 }
419
420 #[test]
421 fn resize() {
422 let (mut dense, mut sparse) = rand_dense_and_sparse(8);
424 dense.disable_column_access_acceleration();
425 sparse.disable_column_access_acceleration();
426 dense.resize(5, 5);
427 sparse.resize(5, 5);
428 assert_matrices_eq(&dense, &sparse);
429 }
430
431 #[test]
432 fn hint_column_dense_and_frozen() {
433 let (dense, mut sparse) = rand_dense_and_sparse(8);
435 sparse.enable_column_access_acceleration();
436 sparse.hint_column_dense_and_frozen(6);
437 sparse.hint_column_dense_and_frozen(5);
438 assert_matrices_eq(&dense, &sparse);
439 }
440
441 #[test]
442 fn dense_storage_math() {
443 let size = 128;
444 let (mut dense, mut sparse) = rand_dense_and_sparse(size);
445 sparse.enable_column_access_acceleration();
446 for i in (0..(size - 1)).rev() {
447 sparse.hint_column_dense_and_frozen(i);
448 assert_matrices_eq(&dense, &sparse);
449 }
450 assert_matrices_eq(&dense, &sparse);
451 sparse.disable_column_access_acceleration();
452 for _ in 0..1000 {
453 let i = rand::thread_rng().gen_range(0..size);
454 let mut j = rand::thread_rng().gen_range(0..size);
455 while j == i {
456 j = rand::thread_rng().gen_range(0..size);
457 }
458 dense.add_assign_rows(i, j, 0);
459 sparse.add_assign_rows(i, j, 0);
460 }
461 assert_matrices_eq(&dense, &sparse);
462 }
463}