1use ffi::*;
2use friendly::binary_vector::BinVector;
3use libc::c_int;
4use std::cmp;
5use std::ops;
6use std::ptr;
7#[cfg(feature = "serde")]
8use vob::Vob;
9
10#[cfg(feature = "serde")]
11#[derive(Serialize)]
12#[serde(remote = "ptr::NonNull<Mzd>")]
13struct MzdSerializer {
14 #[serde(getter = "mzd_to_vecs")]
15 rows: Vec<Vob>,
16}
17
18#[cfg(feature = "serde")]
19fn mzd_to_vecs(mzd: &ptr::NonNull<Mzd>) -> Vec<Vob> {
20 let m = BinMatrix { mzd: *mzd };
21 let result = (0..m.nrows())
22 .into_iter()
23 .map(|r| m.get_window(r, 0, r + 1, m.ncols()).as_vector().into_vob())
24 .collect();
25 std::mem::forget(m);
27 result
28}
29
30#[derive(Debug)]
32#[cfg_attr(feature = "serde", derive(Serialize))]
33pub struct BinMatrix {
34 #[cfg_attr(feature = "serde", serde(with = "MzdSerializer", rename = "matrix"))]
35 mzd: ptr::NonNull<Mzd>,
36}
37
38unsafe impl Sync for BinMatrix {}
39unsafe impl Send for BinMatrix {}
40
41impl ops::Drop for BinMatrix {
42 fn drop(&mut self) {
43 unsafe { ptr::drop_in_place(self.mzd.as_ptr()) }
44 }
45}
46
47macro_rules! nonnull {
48 ($exp:expr) => {
49 ptr::NonNull::new_unchecked($exp)
50 };
51}
52
53#[cfg(all(
54 feature = "m4rm_mul",
55 not(any(feature = "strassen_mul", feature = "naive_mul"))
56))]
57macro_rules! mul_impl {
58 ($dest:expr, $a:expr, $b:expr) => {
59 mzd_mul_m4rm($dest, $a, $b, 0)
60 };
61}
62
63#[cfg(any(
64 all(
65 feature = "strassen_mul",
66 not(any(feature = "m4rm_mul", feature = "naive_mul"))
67 ),
68 not(any(feature = "strassen_mul", feature = "m4rm_mul", feature = "naive_mul"))
69))]
70macro_rules! mul_impl {
71 ($dest:expr, $a:expr, $b:expr) => {
72 mzd_mul($dest, $a, $b, 0)
73 };
74}
75
76#[cfg(all(
77 feature = "naive_mul",
78 not(any(feature = "m4rm_mul", feature = "strassen_mul"))
79))]
80macro_rules! mul_impl {
81 ($dest:expr, $a:expr, $b:expr) => {
82 mzd_mul_naive($dest, $a, $b)
83 };
84}
85
86#[cfg(any(
87 all(feature = "naive_mul", feature = "m4rm_mul"),
88 all(feature = "strassen_mul", feature = "naive_mul"),
89 all(feature = "m4rm_mul", feature = "strassen_mul")
90))]
91macro_rules! mul_impl {
92 ($($a:expr),*) => {
93 compile_error!("You need to set only one of the feature flags as mul strategy")
94 };
95}
96
97impl BinMatrix {
98 pub fn zero(rows: usize, cols: usize) -> BinMatrix {
100 if rows == 0 || cols == 0 {
101 panic!("Can't create a 0 matrix");
102 }
103 let mzd = unsafe { nonnull!(mzd_init(rows as c_int, cols as c_int)) };
104 BinMatrix { mzd }
105 }
106
107 pub fn new(rows: Vec<BinVector>) -> BinMatrix {
109 let rowlen = rows[0].len();
110 let storage: Vec<Vec<u64>> = rows
111 .iter()
112 .map(|vec| {
113 vec.get_storage()
114 .into_iter()
115 .copied()
116 .map(|b| b as u64)
117 .collect()
118 })
119 .collect();
120 BinMatrix::from_slices(&storage, rowlen)
121 }
122
123 pub fn from_slices<T: AsRef<[u64]>>(rows: &[T], rowlen: usize) -> BinMatrix {
125 if rows.is_empty() || rowlen == 0 {
126 panic!("Can't create a 0 matrix");
127 }
128
129 for row in rows {
130 debug_assert!(row.as_ref().len() * 64 >= rowlen, "expected len {} bits but got only {} blocks", rowlen, row.as_ref().len());
131 }
132
133 let mzd_ptr = unsafe { mzd_init(rows.len() as c_int, rowlen as c_int) };
134
135 let blocks_per_row = rowlen / 64 + if rowlen % 64 == 0 { 0 } else { 1 };
136 for (row_index, row) in rows.into_iter().enumerate() {
138 let row_ptr: *const *mut Word = unsafe { (*mzd_ptr).rows.add(row_index) };
139 for (block_index, row_block) in row
140 .as_ref()
141 .iter()
142 .take(blocks_per_row)
143 .copied()
144 .enumerate()
145 {
146 assert_eq!(
147 ::std::mem::size_of::<usize>(),
148 ::std::mem::size_of::<u64>(),
149 "only works on 64 bit"
150 );
151 let row_block = if block_index == rowlen / 64 {
152 row_block & ((1 << (rowlen % 64)) - 1)
153 } else {
154 row_block
155 };
156 unsafe {
157 *((*row_ptr).add(block_index)) = row_block as u64;
158 }
159 }
160 }
161
162 unsafe {
163 BinMatrix {
164 mzd: nonnull!(mzd_ptr),
165 }
166 }
167 }
168
169 pub fn count_ones(&self) -> u32 {
173 assert!(self.nrows() == 1 || self.ncols() == 1, "only works on single row or single column matrices");
174 let mut accumulator = 0;
175 for row in 0..self.nrows() {
176 let row_ptr: *const *mut Word = unsafe { (*self.mzd.as_ptr()).rows.add(row) };
177 for i in 0..(self.ncols() / 64) {
178 let word_ptr: *const Word = unsafe { (*row_ptr).add(i) };
179 accumulator += unsafe { (*word_ptr).count_ones() };
180 }
181 if self.ncols() % 64 != 0 {
183 let word_ptr: *const Word = unsafe { (*row_ptr).add((self.ncols() - 1) / 64) };
184 let word = unsafe { *word_ptr } & ((1 << self.ncols() % 64) - 1);
185 accumulator += word.count_ones();
186 }
187 }
188 accumulator
189 }
190
191 pub fn random(rows: usize, columns: usize) -> BinMatrix {
193 let mzd = unsafe { mzd_init(rows as Rci, columns as Rci) };
194 unsafe {
196 mzd_randomize(mzd);
197 }
198 unsafe { BinMatrix { mzd: nonnull!(mzd) } }
199 }
200
201 pub fn from_mzd(mzd: *mut Mzd) -> BinMatrix {
203 let mzd = ptr::NonNull::new(mzd).expect("Can't be NULL");
204 BinMatrix { mzd }
205 }
206
207 #[inline]
209 pub fn identity(rows: usize) -> BinMatrix {
210 unsafe {
211 let mzd_ptr = mzd_init(rows as c_int, rows as c_int);
212 mzd_set_ui(mzd_ptr, 1);
213 let mzd = nonnull!(mzd_ptr);
214 BinMatrix { mzd }
215 }
216 }
217
218 #[inline]
221 pub fn augmented(&self, other: &BinMatrix) -> BinMatrix {
222 debug_assert_eq!(self.nrows(), other.nrows(), "The rows need to be equal");
223 let mzd = unsafe {
224 nonnull!(mzd_concat(
225 ptr::null_mut(),
226 self.mzd.as_ptr(),
227 other.mzd.as_ptr()
228 ))
229 };
230 BinMatrix { mzd }
231 }
232
233 #[inline]
235 pub fn stacked(&self, other: &BinMatrix) -> BinMatrix {
236 let mzd = unsafe {
237 nonnull!(mzd_stack(
238 ptr::null_mut(),
239 self.mzd.as_ptr(),
240 other.mzd.as_ptr()
241 ))
242 };
243 BinMatrix { mzd }
244 }
245
246 #[inline]
250 pub fn rank(&self) -> usize {
251 self.clone().echelonize()
252 }
253
254 #[inline]
258 pub fn echelonize(&mut self) -> usize {
259 let rank = unsafe { mzd_echelonize(self.mzd.as_ptr(), false as c_int) };
260 rank as usize
261 }
262
263 #[inline]
265 pub fn inverted(&self) -> BinMatrix {
266 let mzd = unsafe { nonnull!(mzd_inv_m4ri(ptr::null_mut(), self.mzd.as_ptr(), 0 as c_int)) };
267 BinMatrix { mzd }
268 }
269
270 #[inline]
272 pub fn transposed(&self) -> BinMatrix {
273 let mzd;
274 unsafe {
275 let mzd_ptr = mzd_transpose(ptr::null_mut(), self.mzd.as_ptr());
276 mzd = nonnull!(mzd_ptr);
277 }
278 BinMatrix { mzd }
279 }
280
281 #[inline]
285 pub fn nrows(&self) -> usize {
286 unsafe { self.mzd.as_ref().nrows as usize }
287 }
288
289 #[inline]
293 pub fn ncols(&self) -> usize {
294 unsafe { self.mzd.as_ref().ncols as usize }
295 }
296
297 pub fn get_word(&self, row: usize, column: usize) -> Word {
299 assert!(row < self.nrows());
300 assert!(column < self.ncols());
301
302 unsafe { self.get_word_unchecked(row, column) }
303 }
304
305 #[inline]
308 pub unsafe fn get_word_unchecked(&self, row: usize, column: usize) -> Word {
309 let row_ptr: *const *mut Word = (*self.mzd.as_ptr()).rows.add(row);
310 let word_ptr: *const Word = ((*row_ptr) as *const Word).add(column);
311 *word_ptr
312 }
313
314 pub fn get_word_mut(&self, row: usize, column: usize) -> &mut Word {
316 assert!(row < self.nrows());
317 assert!(column < self.ncols());
318 unsafe { self.get_word_mut_unchecked(row, column) }
319 }
320
321 #[inline]
323 pub unsafe fn get_word_mut_unchecked(&self, row: usize, column: usize) -> &mut Word {
324 let row_ptr: *const *mut Word = (*self.mzd.as_ptr()).rows.add(row);
325 let word_ptr: *mut Word = ((*row_ptr) as *mut Word).add(column / 64);
326 word_ptr.as_mut().unwrap()
327 }
328
329 pub fn as_vector(&self) -> BinVector {
333 if self.nrows() != 1 {
334 assert_eq!(self.ncols(), 1, "needs to have only one column or row");
335 self.transposed().as_vector()
336 } else {
337 assert_eq!(self.nrows(), 1, "needs to have only one column or row");
338 let mut bits = BinVector::with_capacity(self.ncols());
339 {
340 let collector = unsafe { bits.get_storage_mut() };
341 for i in 0..(self.ncols() / 64) {
342 let row_ptr: *const *mut Word = unsafe { (*self.mzd.as_ptr()).rows };
343 let word_ptr: *const Word = unsafe { ((*row_ptr) as *const Word).add(i) };
344 collector.push(unsafe { *word_ptr as usize });
345 }
346 if self.ncols() % 64 != 0 {
348 let row_ptr: *const *mut Word = unsafe { (*self.mzd.as_ptr()).rows };
349 let word_ptr: *const Word = unsafe { (*row_ptr).add((self.ncols() - 1) / 64) };
350 let word = unsafe { *word_ptr };
351 collector.push(word as usize);
352 }
353 }
354 unsafe {
355 bits.set_len(self.ncols());
356 bits.mask_last_block();
357 }
358
359 bits
360 }
361 }
362
363 pub fn bit(&self, row: usize, col: usize) -> bool {
365 let bit = unsafe { mzd_read_bit(self.mzd.as_ptr(), row as Rci, col as Rci) };
366 debug_assert!(bit == 0 || bit == 1, "Invalid bool for bit??");
367 bit == 1
368 }
369
370 pub fn get_window(
372 &self,
373 start_row: usize,
374 start_col: usize,
375 high_row: usize,
376 high_col: usize,
377 ) -> BinMatrix {
378 let (rows, cols) = (high_row - start_row, high_col - start_col);
379 debug_assert!(rows > 0 && rows <= self.nrows());
380 debug_assert!(cols > 0 && cols <= self.ncols());
381 let mzd_ptr = unsafe { mzd_init(rows as Rci, cols as Rci) };
382 for (r, i) in (start_row..high_row).enumerate() {
383 for (c, j) in (start_col..high_col).enumerate() {
385 let bit = self.bit(i, j);
386 unsafe {
387 mzd_write_bit(mzd_ptr, r as Rci, c as Rci, bit as BIT);
388 }
389 }
390 }
391 BinMatrix::from_mzd(mzd_ptr)
392 }
393
394 pub fn set_window(&mut self, start_row: usize, start_col: usize, other: &BinMatrix) {
399 let highr = start_row + other.nrows();
400 let highc = start_col + other.ncols();
401 debug_assert!(self.ncols() >= highc, "This matrix is too small!");
402 debug_assert!(self.nrows() >= highr, "This matrix has too few rows !");
403 let mzd_ptr = self.mzd.as_ptr();
404
405 for r in start_row..highr {
406 for c in start_col..highc {
407 let bit = other.bit(r - start_row, c - start_col);
408 unsafe {
409 mzd_write_bit(mzd_ptr, r as Rci, c as Rci, bit as BIT);
410 }
411 }
412 }
413 }
414
415 pub fn mul_slice(&self, other: &[u64]) -> BinMatrix {
417 debug_assert!(
419 self.ncols() <= other.len() * 64,
420 "Mismatched sizes: ({}x{}) * ({}x1) (too big)",
421 self.nrows(),
422 self.ncols(),
423 other.len() * 64
424 );
425 let result = {
426 let other = BinMatrix::from_slices(&[other], self.ncols()).transposed();
427 unsafe { mzd_mul_naive(ptr::null_mut(), self.mzd.as_ptr(), other.mzd.as_ptr()) }
428 };
429 let matresult = BinMatrix::from_mzd(result);
430 matresult
431 }
432}
433
434impl cmp::PartialEq for BinMatrix {
435 fn eq(&self, other: &BinMatrix) -> bool {
436 unsafe { mzd_equal(self.mzd.as_ptr(), other.mzd.as_ptr()) == 1 }
437 }
438}
439
440impl cmp::Eq for BinMatrix {}
441
442impl ops::Mul<BinMatrix> for BinMatrix {
443 type Output = BinMatrix;
444
445 #[inline]
447 fn mul(self, other: BinMatrix) -> Self::Output {
448 &self * &other
449 }
450}
451
452impl std::clone::Clone for BinMatrix {
453 fn clone(&self) -> Self {
454 let mzd = unsafe { nonnull!(mzd_copy(ptr::null_mut(), self.mzd.as_ptr())) };
455 BinMatrix { mzd }
456 }
457}
458
459impl<'a> ops::Mul<&'a BinMatrix> for &'a BinMatrix {
460 type Output = BinMatrix;
461 #[inline]
463 fn mul(self, other: &BinMatrix) -> Self::Output {
464 unsafe {
465 let mzd_ptr = mul_impl!(ptr::null_mut(), self.mzd.as_ptr(), other.mzd.as_ptr());
466
467 BinMatrix {
468 mzd: ptr::NonNull::new(mzd_ptr).expect("Multiplication failed"),
469 }
470 }
471 }
472}
473
474impl<'a> ops::Add<&'a BinMatrix> for &'a BinMatrix {
475 type Output = BinMatrix;
476
477 #[inline]
479 fn add(self, other: &BinMatrix) -> Self::Output {
480 let mzd = unsafe {
481 nonnull!(mzd_add(
482 ptr::null_mut(),
483 self.mzd.as_ptr(),
484 other.mzd.as_ptr()
485 ))
486 };
487 BinMatrix { mzd }
488 }
489}
490
491impl ops::Add<BinMatrix> for BinMatrix {
492 type Output = BinMatrix;
493
494 #[inline]
496 fn add(self, other: BinMatrix) -> Self::Output {
497 let mzd = unsafe {
498 nonnull!(mzd_add(
499 self.mzd.as_ptr(),
500 self.mzd.as_ptr(),
501 other.mzd.as_ptr()
502 ))
503 };
504 BinMatrix { mzd }
505 }
506}
507
508impl ops::AddAssign<BinMatrix> for BinMatrix {
509 #[inline]
511 fn add_assign(&mut self, other: BinMatrix) {
512 unsafe {
513 mzd_add(self.mzd.as_ptr(), self.mzd.as_ptr(), other.mzd.as_ptr());
514 }
515 }
516}
517
518impl<'a> ops::AddAssign<&'a BinMatrix> for BinMatrix {
519 #[inline]
521 fn add_assign(&mut self, other: &BinMatrix) {
522 unsafe {
523 mzd_add(self.mzd.as_ptr(), self.mzd.as_ptr(), other.mzd.as_ptr());
524 }
525 }
526}
527
528impl<'a> ops::Mul<&'a BinVector> for &'a BinMatrix {
529 type Output = BinVector;
530 #[inline]
532 fn mul(self, other: &BinVector) -> Self::Output {
533 self.mul_slice(
534 &other
535 .get_storage()
536 .iter()
537 .copied()
538 .map(|b| b as u64)
539 .collect::<Vec<u64>>(),
540 ).as_vector()
541 }
542}
543
544impl ops::Mul<BinVector> for BinMatrix {
545 type Output = BinVector;
546 fn mul(self, other: BinVector) -> Self::Output {
548 &self * &other
549 }
550}
551
552impl<'a> ops::Mul<&'a BinMatrix> for &'a BinVector {
553 type Output = BinVector;
554
555 #[inline]
556 fn mul(self, other: &BinMatrix) -> Self::Output {
558 let vec_mzd = self.as_matrix();
559 let tmp = &vec_mzd * other;
560
561 tmp.as_vector()
562 }
563}
564
565impl ops::Mul<BinMatrix> for BinVector {
566 type Output = BinVector;
567
568 #[inline]
569 fn mul(self, other: BinMatrix) -> Self::Output {
571 &self * &other
572 }
573}
574
575pub fn solve_left(a: BinMatrix, b: &mut BinMatrix) -> bool {
583 let result = unsafe { mzd_solve_left(a.mzd.as_ptr(), b.mzd.as_ptr(), 0, 1) };
584
585 result == 0
586}
587
588#[cfg(test)]
589mod test {
590 use super::*;
591 use rand::prelude::*;
592 use vob::Vob;
593
594 #[test]
595 fn new() {
596 let _m = BinMatrix::new(vec![
597 BinVector::from(vob![true, false, true]),
598 BinVector::from(vob![true, true, true]),
599 ]);
600 }
601
602 #[test]
603 fn identity() {
604 let id = BinMatrix::new(vec![
605 BinVector::from(vob![
606 true, false, false, false, false, false, false, false, false, false
607 ]),
608 BinVector::from(vob![
609 false, true, false, false, false, false, false, false, false, false
610 ]),
611 BinVector::from(vob![
612 false, false, true, false, false, false, false, false, false, false
613 ]),
614 BinVector::from(vob![
615 false, false, false, true, false, false, false, false, false, false
616 ]),
617 BinVector::from(vob![
618 false, false, false, false, true, false, false, false, false, false
619 ]),
620 BinVector::from(vob![
621 false, false, false, false, false, true, false, false, false, false
622 ]),
623 BinVector::from(vob![
624 false, false, false, false, false, false, true, false, false, false
625 ]),
626 BinVector::from(vob![
627 false, false, false, false, false, false, false, true, false, false
628 ]),
629 BinVector::from(vob![
630 false, false, false, false, false, false, false, false, true, false
631 ]),
632 BinVector::from(vob![
633 false, false, false, false, false, false, false, false, false, true
634 ]),
635 ]);
636
637 let id_gen = BinMatrix::identity(10);
638 assert_eq!(id.nrows(), id_gen.nrows());
639 assert_eq!(id.ncols(), id_gen.ncols());
640 for i in 0..8 {
641 for j in 0..8 {
642 let m1 = id.mzd.as_ptr();
643 let m2 = id_gen.mzd.as_ptr();
644 unsafe {
645 assert_eq!(
646 mzd_read_bit(m1, i, j),
647 mzd_read_bit(m2, i, j),
648 "({}, {})",
649 i,
650 j
651 );
652 }
653 }
654 }
655 unsafe {
656 assert!(mzd_equal(id.mzd.as_ptr(), id_gen.mzd.as_ptr()) != 0);
657 }
658 assert_eq!(id, id_gen);
659 }
660
661 #[test]
662 fn mul() {
663 let m1 = BinMatrix::identity(8);
664 let m2 = BinMatrix::identity(8);
665 let m3 = BinMatrix::identity(8);
666 let prod = m1 * m2;
667 unsafe {
668 assert!(mzd_equal(prod.mzd.as_ptr(), m3.mzd.as_ptr()) != 0);
669 }
670 }
671
672 #[test]
673 fn vecmul() {
674 let m1 = BinMatrix::identity(10);
675 let binvec = BinVector::from(Vob::from_elem(10, true));
676
677 let result: BinVector = &m1 * &binvec;
678 assert_eq!(result, binvec);
679
680 let result: BinVector = &binvec * &m1;
681 assert_eq!(result, binvec);
682
683 let m1 = BinMatrix::random(10, 3);
684 let result = &binvec * &m1;
685 assert_eq!(result.len(), 3);
686 }
687
688 #[test]
689 fn test_random() {
690 BinMatrix::random(10, 1);
691 }
692
693 #[cfg(feature = "serde")]
694 #[test]
695 fn test_serialize() {
696 let m = BinMatrix::identity(3);
697 let json = serde_json::to_string(&m).unwrap();
698 assert_eq!(json, "{\"matrix\":{\"rows\":[{\"len\":3,\"vec\":[1]},{\"len\":3,\"vec\":[2]},{\"len\":3,\"vec\":[4]}]}}");
699 }
700
701 #[test]
702 fn test_as_vector_column() {
703 for i in 1..25 {
704 let m1 = BinMatrix::random(i, 1);
705 let vec = m1.as_vector();
706 assert_eq!(vec.len(), i);
707 assert!(m1 == vec.as_column_matrix());
708 }
709 }
710
711 #[test]
712 fn test_as_vector_row() {
713 for i in 1..25 {
714 let m1 = BinMatrix::random(1, i);
715 let vec = m1.as_vector();
716 assert_eq!(vec.len(), i);
717 assert!(m1 == vec.as_matrix());
718 }
719 }
720
721 #[test]
722 fn zero() {
723 let m1 = BinMatrix::zero(10, 3);
724 for i in 0..10 {
725 for j in 0..3 {
726 assert_eq!(m1.bit(i, j), false);
727 }
728 }
729 }
730
731 #[test]
732 fn set_window() {
733 let mut m1 = BinMatrix::zero(10, 10);
734 m1.set_window(5, 5, &BinMatrix::identity(5));
735 for i in 0..5 {
736 for j in 0..5 {
737 assert_eq!(m1.bit(i, j), false);
738 }
739 }
740 for i in 5..10 {
741 for j in 5..10 {
742 let bit = m1.bit(i, j);
743 assert_eq!(bit, i == j, "bit ({},{}) was {}", i, j, bit);
744 }
745 }
746
747 let mut m1 = BinMatrix::random(10, 10);
748 m1.set_window(5, 5, &BinMatrix::identity(5));
749 for i in 5..10 {
750 for j in 5..10 {
751 let bit = m1.bit(i, j);
752 assert_eq!(bit, i == j, "bit ({},{}) was {}", i, j, bit);
753 }
754 }
755 }
756
757 #[test]
758 fn test_random_unequal() {
759 let m1 = BinMatrix::random(100, 100);
760 let m2 = BinMatrix::random(100, 100);
761 assert_ne!(m1, m2);
762 }
763
764 #[test]
765 fn test_count_ones() {
766 let rng = &mut rand::thread_rng();
767 for _ in 0..1000 {
768 let size = rng.gen_range(1..1000);
769 let v = BinVector::random(size);
770 assert_eq!(v.count_ones(), v.as_matrix().count_ones());
771 assert_eq!(v.count_ones(), v.as_column_matrix().count_ones());
772 }
773 }
774}