1use bitvec::slice::BitSlice;
2
3use crate::DimensionMismatch;
4
5pub trait BitBoard: Sized {
6 fn n_rows(&self) -> usize;
8
9 fn n_cols(&self) -> usize;
11
12 fn board_mut(&mut self) -> &mut BitSlice;
14
15 fn board(&self) -> &BitSlice;
17
18 fn index_of(&self, row: usize, col: usize) -> usize {
20 assert!(
21 row <= (self.n_rows() - 1),
22 "row cannot be greater than n_rows"
23 );
24 assert!(
25 col <= (self.n_cols() - 1),
26 "col cannot be greater than n_cols"
27 );
28 (row * self.n_cols()) + col
29 }
30
31 fn row_col_of(&self, index: usize) -> (usize, usize) {
33 let row = index / self.n_cols();
34 let col = index % self.n_cols();
35 (row, col)
36 }
37
38 fn fill(&mut self, value: bool) {
40 self.board_mut().fill(value);
41 }
42
43 fn or(&self, other: &impl BitBoard) -> Result<Self, DimensionMismatch>;
44 fn and(&self, other: &impl BitBoard) -> Result<Self, DimensionMismatch>;
45
46 fn set(&mut self, row: usize, col: usize, value: bool) {
48 let new_ind = self.index_of(row, col);
49 self.board_mut().set(new_ind, value);
50 }
51
52 fn get(&self, row: usize, col: usize) -> bool {
54 if row >= self.n_rows() || col >= self.n_cols() {
55 return false;
56 }
57 let new_ind = self.index_of(row, col);
58 *self.board().get(new_ind).as_deref().unwrap_or(&false)
59 }
60
61 fn set_col(&mut self, col: usize, value: bool) {
63 for r_idx in 0..self.n_rows() {
65 let idx = (r_idx * self.n_cols()) + col;
67 self.board_mut().set(idx, value);
68 }
69 }
70
71 fn get_col(&self, col: usize) -> impl Iterator<Item = bool> {
73 (0..self.n_rows()).map(move |row| self.get(row, col))
74 }
75
76 fn set_row(&mut self, row: usize, value: bool) {
78 for cidx in 0..self.n_cols() {
80 let idx = (row * self.n_cols()) + cidx;
82 self.board_mut().set(idx, value);
83 }
84 }
85
86 fn get_row(&self, row: usize) -> impl Iterator<Item = bool> {
88 (0..self.n_cols()).map(move |col| self.get(row, col))
89 }
90
91 fn set_cardinal_neighbors(&mut self, row: usize, col: usize, value: bool) {
94 if row > 0 {
96 self.set(row - 1, col, value);
97 }
98
99 if row < self.n_rows() - 1 {
101 self.set(row + 1, col, value);
102 }
103
104 if col > 0 {
106 self.set(row, col - 1, value);
107 }
108
109 if col < self.n_cols() - 1 {
111 self.set(row, col + 1, value);
112 }
113 }
114
115 fn set_diagonals(&mut self, row: usize, col: usize, value: bool) {
118 if row > 0 && col > 0 {
120 self.set(row - 1, col - 1, value);
121 }
122
123 if row > 0 && col < self.n_cols() - 1 {
125 self.set(row - 1, col + 1, value);
126 }
127
128 if row < self.n_rows() - 1 && col > 0 {
130 self.set(row + 1, col - 1, value);
131 }
132
133 if row < self.n_rows() - 1 && col < self.n_cols() - 1 {
135 self.set(row + 1, col + 1, value);
136 }
137 }
138
139 fn set_all_neighbors(&mut self, row: usize, col: usize, value: bool) {
142 self.set_cardinal_neighbors(row, col, value);
143 self.set_diagonals(row, col, value);
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use crate::{bitboard::BitBoard, bitboardstatic::BitBoardStatic};
150 use rstest::rstest;
151
152 #[rstest]
153 #[case(0, 0, 0)]
154 #[case(1, 0, 1)]
155 #[case(2, 1, 0)]
156 #[case(3, 1, 1)]
157 fn index_of_and_row_col_of(#[case] index: usize, #[case] row: usize, #[case] col: usize) {
158 let bb = BitBoardStatic::<1>::new(2, 2);
159 assert_eq!(bb.index_of(row, col), index);
160 assert_eq!(bb.row_col_of(index), (row, col));
161 }
162
163 #[rstest]
164 #[case(0, 0, 0)]
165 #[case(1, 0, 1)]
166 #[case(2, 0, 2)]
167 #[case(3, 1, 0)]
168 #[case(4, 1, 1)]
169 #[case(5, 1, 2)]
170 #[case(6, 2, 0)]
171 #[case(7, 2, 1)]
172 #[case(8, 2, 2)]
173 fn index_of_and_row_col_of_3x3(#[case] index: usize, #[case] row: usize, #[case] col: usize) {
174 let bb = BitBoardStatic::<1>::new(3, 3);
175 assert_eq!(bb.index_of(row, col), index);
176 assert_eq!(bb.row_col_of(index), (row, col));
177 }
178
179 #[test]
180 fn index_of_and_row_col_of_2x10() {
181 let bb = BitBoardStatic::<1>::new(2, 10);
182 for index in 0..20 {
183 let (row, col) = bb.row_col_of(index);
184 assert_eq!(bb.index_of(row, col), index);
185 }
186 }
187
188 #[rstest]
189 #[case(0, 0, false)]
190 #[case(0, 1, true)]
191 #[case(1, 0, false)]
192 #[case(1, 1, false)]
193 #[case(2, 0, false)]
194 #[case(0, 2, false)]
195 #[case(2, 2, false)]
196 fn get_2x2(#[case] row: usize, #[case] col: usize, #[case] expected: bool) {
197 let mut bb = BitBoardStatic::<1>::new(2, 2);
198 bb.set(0, 1, true);
199 assert_eq!(bb.get(row, col), expected);
200 }
201
202 #[rstest]
203 #[case(0, vec![true, false, true])]
204 #[case(1, vec![false, true, false])]
205 #[case(2, vec![true, true, true])]
206 #[case(3, vec![false, false, false])]
207 #[case(4, vec![false, false, false])]
208 fn get_row_3x3(#[case] row: usize, #[case] expected: Vec<bool>) {
209 let mut bb = BitBoardStatic::<1>::new(3, 3);
210 bb.set(0, 0, true);
211 bb.set(0, 2, true);
212 bb.set(1, 1, true);
213 bb.set_row(2, true);
214
215 assert_eq!(bb.get_row(row).collect::<Vec<bool>>(), expected);
216 }
217
218 #[rstest]
219 #[case(0, vec![true, false, true])]
220 #[case(1, vec![false, true, true])]
221 #[case(2, vec![true, false, true])]
222 #[case(3, vec![false, false, false])]
223 fn get_col_3x3(#[case] col: usize, #[case] expected: Vec<bool>) {
224 let mut bb = BitBoardStatic::<1>::new(3, 3);
225 bb.set(0, 0, true);
226 bb.set(0, 2, true);
227 bb.set(1, 1, true);
228 bb.set(2, 0, true);
229 bb.set(2, 1, true);
230 bb.set(2, 2, true);
231
232 assert_eq!(bb.get_col(col).collect::<Vec<bool>>(), expected);
233 }
234}