m4ri_rust/friendly/
binary_matrix.rs

1use ffi::*;
2use friendly::binary_vector::BinVector;
3use libc::c_int;
4use std::cmp;
5use std::ops;
6use std::ptr;
7#[cfg(feature = "serde")]
8use vob::Vob;
9
10#[cfg(feature = "serde")]
11#[derive(Serialize)]
12#[serde(remote = "ptr::NonNull<Mzd>")]
13struct MzdSerializer {
14    #[serde(getter = "mzd_to_vecs")]
15    rows: Vec<Vob>,
16}
17
18#[cfg(feature = "serde")]
19fn mzd_to_vecs(mzd: &ptr::NonNull<Mzd>) -> Vec<Vob> {
20    let m = BinMatrix { mzd: *mzd };
21    let result = (0..m.nrows())
22        .into_iter()
23        .map(|r| m.get_window(r, 0, r + 1, m.ncols()).as_vector().into_vob())
24        .collect();
25    // We shouldn't free m as we stole mzd.
26    std::mem::forget(m);
27    result
28}
29
30/// Structure to represent matrices
31#[derive(Debug)]
32#[cfg_attr(feature = "serde", derive(Serialize))]
33pub struct BinMatrix {
34    #[cfg_attr(feature = "serde", serde(with = "MzdSerializer", rename = "matrix"))]
35    mzd: ptr::NonNull<Mzd>,
36}
37
38unsafe impl Sync for BinMatrix {}
39unsafe impl Send for BinMatrix {}
40
41impl ops::Drop for BinMatrix {
42    fn drop(&mut self) {
43        unsafe { ptr::drop_in_place(self.mzd.as_ptr()) }
44    }
45}
46
47macro_rules! nonnull {
48    ($exp:expr) => {
49        ptr::NonNull::new_unchecked($exp)
50    };
51}
52
53#[cfg(all(
54    feature = "m4rm_mul",
55    not(any(feature = "strassen_mul", feature = "naive_mul"))
56))]
57macro_rules! mul_impl {
58    ($dest:expr, $a:expr, $b:expr) => {
59        mzd_mul_m4rm($dest, $a, $b, 0)
60    };
61}
62
63#[cfg(any(
64    all(
65        feature = "strassen_mul",
66        not(any(feature = "m4rm_mul", feature = "naive_mul"))
67    ),
68    not(any(feature = "strassen_mul", feature = "m4rm_mul", feature = "naive_mul"))
69))]
70macro_rules! mul_impl {
71    ($dest:expr, $a:expr, $b:expr) => {
72        mzd_mul($dest, $a, $b, 0)
73    };
74}
75
76#[cfg(all(
77    feature = "naive_mul",
78    not(any(feature = "m4rm_mul", feature = "strassen_mul"))
79))]
80macro_rules! mul_impl {
81    ($dest:expr, $a:expr, $b:expr) => {
82        mzd_mul_naive($dest, $a, $b)
83    };
84}
85
86#[cfg(any(
87    all(feature = "naive_mul", feature = "m4rm_mul"),
88    all(feature = "strassen_mul", feature = "naive_mul"),
89    all(feature = "m4rm_mul", feature = "strassen_mul")
90))]
91macro_rules! mul_impl {
92    ($($a:expr),*) => {
93        compile_error!("You need to set only one of the feature flags as mul strategy")
94    };
95}
96
97impl BinMatrix {
98    /// Create a zero matrix
99    pub fn zero(rows: usize, cols: usize) -> BinMatrix {
100        if rows == 0 || cols == 0 {
101            panic!("Can't create a 0 matrix");
102        }
103        let mzd = unsafe { nonnull!(mzd_init(rows as c_int, cols as c_int)) };
104        BinMatrix { mzd }
105    }
106
107    /// Create a new matrix
108    pub fn new(rows: Vec<BinVector>) -> BinMatrix {
109        let rowlen = rows[0].len();
110        let storage: Vec<Vec<u64>> = rows
111            .iter()
112            .map(|vec| {
113                vec.get_storage()
114                    .into_iter()
115                    .copied()
116                    .map(|b| b as u64)
117                    .collect()
118            })
119            .collect();
120        BinMatrix::from_slices(&storage, rowlen)
121    }
122
123    /// Create a new matrix from slices
124    pub fn from_slices<T: AsRef<[u64]>>(rows: &[T], rowlen: usize) -> BinMatrix {
125        if rows.is_empty() || rowlen == 0 {
126            panic!("Can't create a 0 matrix");
127        }
128
129        for row in rows {
130            debug_assert!(row.as_ref().len() * 64 >= rowlen, "expected len {} bits but got only {} blocks", rowlen, row.as_ref().len());
131        }
132
133        let mzd_ptr = unsafe { mzd_init(rows.len() as c_int, rowlen as c_int) };
134
135        let blocks_per_row = rowlen / 64 + if rowlen % 64 == 0 { 0 } else { 1 };
136        // Directly write to the underlying Mzd storage
137        for (row_index, row) in rows.into_iter().enumerate() {
138            let row_ptr: *const *mut Word = unsafe { (*mzd_ptr).rows.add(row_index) };
139            for (block_index, row_block) in row
140                .as_ref()
141                .iter()
142                .take(blocks_per_row)
143                .copied()
144                .enumerate()
145            {
146                assert_eq!(
147                    ::std::mem::size_of::<usize>(),
148                    ::std::mem::size_of::<u64>(),
149                    "only works on 64 bit"
150                );
151                let row_block = if block_index == rowlen / 64 {
152                    row_block & ((1 << (rowlen % 64)) - 1)
153                } else {
154                    row_block
155                };
156                unsafe {
157                    *((*row_ptr).add(block_index)) = row_block as u64;
158                }
159            }
160        }
161
162        unsafe {
163            BinMatrix {
164                mzd: nonnull!(mzd_ptr),
165            }
166        }
167    }
168
169    /// Get the hamming weight for single-row or single-column matrices (ie. vectors)
170    ///
171    /// **Panics** if ``nrows > 1 && ncols > 1``
172    pub fn count_ones(&self) -> u32 {
173        assert!(self.nrows() == 1 || self.ncols() == 1, "only works on single row or single column matrices");
174        let mut accumulator = 0;
175        for row in 0..self.nrows() {
176            let row_ptr: *const *mut Word = unsafe { (*self.mzd.as_ptr()).rows.add(row) };
177            for i in 0..(self.ncols() / 64) {
178                let word_ptr: *const Word = unsafe { (*row_ptr).add(i) };
179                accumulator += unsafe { (*word_ptr).count_ones() };
180            }
181            // process last block
182            if self.ncols() % 64 != 0 {
183                let word_ptr: *const Word = unsafe { (*row_ptr).add((self.ncols() - 1) / 64) };
184                let word = unsafe { *word_ptr } & ((1 << self.ncols() % 64) - 1);
185                accumulator += word.count_ones();
186            }
187        }
188        accumulator
189    }
190
191    /// Construct a randomized matrix
192    pub fn random(rows: usize, columns: usize) -> BinMatrix {
193        let mzd = unsafe { mzd_init(rows as Rci, columns as Rci) };
194        // Randomize
195        unsafe {
196            mzd_randomize(mzd);
197        }
198        unsafe { BinMatrix { mzd: nonnull!(mzd) } }
199    }
200
201    /// Construct a BinMatrix from the raw mzd pointer
202    pub fn from_mzd(mzd: *mut Mzd) -> BinMatrix {
203        let mzd = ptr::NonNull::new(mzd).expect("Can't be NULL");
204        BinMatrix { mzd }
205    }
206
207    /// Get an identity matrix
208    #[inline]
209    pub fn identity(rows: usize) -> BinMatrix {
210        unsafe {
211            let mzd_ptr = mzd_init(rows as c_int, rows as c_int);
212            mzd_set_ui(mzd_ptr, 1);
213            let mzd = nonnull!(mzd_ptr);
214            BinMatrix { mzd }
215        }
216    }
217
218    /// Augment the matrix:
219    ///  ``[A] [B] => [A B]``
220    #[inline]
221    pub fn augmented(&self, other: &BinMatrix) -> BinMatrix {
222        debug_assert_eq!(self.nrows(), other.nrows(), "The rows need to be equal");
223        let mzd = unsafe {
224            nonnull!(mzd_concat(
225                ptr::null_mut(),
226                self.mzd.as_ptr(),
227                other.mzd.as_ptr()
228            ))
229        };
230        BinMatrix { mzd }
231    }
232
233    /// Stack the matrix with another and return the result
234    #[inline]
235    pub fn stacked(&self, other: &BinMatrix) -> BinMatrix {
236        let mzd = unsafe {
237            nonnull!(mzd_stack(
238                ptr::null_mut(),
239                self.mzd.as_ptr(),
240                other.mzd.as_ptr()
241            ))
242        };
243        BinMatrix { mzd }
244    }
245
246    /// Get the rank of the matrix
247    ///
248    /// Does an echelonization and throws it away!
249    #[inline]
250    pub fn rank(&self) -> usize {
251        self.clone().echelonize()
252    }
253
254    /// Echelonize this matrix in-place
255    ///
256    /// Return: the rank of the matrix
257    #[inline]
258    pub fn echelonize(&mut self) -> usize {
259        let rank = unsafe { mzd_echelonize(self.mzd.as_ptr(), false as c_int) };
260        rank as usize
261    }
262
263    /// Compute the inverse of this matrix, returns a new matrix
264    #[inline]
265    pub fn inverted(&self) -> BinMatrix {
266        let mzd = unsafe { nonnull!(mzd_inv_m4ri(ptr::null_mut(), self.mzd.as_ptr(), 0 as c_int)) };
267        BinMatrix { mzd }
268    }
269
270    /// Compute the transpose of the matrix
271    #[inline]
272    pub fn transposed(&self) -> BinMatrix {
273        let mzd;
274        unsafe {
275            let mzd_ptr = mzd_transpose(ptr::null_mut(), self.mzd.as_ptr());
276            mzd = nonnull!(mzd_ptr);
277        }
278        BinMatrix { mzd }
279    }
280
281    /// Get the number of rows
282    ///
283    /// O(1)
284    #[inline]
285    pub fn nrows(&self) -> usize {
286        unsafe { self.mzd.as_ref().nrows as usize }
287    }
288
289    /// Get the number of columns
290    ///
291    /// O(1)
292    #[inline]
293    pub fn ncols(&self) -> usize {
294        unsafe { self.mzd.as_ref().ncols as usize }
295    }
296
297    /// Get a single word from the matrix at a certain offset
298    pub fn get_word(&self, row: usize, column: usize) -> Word {
299        assert!(row < self.nrows());
300        assert!(column < self.ncols());
301
302        unsafe { self.get_word_unchecked(row, column) }
303    }
304
305    /// Get a particular word from the matrix
306    /// Does not do any bounds checking!
307    #[inline]
308    pub unsafe fn get_word_unchecked(&self, row: usize, column: usize) -> Word {
309        let row_ptr: *const *mut Word = (*self.mzd.as_ptr()).rows.add(row);
310        let word_ptr: *const Word = ((*row_ptr) as *const Word).add(column);
311        *word_ptr
312    }
313
314    /// Get a mutable reference to a particular word in the matrix
315    pub fn get_word_mut(&self, row: usize, column: usize) -> &mut Word {
316        assert!(row < self.nrows());
317        assert!(column < self.ncols());
318        unsafe { self.get_word_mut_unchecked(row, column) }
319    }
320
321    /// Get a mutable reference to a particular word in the matrix without bounds checking.
322    #[inline]
323    pub unsafe fn get_word_mut_unchecked(&self, row: usize, column: usize) -> &mut Word {
324        let row_ptr: *const *mut Word = (*self.mzd.as_ptr()).rows.add(row);
325        let word_ptr: *mut Word = ((*row_ptr) as *mut Word).add(column / 64);
326        word_ptr.as_mut().unwrap()
327    }
328
329    /// Get as a vector
330    ///
331    /// Works both on single-column and single-row matrices
332    pub fn as_vector(&self) -> BinVector {
333        if self.nrows() != 1 {
334            assert_eq!(self.ncols(), 1, "needs to have only one column or row");
335            self.transposed().as_vector()
336        } else {
337            assert_eq!(self.nrows(), 1, "needs to have only one column or row");
338            let mut bits = BinVector::with_capacity(self.ncols());
339            {
340                let collector = unsafe { bits.get_storage_mut() };
341                for i in 0..(self.ncols() / 64) {
342                    let row_ptr: *const *mut Word = unsafe { (*self.mzd.as_ptr()).rows };
343                    let word_ptr: *const Word = unsafe { ((*row_ptr) as *const Word).add(i) };
344                    collector.push(unsafe { *word_ptr as usize });
345                }
346                // process last block
347                if self.ncols() % 64 != 0 {
348                    let row_ptr: *const *mut Word = unsafe { (*self.mzd.as_ptr()).rows };
349                    let word_ptr: *const Word = unsafe { (*row_ptr).add((self.ncols() - 1) / 64) };
350                    let word = unsafe { *word_ptr };
351                    collector.push(word as usize);
352                }
353            }
354            unsafe {
355                bits.set_len(self.ncols());
356                bits.mask_last_block();
357            }
358
359            bits
360        }
361    }
362
363    /// Get a certain bit
364    pub fn bit(&self, row: usize, col: usize) -> bool {
365        let bit = unsafe { mzd_read_bit(self.mzd.as_ptr(), row as Rci, col as Rci) };
366        debug_assert!(bit == 0 || bit == 1, "Invalid bool for bit??");
367        bit == 1
368    }
369
370    /// Get a window from the matrix. Makes a copy.
371    pub fn get_window(
372        &self,
373        start_row: usize,
374        start_col: usize,
375        high_row: usize,
376        high_col: usize,
377    ) -> BinMatrix {
378        let (rows, cols) = (high_row - start_row, high_col - start_col);
379        debug_assert!(rows > 0 && rows <= self.nrows());
380        debug_assert!(cols > 0 && cols <= self.ncols());
381        let mzd_ptr = unsafe { mzd_init(rows as Rci, cols as Rci) };
382        for (r, i) in (start_row..high_row).enumerate() {
383            // FIXME speed
384            for (c, j) in (start_col..high_col).enumerate() {
385                let bit = self.bit(i, j);
386                unsafe {
387                    mzd_write_bit(mzd_ptr, r as Rci, c as Rci, bit as BIT);
388                }
389            }
390        }
391        BinMatrix::from_mzd(mzd_ptr)
392    }
393
394    /// Set a window in the matrix to another matrix
395    ///
396    /// Currently does bit-by-bit, should use more optimal means
397    /// if alignment allows it
398    pub fn set_window(&mut self, start_row: usize, start_col: usize, other: &BinMatrix) {
399        let highr = start_row + other.nrows();
400        let highc = start_col + other.ncols();
401        debug_assert!(self.ncols() >= highc, "This matrix is too small!");
402        debug_assert!(self.nrows() >= highr, "This matrix has too few rows !");
403        let mzd_ptr = self.mzd.as_ptr();
404
405        for r in start_row..highr {
406            for c in start_col..highc {
407                let bit = other.bit(r - start_row, c - start_col);
408                unsafe {
409                    mzd_write_bit(mzd_ptr, r as Rci, c as Rci, bit as BIT);
410                }
411            }
412        }
413    }
414
415    /// Multiply a matrix by a vector represented as a [u64]
416    pub fn mul_slice(&self, other: &[u64]) -> BinMatrix {
417        // I've tried to use thread-local storage for the temporary here, but it wasn't faster.
418        debug_assert!(
419            self.ncols() <= other.len() * 64,
420            "Mismatched sizes: ({}x{}) * ({}x1) (too big)",
421            self.nrows(),
422            self.ncols(),
423            other.len() * 64
424        );
425        let result = {
426            let other = BinMatrix::from_slices(&[other], self.ncols()).transposed();
427            unsafe { mzd_mul_naive(ptr::null_mut(), self.mzd.as_ptr(), other.mzd.as_ptr()) }
428        };
429        let matresult = BinMatrix::from_mzd(result);
430        matresult
431    }
432}
433
434impl cmp::PartialEq for BinMatrix {
435    fn eq(&self, other: &BinMatrix) -> bool {
436        unsafe { mzd_equal(self.mzd.as_ptr(), other.mzd.as_ptr()) == 1 }
437    }
438}
439
440impl cmp::Eq for BinMatrix {}
441
442impl ops::Mul<BinMatrix> for BinMatrix {
443    type Output = BinMatrix;
444
445    /// Computes the product of two matrices
446    #[inline]
447    fn mul(self, other: BinMatrix) -> Self::Output {
448        &self * &other
449    }
450}
451
452impl std::clone::Clone for BinMatrix {
453    fn clone(&self) -> Self {
454        let mzd = unsafe { nonnull!(mzd_copy(ptr::null_mut(), self.mzd.as_ptr())) };
455        BinMatrix { mzd }
456    }
457}
458
459impl<'a> ops::Mul<&'a BinMatrix> for &'a BinMatrix {
460    type Output = BinMatrix;
461    /// Computes the product of two matrices
462    #[inline]
463    fn mul(self, other: &BinMatrix) -> Self::Output {
464        unsafe {
465            let mzd_ptr = mul_impl!(ptr::null_mut(), self.mzd.as_ptr(), other.mzd.as_ptr());
466
467            BinMatrix {
468                mzd: ptr::NonNull::new(mzd_ptr).expect("Multiplication failed"),
469            }
470        }
471    }
472}
473
474impl<'a> ops::Add<&'a BinMatrix> for &'a BinMatrix {
475    type Output = BinMatrix;
476
477    /// Add up two matrices
478    #[inline]
479    fn add(self, other: &BinMatrix) -> Self::Output {
480        let mzd = unsafe {
481            nonnull!(mzd_add(
482                ptr::null_mut(),
483                self.mzd.as_ptr(),
484                other.mzd.as_ptr()
485            ))
486        };
487        BinMatrix { mzd }
488    }
489}
490
491impl ops::Add<BinMatrix> for BinMatrix {
492    type Output = BinMatrix;
493
494    /// Add up two matrices, re-uses memory of A
495    #[inline]
496    fn add(self, other: BinMatrix) -> Self::Output {
497        let mzd = unsafe {
498            nonnull!(mzd_add(
499                self.mzd.as_ptr(),
500                self.mzd.as_ptr(),
501                other.mzd.as_ptr()
502            ))
503        };
504        BinMatrix { mzd }
505    }
506}
507
508impl ops::AddAssign<BinMatrix> for BinMatrix {
509    /// Add up two matrices, re-uses memory of A
510    #[inline]
511    fn add_assign(&mut self, other: BinMatrix) {
512        unsafe {
513            mzd_add(self.mzd.as_ptr(), self.mzd.as_ptr(), other.mzd.as_ptr());
514        }
515    }
516}
517
518impl<'a> ops::AddAssign<&'a BinMatrix> for BinMatrix {
519    /// Add up two matrices, re-uses memory of A
520    #[inline]
521    fn add_assign(&mut self, other: &BinMatrix) {
522        unsafe {
523            mzd_add(self.mzd.as_ptr(), self.mzd.as_ptr(), other.mzd.as_ptr());
524        }
525    }
526}
527
528impl<'a> ops::Mul<&'a BinVector> for &'a BinMatrix {
529    type Output = BinVector;
530    /// Computes (A * v^T)
531    #[inline]
532    fn mul(self, other: &BinVector) -> Self::Output {
533        self.mul_slice(
534            &other
535                .get_storage()
536                .iter()
537                .copied()
538                .map(|b| b as u64)
539                .collect::<Vec<u64>>(),
540        ).as_vector()
541    }
542}
543
544impl ops::Mul<BinVector> for BinMatrix {
545    type Output = BinVector;
546    /// Computes (A * v^T)
547    fn mul(self, other: BinVector) -> Self::Output {
548        &self * &other
549    }
550}
551
552impl<'a> ops::Mul<&'a BinMatrix> for &'a BinVector {
553    type Output = BinVector;
554
555    #[inline]
556    /// computes v^T * A
557    fn mul(self, other: &BinMatrix) -> Self::Output {
558        let vec_mzd = self.as_matrix();
559        let tmp = &vec_mzd * other;
560
561        tmp.as_vector()
562    }
563}
564
565impl ops::Mul<BinMatrix> for BinVector {
566    type Output = BinVector;
567
568    #[inline]
569    /// computes v^T * A
570    fn mul(self, other: BinMatrix) -> Self::Output {
571        &self * &other
572    }
573}
574
575/// Solve AX = B for X
576///
577/// Modifies B in-place
578///
579/// B will contain the solution afterwards
580///
581/// Return True if it succeeded
582pub fn solve_left(a: BinMatrix, b: &mut BinMatrix) -> bool {
583    let result = unsafe { mzd_solve_left(a.mzd.as_ptr(), b.mzd.as_ptr(), 0, 1) };
584
585    result == 0
586}
587
588#[cfg(test)]
589mod test {
590    use super::*;
591    use rand::prelude::*;
592    use vob::Vob;
593
594    #[test]
595    fn new() {
596        let _m = BinMatrix::new(vec![
597            BinVector::from(vob![true, false, true]),
598            BinVector::from(vob![true, true, true]),
599        ]);
600    }
601
602    #[test]
603    fn identity() {
604        let id = BinMatrix::new(vec![
605            BinVector::from(vob![
606                true, false, false, false, false, false, false, false, false, false
607            ]),
608            BinVector::from(vob![
609                false, true, false, false, false, false, false, false, false, false
610            ]),
611            BinVector::from(vob![
612                false, false, true, false, false, false, false, false, false, false
613            ]),
614            BinVector::from(vob![
615                false, false, false, true, false, false, false, false, false, false
616            ]),
617            BinVector::from(vob![
618                false, false, false, false, true, false, false, false, false, false
619            ]),
620            BinVector::from(vob![
621                false, false, false, false, false, true, false, false, false, false
622            ]),
623            BinVector::from(vob![
624                false, false, false, false, false, false, true, false, false, false
625            ]),
626            BinVector::from(vob![
627                false, false, false, false, false, false, false, true, false, false
628            ]),
629            BinVector::from(vob![
630                false, false, false, false, false, false, false, false, true, false
631            ]),
632            BinVector::from(vob![
633                false, false, false, false, false, false, false, false, false, true
634            ]),
635        ]);
636
637        let id_gen = BinMatrix::identity(10);
638        assert_eq!(id.nrows(), id_gen.nrows());
639        assert_eq!(id.ncols(), id_gen.ncols());
640        for i in 0..8 {
641            for j in 0..8 {
642                let m1 = id.mzd.as_ptr();
643                let m2 = id_gen.mzd.as_ptr();
644                unsafe {
645                    assert_eq!(
646                        mzd_read_bit(m1, i, j),
647                        mzd_read_bit(m2, i, j),
648                        "({}, {})",
649                        i,
650                        j
651                    );
652                }
653            }
654        }
655        unsafe {
656            assert!(mzd_equal(id.mzd.as_ptr(), id_gen.mzd.as_ptr()) != 0);
657        }
658        assert_eq!(id, id_gen);
659    }
660
661    #[test]
662    fn mul() {
663        let m1 = BinMatrix::identity(8);
664        let m2 = BinMatrix::identity(8);
665        let m3 = BinMatrix::identity(8);
666        let prod = m1 * m2;
667        unsafe {
668            assert!(mzd_equal(prod.mzd.as_ptr(), m3.mzd.as_ptr()) != 0);
669        }
670    }
671
672    #[test]
673    fn vecmul() {
674        let m1 = BinMatrix::identity(10);
675        let binvec = BinVector::from(Vob::from_elem(10, true));
676
677        let result: BinVector = &m1 * &binvec;
678        assert_eq!(result, binvec);
679
680        let result: BinVector = &binvec * &m1;
681        assert_eq!(result, binvec);
682
683        let m1 = BinMatrix::random(10, 3);
684        let result = &binvec * &m1;
685        assert_eq!(result.len(), 3);
686    }
687
688    #[test]
689    fn test_random() {
690        BinMatrix::random(10, 1);
691    }
692
693    #[cfg(feature = "serde")]
694    #[test]
695    fn test_serialize() {
696        let m = BinMatrix::identity(3);
697        let json = serde_json::to_string(&m).unwrap();
698        assert_eq!(json, "{\"matrix\":{\"rows\":[{\"len\":3,\"vec\":[1]},{\"len\":3,\"vec\":[2]},{\"len\":3,\"vec\":[4]}]}}");
699    }
700
701    #[test]
702    fn test_as_vector_column() {
703        for i in 1..25 {
704            let m1 = BinMatrix::random(i, 1);
705            let vec = m1.as_vector();
706            assert_eq!(vec.len(), i);
707            assert!(m1 == vec.as_column_matrix());
708        }
709    }
710
711    #[test]
712    fn test_as_vector_row() {
713        for i in 1..25 {
714            let m1 = BinMatrix::random(1, i);
715            let vec = m1.as_vector();
716            assert_eq!(vec.len(), i);
717            assert!(m1 == vec.as_matrix());
718        }
719    }
720
721    #[test]
722    fn zero() {
723        let m1 = BinMatrix::zero(10, 3);
724        for i in 0..10 {
725            for j in 0..3 {
726                assert_eq!(m1.bit(i, j), false);
727            }
728        }
729    }
730
731    #[test]
732    fn set_window() {
733        let mut m1 = BinMatrix::zero(10, 10);
734        m1.set_window(5, 5, &BinMatrix::identity(5));
735        for i in 0..5 {
736            for j in 0..5 {
737                assert_eq!(m1.bit(i, j), false);
738            }
739        }
740        for i in 5..10 {
741            for j in 5..10 {
742                let bit = m1.bit(i, j);
743                assert_eq!(bit, i == j, "bit ({},{}) was {}", i, j, bit);
744            }
745        }
746
747        let mut m1 = BinMatrix::random(10, 10);
748        m1.set_window(5, 5, &BinMatrix::identity(5));
749        for i in 5..10 {
750            for j in 5..10 {
751                let bit = m1.bit(i, j);
752                assert_eq!(bit, i == j, "bit ({},{}) was {}", i, j, bit);
753            }
754        }
755    }
756
757    #[test]
758    fn test_random_unequal() {
759        let m1 = BinMatrix::random(100, 100);
760        let m2 = BinMatrix::random(100, 100);
761        assert_ne!(m1, m2);
762    }
763
764    #[test]
765    fn test_count_ones() {
766        let rng = &mut rand::thread_rng();
767        for _ in 0..1000 {
768            let size = rng.gen_range(1..1000);
769            let v = BinVector::random(size);
770            assert_eq!(v.count_ones(), v.as_matrix().count_ones());
771            assert_eq!(v.count_ones(), v.as_column_matrix().count_ones());
772        }
773    }
774}