1use core::cmp;
4use core::ops::{Index, IndexMut, RangeBounds};
5
6use bit_vec::BitVec;
7
8use super::{FALSE, TRUE};
9use crate::local_prelude::*;
10use crate::row::Iter;
11use crate::util::round_up_to_next;
12
13#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
15#[cfg_attr(
16 feature = "miniserde",
17 derive(miniserde::Serialize, miniserde::Deserialize)
18)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20pub struct BitMatrix {
21 bit_vec: BitVec,
22 row_bits: usize,
23}
24
25impl BitMatrix {
28 pub fn new(rows: usize, row_bits: usize) -> Self {
30 BitMatrix {
31 bit_vec: BitVec::from_elem(round_up_to_next(row_bits, BITS) * rows, false),
32 row_bits: row_bits,
33 }
34 }
35
36 #[inline]
38 fn num_rows(&self) -> usize {
39 if self.row_bits == 0 {
40 0
41 } else {
42 let row_blocks = round_up_to_next(self.row_bits, BITS) / BITS;
43 self.bit_vec.storage().len() / row_blocks
44 }
45 }
46
47 pub fn size(&self) -> (usize, usize) {
49 (self.num_rows(), self.row_bits)
50 }
51
52 #[inline]
58 pub fn set(&mut self, row: usize, col: usize, enabled: bool) {
59 let row_size_in_bits = round_up_to_next(self.row_bits, BITS);
60 self.bit_vec.set(row * row_size_in_bits + col, enabled);
61 }
62
63 #[inline]
65 pub fn set_all(&mut self, enabled: bool) {
66 if enabled {
67 self.bit_vec.set_all();
68 } else {
69 self.bit_vec.clear();
70 }
71 }
72
73 pub fn grow(&mut self, num_rows: usize, value: bool) {
75 self.bit_vec
76 .grow(round_up_to_next(self.row_bits, BITS) * num_rows, value);
77 }
78
79 pub fn truncate(&mut self, num_rows: usize) {
81 self.bit_vec
82 .truncate(round_up_to_next(self.row_bits, BITS) * num_rows);
83 }
84
85 #[inline]
87 pub fn sub_matrix<R: RangeBounds<usize>>(&self, range: R) -> BitSubMatrix {
88 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
89 BitSubMatrix {
90 slice: &self.bit_vec.storage()[(
91 range.start_bound().map(|&s| s * row_size),
92 range.end_bound().map(|&e| e * row_size),
93 )],
94 row_bits: self.row_bits,
95 }
96 }
97
98 #[inline]
104 pub fn split_at(&self, row: usize) -> (BitSubMatrix, BitSubMatrix) {
105 (
106 self.sub_matrix(0..row),
107 self.sub_matrix(row..self.num_rows()),
108 )
109 }
110
111 #[inline]
114 pub fn split_at_mut(&mut self, row: usize) -> (BitSubMatrixMut, BitSubMatrixMut) {
115 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
116 let (first, second) = unsafe { self.bit_vec.storage_mut().split_at_mut(row * row_size) };
117 (
118 BitSubMatrixMut::new(first, self.row_bits),
119 BitSubMatrixMut::new(second, self.row_bits),
120 )
121 }
122
123 pub fn iter_row(&self, row: usize) -> Iter {
125 BitSlice::new(&self[row].slice).iter_bits(self.row_bits)
126 }
127
128 pub fn transitive_closure(&mut self) {
132 assert_eq!(self.num_rows(), self.row_bits);
133 for pos in 0..self.row_bits {
134 let (mut rows0, mut rows1a) = self.split_at_mut(pos);
135 let (row, mut rows1b) = rows1a.split_at_mut(1);
136 for dst_row in rows0.iter_mut().chain(rows1b.iter_mut()) {
137 if dst_row[pos] {
138 for (dst, src) in dst_row.iter_blocks_mut().zip(row[0].iter_blocks()) {
139 *dst |= *src;
140 }
141 }
142 }
143 }
144 }
145
146 pub fn reflexive_closure(&mut self) {
148 for i in 0..cmp::min(self.row_bits, self.num_rows()) {
149 self.set(i, i, true);
150 }
151 }
152}
153
154impl Index<usize> for BitMatrix {
156 type Output = BitSlice;
157
158 #[inline]
159 fn index(&self, row: usize) -> &BitSlice {
160 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
161 BitSlice::new(&self.bit_vec.storage()[row * row_size..(row + 1) * row_size])
162 }
163}
164
165impl IndexMut<usize> for BitMatrix {
167 #[inline]
168 fn index_mut(&mut self, row: usize) -> &mut BitSlice {
169 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
170 unsafe {
171 BitSlice::new_mut(&mut self.bit_vec.storage_mut()[row * row_size..(row + 1) * row_size])
172 }
173 }
174}
175
176impl Index<(usize, usize)> for BitMatrix {
178 type Output = bool;
179
180 #[inline]
181 fn index(&self, (row, col): (usize, usize)) -> &bool {
182 let row_size_in_bits = round_up_to_next(self.row_bits, BITS);
183 if self.bit_vec.get(row * row_size_in_bits + col).unwrap() {
184 &TRUE
185 } else {
186 &FALSE
187 }
188 }
189}
190
191#[cfg(feature = "memusage")]
192impl MemoryReport for BitMatrix {
193 fn indirect(&self) -> usize {
194 (self.bit_vec.capacity() + 31) / 32 * 4
195 }
196}
197
198#[test]
201fn test_empty() {
202 let mut matrix = BitMatrix::new(0, 0);
203 for _ in 0..3 {
204 assert_eq!(matrix.num_rows(), 0);
205 assert_eq!(matrix.size(), (0, 0));
206 matrix.transitive_closure();
207 }
208}