1use std::error::Error;
2use std::fmt;
3
4use bitvec::prelude::*;
5
6#[derive(Debug)]
7pub struct DimensionMismatch;
8
9impl fmt::Display for DimensionMismatch {
10 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11 write!(f, "Dimensions do not match.")
12 }
13}
14
15impl Error for DimensionMismatch {}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct BitBoard {
22 pub board: BitVec,
24
25 pub n_rows: usize,
27
28 pub n_cols: usize,
30}
31
32impl fmt::Display for BitBoard {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 write!(f, " ")?; for col in 0..self.n_cols {
37 write!(f, "{}", col % 10)?; }
39 writeln!(f)?;
40
41 for row in 0..self.n_rows {
42 write!(f, "{:>2} ", row)?;
44 for col in 0..self.n_cols {
45 let idx = row * self.n_cols + col;
46 let bit = self.board[idx];
47 let c = if bit { 'X' } else { '.' };
48 write!(f, "{}", c)?;
49 }
50 writeln!(f)?;
51 }
52 Ok(())
53 }
54}
55
56impl BitBoard {
57 pub fn new(n_rows: usize, n_cols: usize) -> Self {
59 BitBoard {
60 board: bitvec![0; n_rows * n_cols],
61 n_rows,
62 n_cols,
63 }
64 }
65
66 pub fn index_of(&self, row: usize, col: usize) -> usize {
68 assert!(
69 row <= (self.n_rows - 1),
70 "row cannot be greater than n_rows"
71 );
72 assert!(
73 col <= (self.n_cols - 1),
74 "col cannot be greater than n_cols"
75 );
76 (row * self.n_cols) + col
77 }
78
79 pub fn fill(&mut self, value: bool) {
81 self.board.fill(value);
82 }
83
84 pub fn or(&self, other: &BitBoard) -> Result<BitBoard, DimensionMismatch> {
85 if (self.n_rows != other.n_rows) || (self.n_cols != other.n_cols) {
86 return Err(DimensionMismatch);
87 }
88 let mut new_board = BitBoard::new(self.n_rows, self.n_cols);
89 new_board.board = self.board.clone() | other.board.clone();
90 Ok(new_board)
91 }
92
93 pub fn and(&self, other: &BitBoard) -> Result<BitBoard, DimensionMismatch> {
94 if (self.n_rows != other.n_rows) || (self.n_cols != other.n_cols) {
95 return Err(DimensionMismatch);
96 }
97 let mut new_board = BitBoard::new(self.n_rows, self.n_cols);
98 new_board.board = self.board.clone() & other.board.clone();
99 Ok(new_board)
100 }
101
102 pub fn set(&mut self, row: usize, col: usize, value: bool) {
104 let new_ind = self.index_of(row, col);
105 self.board.set(new_ind, value);
106 }
107
108 pub fn set_col(&mut self, col: usize, value: bool) {
110 for r_idx in 0..self.n_rows {
112 let idx = (r_idx * self.n_cols) + col;
114 self.board.set(idx, value);
115 }
116 }
117
118 pub fn set_row(&mut self, row: usize, value: bool) {
120 for cidx in 0..self.n_cols {
122 let idx = (row * self.n_cols) + cidx;
124 self.board.set(idx, value);
125 }
126 }
127
128 pub fn set_cardinal_neighbors(&mut self, row: usize, col: usize, value: bool) {
131 if row > 0 {
133 self.set(row - 1, col, value);
134 }
135
136 if row < self.n_rows - 1 {
138 self.set(row + 1, col, value);
139 }
140
141 if col > 0 {
143 self.set(row, col - 1, value);
144 }
145
146 if col < self.n_cols - 1 {
148 self.set(row, col + 1, value);
149 }
150 }
151
152 pub fn set_diagonals(&mut self, row: usize, col: usize, value: bool) {
155 if row > 0 && col > 0 {
157 self.set(row - 1, col - 1, value);
158 }
159
160 if row > 0 && col < self.n_cols - 1 {
162 self.set(row - 1, col + 1, value);
163 }
164
165 if row < self.n_rows - 1 && col > 0 {
167 self.set(row + 1, col - 1, value);
168 }
169
170 if row < self.n_rows - 1 && col < self.n_cols - 1 {
172 self.set(row + 1, col + 1, value);
173 }
174 }
175
176 pub fn set_all_neighbors(&mut self, row: usize, col: usize, value: bool) {
179 self.set_cardinal_neighbors(row, col, value);
180 self.set_diagonals(row, col, value);
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use rstest::rstest;
188
189 #[test]
190 fn can_construct() {
191 let bb = BitBoard::new(2, 2);
192 println!("{:?}", bb);
193 }
194
195 #[test]
196 #[should_panic(expected = "row cannot be greater than n_rows")]
197 fn row_too_big() {
198 let bb = BitBoard::new(2, 2);
199 bb.index_of(10, 0);
200 }
201
202 #[test]
203 #[should_panic(expected = "col cannot be greater than n_col")]
204 fn col_too_big() {
205 let bb = BitBoard::new(2, 2);
206 bb.index_of(0, 10);
207 }
208
209 #[rstest]
210 #[case(0, 0, 0)]
211 #[case(0, 1, 1)]
212 #[case(1, 0, 2)]
213 #[case(1, 1, 3)]
214 fn index_of(#[case] row: usize, #[case] col: usize, #[case] expected: usize) {
215 let bb = BitBoard::new(2, 2);
216 assert_eq!(expected, bb.index_of(row, col))
217 }
218
219 #[rstest]
220 #[case(0, 0, 0)]
221 #[case(1, 0, 1)]
222 #[case(2, 0, 2)]
223 #[case(3, 0, 3)]
224 #[case(4, 0, 4)]
225 fn col_vec_index_of(#[case] row: usize, #[case] col: usize, #[case] expected: usize) {
226 let bb = BitBoard::new(5, 1);
227 assert_eq!(expected, bb.index_of(row, col))
228 }
229
230 #[rstest]
231 #[case(0, 0, 0)]
232 #[case(0, 1, 1)]
233 #[case(0, 2, 2)]
234 #[case(0, 3, 3)]
235 #[case(0, 4, 4)]
236 fn row_vec_index_of(#[case] row: usize, #[case] col: usize, #[case] expected: usize) {
237 let bb = BitBoard::new(1, 5);
238 assert_eq!(expected, bb.index_of(row, col))
239 }
240
241 #[rstest]
242 #[case(0)]
243 #[case(1)]
244 #[case(2)]
245 #[case(3)]
246 #[case(4)]
247 fn set_col(#[case] col: usize) {
248 let mut bb = BitBoard::new(5, 5);
249 bb.set_col(col, true);
250 for ridx in 0..bb.n_rows {
251 for cidx in 0..bb.n_cols {
252 if cidx == col {
253 assert!(bb.board[bb.index_of(ridx, cidx)])
254 } else {
255 assert!(!bb.board[bb.index_of(ridx, cidx)])
256 }
257 }
258 }
259 }
260
261 #[rstest]
262 #[case(0)]
263 #[case(1)]
264 #[case(2)]
265 #[case(3)]
266 #[case(4)]
267 fn set_row(#[case] row: usize) {
268 let mut bb = BitBoard::new(5, 5);
269 bb.set_row(row, true);
270 for ridx in 0..bb.n_rows {
271 for cidx in 0..bb.n_cols {
272 if ridx == row {
273 assert!(bb.board[bb.index_of(ridx, cidx)])
274 } else {
275 assert!(!bb.board[bb.index_of(ridx, cidx)])
276 }
277 }
278 }
279 }
280
281 #[test]
282 fn can_set_all_bits() {
283 let nr = 3;
285 let nc = 3;
286 let mut bb = BitBoard::new(nr, nc);
287
288 bb.set(0, 0, true);
289
290 for ridx in 0..nr {
292 for cidx in 0..nc {
293 bb.set(ridx, cidx, true);
294 }
295 }
296 assert!(bb.board.all());
297
298 for ridx in 0..nr {
300 for cidx in 0..nc {
301 bb.set(ridx, cidx, false);
302 }
303 }
304 assert!(bb.board.not_any());
305 }
306
307 #[rstest]
308 #[case(0, 0, BitBoard { board: bitvec![0, 1, 1, 0], n_rows: 2, n_cols: 2 })]
309 #[case(0, 1, BitBoard { board: bitvec![1, 0, 0, 1], n_rows: 2, n_cols: 2 })]
310 #[case(1, 0, BitBoard { board: bitvec![1, 0, 0, 1], n_rows: 2, n_cols: 2 })]
311 #[case(1, 1, BitBoard { board: bitvec![0, 1, 1, 0], n_rows: 2, n_cols: 2 })]
312 fn set_caridnal_neighbors_2x2(
313 #[case] row: usize,
314 #[case] col: usize,
315 #[case] expect: BitBoard,
316 ) {
317 let mut bb = BitBoard::new(2, 2);
318 bb.set_cardinal_neighbors(row, col, true);
319 assert_eq!(expect, bb);
320 }
321
322 #[rstest]
323 #[case(0, 0, BitBoard { board: bitvec![0, 1, 0, 1, 0, 0, 0, 0, 0], n_rows: 3, n_cols: 3 })]
324 #[case(0, 1, BitBoard { board: bitvec![1, 0, 1, 0, 1, 0, 0, 0, 0], n_rows: 3, n_cols: 3 })]
325 #[case(0, 2, BitBoard { board: bitvec![0, 1, 0, 0, 0, 1, 0, 0, 0], n_rows: 3, n_cols: 3 })]
326 #[case(1, 0, BitBoard { board: bitvec![1, 0, 0, 0, 1, 0, 1, 0, 0], n_rows: 3, n_cols: 3 })]
327 #[case(1, 1, BitBoard { board: bitvec![0, 1, 0, 1, 0, 1, 0, 1, 0], n_rows: 3, n_cols: 3 })]
328 #[case(1, 2, BitBoard { board: bitvec![0, 0, 1, 0, 1, 0, 0, 0, 1], n_rows: 3, n_cols: 3 })]
329 #[case(2, 0, BitBoard { board: bitvec![0, 0, 0, 1, 0, 0, 0, 1, 0], n_rows: 3, n_cols: 3 })]
330 #[case(2, 1, BitBoard { board: bitvec![0, 0, 0, 0, 1, 0, 1, 0, 1], n_rows: 3, n_cols: 3 })]
331 #[case(2, 2, BitBoard { board: bitvec![0, 0, 0, 0, 0, 1, 0, 1, 0], n_rows: 3, n_cols: 3 })]
332 fn set_caridnal_neighbors_3x3(
333 #[case] row: usize,
334 #[case] col: usize,
335 #[case] expect: BitBoard,
336 ) {
337 let mut bb = BitBoard::new(3, 3);
338 bb.set_cardinal_neighbors(row, col, true);
339 assert_eq!(expect, bb);
340 }
341
342 #[rstest]
343 #[case(0, 0, BitBoard { board: bitvec![0, 1, 1, 1], n_rows: 2, n_cols: 2 })]
344 #[case(0, 1, BitBoard { board: bitvec![1, 0, 1, 1], n_rows: 2, n_cols: 2 })]
345 #[case(1, 0, BitBoard { board: bitvec![1, 1, 0, 1], n_rows: 2, n_cols: 2 })]
346 #[case(1, 1, BitBoard { board: bitvec![1, 1, 1, 0], n_rows: 2, n_cols: 2 })]
347 fn set_all_neighbors_2x2(#[case] row: usize, #[case] col: usize, #[case] expect: BitBoard) {
348 let mut bb = BitBoard::new(2, 2);
349 bb.set_all_neighbors(row, col, true);
350 assert_eq!(expect, bb);
351 }
352
353 #[rstest]
354 #[case(0, 0, BitBoard { board: bitvec![0, 1, 0, 1, 1, 0, 0, 0, 0], n_rows: 3, n_cols: 3 })]
355 #[case(0, 1, BitBoard { board: bitvec![1, 0, 1, 1, 1, 1, 0, 0, 0], n_rows: 3, n_cols: 3 })]
356 #[case(0, 2, BitBoard { board: bitvec![0, 1, 0, 0, 1, 1, 0, 0, 0], n_rows: 3, n_cols: 3 })]
357 #[case(1, 0, BitBoard { board: bitvec![1, 1, 0, 0, 1, 0, 1, 1, 0], n_rows: 3, n_cols: 3 })]
358 #[case(1, 1, BitBoard { board: bitvec![1, 1, 1, 1, 0, 1, 1, 1, 1], n_rows: 3, n_cols: 3 })]
359 #[case(1, 2, BitBoard { board: bitvec![0, 1, 1, 0, 1, 0, 0, 1, 1], n_rows: 3, n_cols: 3 })]
360 #[case(2, 0, BitBoard { board: bitvec![0, 0, 0, 1, 1, 0, 0, 1, 0], n_rows: 3, n_cols: 3 })]
361 #[case(2, 1, BitBoard { board: bitvec![0, 0, 0, 1, 1, 1, 1, 0, 1], n_rows: 3, n_cols: 3 })]
362 #[case(2, 2, BitBoard { board: bitvec![0, 0, 0, 0, 1, 1, 0, 1, 0], n_rows: 3, n_cols: 3 })]
363 fn set_all_neighbors_3x3(#[case] row: usize, #[case] col: usize, #[case] expect: BitBoard) {
364 let mut bb = BitBoard::new(3, 3);
365 bb.set_all_neighbors(row, col, true);
366 assert_eq!(expect, bb);
367 }
368
369 #[rstest]
370 #[case(1, 1, 1, 2)]
371 #[case(2, 1, 1, 2)]
372 #[case(2, 1, 2, 7)]
373 fn and_dimension_mismatch(
374 #[case] b1r: usize,
375 #[case] b1c: usize,
376 #[case] b2r: usize,
377 #[case] b2c: usize,
378 ) {
379 let bb1 = BitBoard::new(b1r, b1c);
380 let bb8 = BitBoard::new(b2r, b2c);
381 assert!(bb1.and(&bb8).is_err());
382 }
383
384 #[rstest]
385 #[case(1, 1, 1, 2)]
386 #[case(2, 1, 1, 2)]
387 #[case(2, 1, 2, 7)]
388 fn or_dimension_mismatch(
389 #[case] b1r: usize,
390 #[case] b1c: usize,
391 #[case] b2r: usize,
392 #[case] b2c: usize,
393 ) {
394 let bb1 = BitBoard::new(b1r, b1c);
395 let bb8 = BitBoard::new(b2r, b2c);
396 assert!(bb1.or(&bb8).is_err());
397 }
398
399 #[rstest]
400 #[case(bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0])] #[case(bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1])] #[case(bitvec![0, 0, 0, 0], bitvec![1, 1, 1, 1], bitvec![0, 0, 0, 0])] #[case(bitvec![1, 1, 1, 1], bitvec![1, 0, 0, 1], bitvec![1, 0, 0, 1])] #[case(bitvec![1, 0, 1, 0], bitvec![0, 1, 0, 1], bitvec![0, 0, 0, 0])] #[case(bitvec![1, 1, 0, 0], bitvec![1, 0, 1, 0], bitvec![1, 0, 0, 0])] fn and_operations(#[case] board1: BitVec, #[case] board2: BitVec, #[case] expected: BitVec) {
407 let bb1 = BitBoard {
408 board: board1,
409 n_rows: 2,
410 n_cols: 2,
411 };
412 let bb2 = BitBoard {
413 board: board2,
414 n_rows: 2,
415 n_cols: 2,
416 };
417
418 let result = bb1.and(&bb2).unwrap();
419 assert_eq!(result.board, expected);
420 assert_eq!(result.n_rows, 2);
421 assert_eq!(result.n_cols, 2);
422 }
423
424 #[rstest]
425 #[case(bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0])] #[case(bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1])] #[case(bitvec![0, 0, 0, 0], bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1])] #[case(bitvec![0, 0, 0, 0], bitvec![1, 0, 0, 1], bitvec![1, 0, 0, 1])] #[case(bitvec![1, 0, 1, 0], bitvec![0, 1, 0, 1], bitvec![1, 1, 1, 1])] #[case(bitvec![1, 1, 0, 0], bitvec![0, 0, 1, 1], bitvec![1, 1, 1, 1])] #[case(bitvec![1, 0, 0, 1], bitvec![0, 1, 1, 0], bitvec![1, 1, 1, 1])] fn or_operations(#[case] board1: BitVec, #[case] board2: BitVec, #[case] expected: BitVec) {
433 let bb1 = BitBoard {
434 board: board1,
435 n_rows: 2,
436 n_cols: 2,
437 };
438 let bb2 = BitBoard {
439 board: board2,
440 n_rows: 2,
441 n_cols: 2,
442 };
443
444 let result = bb1.or(&bb2).unwrap();
445 assert_eq!(result.board, expected);
446 assert_eq!(result.n_rows, 2);
447 assert_eq!(result.n_cols, 2);
448 }
449
450 #[test]
451 fn and_or_larger_boards() {
452 let mut bb1 = BitBoard::new(3, 3);
453 bb1.set_row(0, true); bb1.set(2, 2, true); let mut bb2 = BitBoard::new(3, 3);
457 bb2.set_col(0, true); bb2.set(1, 1, true); let and_result = bb1.and(&bb2).unwrap();
462 assert!(and_result.board[0]); assert!(!and_result.board[1]); assert!(!and_result.board[2]); assert!(!and_result.board[3]); assert!(!and_result.board[4]); assert!(!and_result.board[6]); let or_result = bb1.or(&bb2).unwrap();
471 assert!(or_result.board[0]); assert!(or_result.board[1]); assert!(or_result.board[2]); assert!(or_result.board[3]); assert!(or_result.board[4]); assert!(!or_result.board[5]); assert!(or_result.board[6]); assert!(!or_result.board[7]); assert!(or_result.board[8]); }
481
482 #[test]
483 fn and_or_preserve_original_boards() {
484 let mut bb1 = BitBoard::new(2, 2);
485 bb1.set(0, 0, true);
486 let bb1_original = bb1.clone();
487
488 let mut bb2 = BitBoard::new(2, 2);
489 bb2.set(1, 1, true);
490 let bb2_original = bb2.clone();
491
492 let _and_result = bb1.and(&bb2).unwrap();
494 let _or_result = bb1.or(&bb2).unwrap();
495
496 assert_eq!(bb1, bb1_original);
498 assert_eq!(bb2, bb2_original);
499 }
500}