1use core::cmp;
4use core::ops::Range;
5use core::ops::{Index, IndexMut};
6
7#[cfg(all(feature = "miniserde", not(feature = "serde")))]
8use miniserde::{Deserialize, Serialize};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "memusage")]
13use memusage::MemoryReport;
14
15use bit_vec::BitVec;
16
17use super::{FALSE, TRUE};
18use crate::local_prelude::*;
19use crate::row::Iter;
20use crate::util::round_up_to_next;
21
22#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
24#[cfg_attr(
25 any(feature = "serde", feature = "miniserde"),
26 derive(Serialize, Deserialize)
27)]
28pub struct BitMatrix {
29 bit_vec: BitVec,
30 row_bits: usize,
31}
32
33impl BitMatrix {
36 pub fn new(rows: usize, row_bits: usize) -> Self {
38 BitMatrix {
39 bit_vec: BitVec::from_elem(round_up_to_next(row_bits, BITS) * rows, false),
40 row_bits: row_bits,
41 }
42 }
43
44 #[inline]
46 fn num_rows(&self) -> usize {
47 if self.row_bits == 0 {
48 0
49 } else {
50 let row_blocks = round_up_to_next(self.row_bits, BITS) / BITS;
51 self.bit_vec.storage().len() / row_blocks
52 }
53 }
54
55 pub fn size(&self) -> (usize, usize) {
57 (self.num_rows(), self.row_bits)
58 }
59
60 #[inline]
66 pub fn set(&mut self, row: usize, col: usize, enabled: bool) {
67 let row_size_in_bits = round_up_to_next(self.row_bits, BITS);
68 self.bit_vec.set(row * row_size_in_bits + col, enabled);
69 }
70
71 #[inline]
73 pub fn set_all(&mut self, enabled: bool) {
74 if enabled {
75 self.bit_vec.set_all();
76 } else {
77 self.bit_vec.clear();
78 }
79 }
80
81 pub fn grow(&mut self, num_rows: usize, value: bool) {
83 self.bit_vec
84 .grow(round_up_to_next(self.row_bits, BITS) * num_rows, value);
85 }
86
87 pub fn truncate(&mut self, num_rows: usize) {
89 self.bit_vec
90 .truncate(round_up_to_next(self.row_bits, BITS) * num_rows);
91 }
92
93 #[inline]
95 pub fn sub_matrix(&self, range: Range<usize>) -> BitSubMatrix {
96 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
97 BitSubMatrix::new(
98 &self.bit_vec.storage()[range.start * row_size..range.end * row_size],
99 self.row_bits,
100 )
101 }
102
103 #[inline]
109 pub fn split_at(&self, row: usize) -> (BitSubMatrix, BitSubMatrix) {
110 (
111 self.sub_matrix(0..row),
112 self.sub_matrix(row..self.num_rows()),
113 )
114 }
115
116 #[inline]
119 pub fn split_at_mut(&mut self, row: usize) -> (BitSubMatrixMut, BitSubMatrixMut) {
120 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
121 let (first, second) = unsafe { self.bit_vec.storage_mut().split_at_mut(row * row_size) };
122 (
123 BitSubMatrixMut::new(first, self.row_bits),
124 BitSubMatrixMut::new(second, self.row_bits),
125 )
126 }
127
128 pub fn iter_row(&self, row: usize) -> Iter {
130 BitSlice::new(&self[row].slice).iter_bits(self.row_bits)
131 }
132
133 pub fn transitive_closure(&mut self) {
137 assert_eq!(self.num_rows(), self.row_bits);
138 for pos in 0..self.row_bits {
139 let (mut rows0, mut rows1a) = self.split_at_mut(pos);
140 let (row, mut rows1b) = rows1a.split_at_mut(1);
141 for dst_row in rows0.iter_mut().chain(rows1b.iter_mut()) {
142 if dst_row[pos] {
143 for (dst, src) in dst_row.iter_blocks_mut().zip(row[0].iter_blocks()) {
144 *dst |= *src;
145 }
146 }
147 }
148 }
149 }
150
151 pub fn reflexive_closure(&mut self) {
153 for i in 0..cmp::min(self.row_bits, self.num_rows()) {
154 self.set(i, i, true);
155 }
156 }
157}
158
159impl Index<usize> for BitMatrix {
161 type Output = BitSlice;
162
163 #[inline]
164 fn index(&self, row: usize) -> &BitSlice {
165 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
166 BitSlice::new(&self.bit_vec.storage()[row * row_size..(row + 1) * row_size])
167 }
168}
169
170impl IndexMut<usize> for BitMatrix {
172 #[inline]
173 fn index_mut(&mut self, row: usize) -> &mut BitSlice {
174 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
175 unsafe {
176 BitSlice::new_mut(&mut self.bit_vec.storage_mut()[row * row_size..(row + 1) * row_size])
177 }
178 }
179}
180
181impl Index<(usize, usize)> for BitMatrix {
183 type Output = bool;
184
185 #[inline]
186 fn index(&self, (row, col): (usize, usize)) -> &bool {
187 let row_size_in_bits = round_up_to_next(self.row_bits, BITS);
188 if self.bit_vec.get(row * row_size_in_bits + col).unwrap() {
189 &TRUE
190 } else {
191 &FALSE
192 }
193 }
194}
195
196#[cfg(feature = "memusage")]
197impl MemoryReport for BitMatrix {
198 fn indirect(&self) -> usize {
199 (self.bit_vec.capacity() + 31) / 32 * 4
200 }
201}
202
203#[test]
206fn test_empty() {
207 let mut matrix = BitMatrix::new(0, 0);
208 for _ in 0..3 {
209 assert_eq!(matrix.num_rows(), 0);
210 assert_eq!(matrix.size(), (0, 0));
211 matrix.transitive_closure();
212 }
213}