1use core::cmp;
4use core::fmt;
5use core::mem;
6use core::ops::RangeBounds;
7use core::ops::{Index, IndexMut};
8use core::slice;
9
10use crate::local_prelude::*;
11use crate::util::{div_rem, round_up_to_next};
12
13pub struct BitSubMatrix<'a> {
15 pub(crate) slice: &'a [Block],
16 pub(crate) row_bits: usize,
17}
18
19pub struct BitSubMatrixMut<'a> {
21 pub(crate) slice: &'a mut [Block],
22 pub(crate) row_bits: usize,
23}
24
25impl<'a> BitSubMatrix<'a> {
26 pub fn new(slice: &[Block], row_bits: usize) -> BitSubMatrix {
28 BitSubMatrix {
29 slice: slice,
30 row_bits: row_bits,
31 }
32 }
33
34 #[inline]
36 pub unsafe fn from_raw_parts(ptr: *const Block, rows: usize, row_bits: usize) -> Self {
37 BitSubMatrix {
38 slice: slice::from_raw_parts(ptr, round_up_to_next(row_bits, BITS) / BITS * rows),
39 row_bits: row_bits,
40 }
41 }
42
43 pub fn iter(&self) -> impl Iterator<Item = &BitSlice> {
45 fn f(arg: &[Block]) -> &BitSlice {
46 unsafe { mem::transmute(arg) }
47 }
48 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
49 self.slice.chunks(row_size).map(f)
50 }
51}
52
53impl<'a> BitSubMatrixMut<'a> {
54 pub fn new(slice: &mut [Block], row_bits: usize) -> BitSubMatrixMut {
56 BitSubMatrixMut {
57 slice: slice,
58 row_bits: row_bits,
59 }
60 }
61
62 #[inline]
64 pub unsafe fn from_raw_parts(ptr: *mut Block, rows: usize, row_bits: usize) -> Self {
65 BitSubMatrixMut {
66 slice: slice::from_raw_parts_mut(ptr, round_up_to_next(row_bits, BITS) / BITS * rows),
67 row_bits: row_bits,
68 }
69 }
70
71 #[inline]
73 fn num_rows(&self) -> usize {
74 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
75 if row_size == 0 {
76 0
77 } else {
78 self.slice.len() / row_size
79 }
80 }
81
82 #[inline]
88 pub fn set(&mut self, row: usize, col: usize, enabled: bool) {
89 let row_size_in_bits = round_up_to_next(self.row_bits, BITS);
90 let bit = row * row_size_in_bits + col;
91 let (block, i) = div_rem(bit, BITS);
92 assert!(block < self.slice.len() && col < self.row_bits);
93 unsafe {
94 let elt = self.slice.get_unchecked_mut(block);
95 if enabled {
96 *elt |= 1 << i;
97 } else {
98 *elt &= !(1 << i);
99 }
100 }
101 }
102
103 pub fn sub_matrix<R: RangeBounds<usize>>(&self, range: R) -> BitSubMatrix {
105 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
106 BitSubMatrix {
107 slice: &self.slice[(
108 range.start_bound().map(|&s| s * row_size),
109 range.end_bound().map(|&e| e * row_size),
110 )],
111 row_bits: self.row_bits,
112 }
113 }
114
115 #[inline]
121 pub fn split_at(&self, row: usize) -> (BitSubMatrix, BitSubMatrix) {
122 (
123 self.sub_matrix(0..row),
124 self.sub_matrix(row..self.num_rows()),
125 )
126 }
127
128 #[inline]
131 pub fn split_at_mut(&mut self, row: usize) -> (BitSubMatrixMut, BitSubMatrixMut) {
132 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
133 let (first, second) = self.slice.split_at_mut(row * row_size);
134 (
135 BitSubMatrixMut::new(first, self.row_bits),
136 BitSubMatrixMut::new(second, self.row_bits),
137 )
138 }
139
140 pub fn transitive_closure(&mut self) {
144 assert_eq!(self.num_rows(), self.row_bits);
145 for pos in 0..self.row_bits {
146 let (mut rows0, mut rows1a) = self.split_at_mut(pos);
147 let (row, mut rows1b) = rows1a.split_at_mut(1);
148 for dst_row in rows0.iter_mut().chain(rows1b.iter_mut()) {
149 if dst_row[pos] {
150 for (dst, src) in dst_row.iter_blocks_mut().zip(row[0].iter_blocks()) {
151 *dst |= src;
152 }
153 }
154 }
155 }
156 }
157
158 pub fn reflexive_closure(&mut self) {
160 for i in 0..cmp::min(self.row_bits, self.num_rows()) {
161 self.set(i, i, true);
162 }
163 }
164
165 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut BitSlice> {
167 fn f(arg: &mut [Block]) -> &mut BitSlice {
168 unsafe { mem::transmute(arg) }
169 }
170 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
171 self.slice.chunks_mut(row_size).map(f)
172 }
173}
174
175impl<'a> Index<usize> for BitSubMatrixMut<'a> {
177 type Output = BitSlice;
178
179 #[inline]
180 fn index(&self, row: usize) -> &BitSlice {
181 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
182 unsafe { mem::transmute(&self.slice[row * row_size..(row + 1) * row_size]) }
183 }
184}
185
186impl<'a> IndexMut<usize> for BitSubMatrixMut<'a> {
188 #[inline]
189 fn index_mut(&mut self, row: usize) -> &mut BitSlice {
190 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
191 unsafe { mem::transmute(&mut self.slice[row * row_size..(row + 1) * row_size]) }
192 }
193}
194
195impl<'a> Index<usize> for BitSubMatrix<'a> {
197 type Output = BitSlice;
198
199 #[inline]
200 fn index(&self, row: usize) -> &BitSlice {
201 let row_size = round_up_to_next(self.row_bits, BITS) / BITS;
202 unsafe { mem::transmute(&self.slice[row * row_size..(row + 1) * row_size]) }
203 }
204}
205
206impl<'a> fmt::Debug for BitSubMatrix<'a> {
207 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
208 for row in self.iter() {
209 for bit in row.iter_bits(self.row_bits) {
210 write!(fmt, "{}", if bit { 1 } else { 0 })?;
211 }
212 write!(fmt, "\n")?;
213 }
214 Ok(())
215 }
216}