bit_board/
lib.rs

1use std::error::Error;
2use std::fmt;
3
4use bitvec::prelude::*;
5
6#[derive(Debug)]
7pub struct DimensionMismatch;
8
9impl fmt::Display for DimensionMismatch {
10    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11        write!(f, "Dimensions do not match.")
12    }
13}
14
15impl Error for DimensionMismatch {}
16
17/// BitBoard is a 2D array of booleans, stored in the bits of integers. It does
18/// assumes that the boundaries are hard, and going past a boundary does *not* take
19/// you back to the other side.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct BitBoard {
22    // The slice of bits that represent the board.
23    pub board: BitVec,
24
25    /// How many rows does the board have
26    pub n_rows: usize,
27
28    /// How many columns does the board have
29    pub n_cols: usize,
30}
31
32impl fmt::Display for BitBoard {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        // column indices
35        write!(f, "   ")?; // space for row labels
36        for col in 0..self.n_cols {
37            write!(f, "{}", col % 10)?; // wrap every 10 for readability
38        }
39        writeln!(f)?;
40
41        for row in 0..self.n_rows {
42            // row index, right-aligned to 2 spaces
43            write!(f, "{:>2} ", row)?;
44            for col in 0..self.n_cols {
45                let idx = row * self.n_cols + col;
46                let bit = self.board[idx];
47                let c = if bit { 'X' } else { '.' };
48                write!(f, "{}", c)?;
49            }
50            writeln!(f)?;
51        }
52        Ok(())
53    }
54}
55
56impl BitBoard {
57    /// Create a new empty board with `n_rows` and `n_cols`.
58    pub fn new(n_rows: usize, n_cols: usize) -> Self {
59        BitBoard {
60            board: bitvec![0; n_rows * n_cols],
61            n_rows,
62            n_cols,
63        }
64    }
65
66    /// Get the index that we can use to directly access a certain spot on the board
67    pub fn index_of(&self, row: usize, col: usize) -> usize {
68        assert!(
69            row <= (self.n_rows - 1),
70            "row cannot be greater than n_rows"
71        );
72        assert!(
73            col <= (self.n_cols - 1),
74            "col cannot be greater than n_cols"
75        );
76        (row * self.n_cols) + col
77    }
78
79    /// Set all bits to the desired value.
80    pub fn fill(&mut self, value: bool) {
81        self.board.fill(value);
82    }
83
84    pub fn or(&self, other: &BitBoard) -> Result<BitBoard, DimensionMismatch> {
85        if (self.n_rows != other.n_rows) || (self.n_cols != other.n_cols) {
86            return Err(DimensionMismatch);
87        }
88        let mut new_board = BitBoard::new(self.n_rows, self.n_cols);
89        new_board.board = self.board.clone() | other.board.clone();
90        Ok(new_board)
91    }
92
93    pub fn and(&self, other: &BitBoard) -> Result<BitBoard, DimensionMismatch> {
94        if (self.n_rows != other.n_rows) || (self.n_cols != other.n_cols) {
95            return Err(DimensionMismatch);
96        }
97        let mut new_board = BitBoard::new(self.n_rows, self.n_cols);
98        new_board.board = self.board.clone() & other.board.clone();
99        Ok(new_board)
100    }
101
102    /// Set the value at index [row, col] to be the `new_val`.
103    pub fn set(&mut self, row: usize, col: usize, value: bool) {
104        let new_ind = self.index_of(row, col);
105        self.board.set(new_ind, value);
106    }
107
108    /// Set an entire column to a certain value
109    pub fn set_col(&mut self, col: usize, value: bool) {
110        // For each row
111        for r_idx in 0..self.n_rows {
112            // Calculate the index
113            let idx = (r_idx * self.n_cols) + col;
114            self.board.set(idx, value);
115        }
116    }
117
118    /// Set an entire row to a certain value
119    pub fn set_row(&mut self, row: usize, value: bool) {
120        // For each column in the row
121        for cidx in 0..self.n_cols {
122            // Calculate the index
123            let idx = (row * self.n_cols) + cidx;
124            self.board.set(idx, value);
125        }
126    }
127
128    /// Will set the neighbors immediately above, below, left, and right to `value`. If
129    /// the neighbor is out of bounds, nothing will happen
130    pub fn set_cardinal_neighbors(&mut self, row: usize, col: usize, value: bool) {
131        // Above
132        if row > 0 {
133            self.set(row - 1, col, value);
134        }
135
136        // Below
137        if row < self.n_rows - 1 {
138            self.set(row + 1, col, value);
139        }
140
141        // Left
142        if col > 0 {
143            self.set(row, col - 1, value);
144        }
145
146        // Right
147        if col < self.n_cols - 1 {
148            self.set(row, col + 1, value);
149        }
150    }
151
152    /// Set just the spots diagonal from the given position to `value`. If
153    /// the neighbor is out of bounds, nothing will happen
154    pub fn set_diagonals(&mut self, row: usize, col: usize, value: bool) {
155        // Above left
156        if row > 0 && col > 0 {
157            self.set(row - 1, col - 1, value);
158        }
159
160        // Above right
161        if row > 0 && col < self.n_cols - 1 {
162            self.set(row - 1, col + 1, value);
163        }
164
165        // Below left
166        if row < self.n_rows - 1 && col > 0 {
167            self.set(row + 1, col - 1, value);
168        }
169
170        // Below right
171        if row < self.n_rows - 1 && col < self.n_cols - 1 {
172            self.set(row + 1, col + 1, value);
173        }
174    }
175
176    /// Set the cardinal neighbors and the diagonal neighbors to `value`. If
177    /// the neighbor is out of bounds, nothing will happen
178    pub fn set_all_neighbors(&mut self, row: usize, col: usize, value: bool) {
179        self.set_cardinal_neighbors(row, col, value);
180        self.set_diagonals(row, col, value);
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use rstest::rstest;
188
189    #[test]
190    fn can_construct() {
191        let bb = BitBoard::new(2, 2);
192        println!("{:?}", bb);
193    }
194
195    #[test]
196    #[should_panic(expected = "row cannot be greater than n_rows")]
197    fn row_too_big() {
198        let bb = BitBoard::new(2, 2);
199        bb.index_of(10, 0);
200    }
201
202    #[test]
203    #[should_panic(expected = "col cannot be greater than n_col")]
204    fn col_too_big() {
205        let bb = BitBoard::new(2, 2);
206        bb.index_of(0, 10);
207    }
208
209    #[rstest]
210    #[case(0, 0, 0)]
211    #[case(0, 1, 1)]
212    #[case(1, 0, 2)]
213    #[case(1, 1, 3)]
214    fn index_of(#[case] row: usize, #[case] col: usize, #[case] expected: usize) {
215        let bb = BitBoard::new(2, 2);
216        assert_eq!(expected, bb.index_of(row, col))
217    }
218
219    #[rstest]
220    #[case(0, 0, 0)]
221    #[case(1, 0, 1)]
222    #[case(2, 0, 2)]
223    #[case(3, 0, 3)]
224    #[case(4, 0, 4)]
225    fn col_vec_index_of(#[case] row: usize, #[case] col: usize, #[case] expected: usize) {
226        let bb = BitBoard::new(5, 1);
227        assert_eq!(expected, bb.index_of(row, col))
228    }
229
230    #[rstest]
231    #[case(0, 0, 0)]
232    #[case(0, 1, 1)]
233    #[case(0, 2, 2)]
234    #[case(0, 3, 3)]
235    #[case(0, 4, 4)]
236    fn row_vec_index_of(#[case] row: usize, #[case] col: usize, #[case] expected: usize) {
237        let bb = BitBoard::new(1, 5);
238        assert_eq!(expected, bb.index_of(row, col))
239    }
240
241    #[rstest]
242    #[case(0)]
243    #[case(1)]
244    #[case(2)]
245    #[case(3)]
246    #[case(4)]
247    fn set_col(#[case] col: usize) {
248        let mut bb = BitBoard::new(5, 5);
249        bb.set_col(col, true);
250        for ridx in 0..bb.n_rows {
251            for cidx in 0..bb.n_cols {
252                if cidx == col {
253                    assert!(bb.board[bb.index_of(ridx, cidx)])
254                } else {
255                    assert!(!bb.board[bb.index_of(ridx, cidx)])
256                }
257            }
258        }
259    }
260
261    #[rstest]
262    #[case(0)]
263    #[case(1)]
264    #[case(2)]
265    #[case(3)]
266    #[case(4)]
267    fn set_row(#[case] row: usize) {
268        let mut bb = BitBoard::new(5, 5);
269        bb.set_row(row, true);
270        for ridx in 0..bb.n_rows {
271            for cidx in 0..bb.n_cols {
272                if ridx == row {
273                    assert!(bb.board[bb.index_of(ridx, cidx)])
274                } else {
275                    assert!(!bb.board[bb.index_of(ridx, cidx)])
276                }
277            }
278        }
279    }
280
281    #[test]
282    fn can_set_all_bits() {
283        // Create the board
284        let nr = 3;
285        let nc = 3;
286        let mut bb = BitBoard::new(nr, nc);
287
288        bb.set(0, 0, true);
289
290        // Set each bit, and check that all bits are 1
291        for ridx in 0..nr {
292            for cidx in 0..nc {
293                bb.set(ridx, cidx, true);
294            }
295        }
296        assert!(bb.board.all());
297
298        // Unset each bit, and check that all bits are 0
299        for ridx in 0..nr {
300            for cidx in 0..nc {
301                bb.set(ridx, cidx, false);
302            }
303        }
304        assert!(bb.board.not_any());
305    }
306
307    #[rstest]
308    #[case(0, 0, BitBoard { board: bitvec![0, 1, 1, 0], n_rows: 2, n_cols: 2 })]
309    #[case(0, 1, BitBoard { board: bitvec![1, 0, 0, 1], n_rows: 2, n_cols: 2 })]
310    #[case(1, 0, BitBoard { board: bitvec![1, 0, 0, 1], n_rows: 2, n_cols: 2 })]
311    #[case(1, 1, BitBoard { board: bitvec![0, 1, 1, 0], n_rows: 2, n_cols: 2 })]
312    fn set_caridnal_neighbors_2x2(
313        #[case] row: usize,
314        #[case] col: usize,
315        #[case] expect: BitBoard,
316    ) {
317        let mut bb = BitBoard::new(2, 2);
318        bb.set_cardinal_neighbors(row, col, true);
319        assert_eq!(expect, bb);
320    }
321
322    #[rstest]
323    #[case(0, 0, BitBoard { board: bitvec![0, 1, 0, 1, 0, 0, 0, 0, 0], n_rows: 3, n_cols: 3 })]
324    #[case(0, 1, BitBoard { board: bitvec![1, 0, 1, 0, 1, 0, 0, 0, 0], n_rows: 3, n_cols: 3 })]
325    #[case(0, 2, BitBoard { board: bitvec![0, 1, 0, 0, 0, 1, 0, 0, 0], n_rows: 3, n_cols: 3 })]
326    #[case(1, 0, BitBoard { board: bitvec![1, 0, 0, 0, 1, 0, 1, 0, 0], n_rows: 3, n_cols: 3 })]
327    #[case(1, 1, BitBoard { board: bitvec![0, 1, 0, 1, 0, 1, 0, 1, 0], n_rows: 3, n_cols: 3 })]
328    #[case(1, 2, BitBoard { board: bitvec![0, 0, 1, 0, 1, 0, 0, 0, 1], n_rows: 3, n_cols: 3 })]
329    #[case(2, 0, BitBoard { board: bitvec![0, 0, 0, 1, 0, 0, 0, 1, 0], n_rows: 3, n_cols: 3 })]
330    #[case(2, 1, BitBoard { board: bitvec![0, 0, 0, 0, 1, 0, 1, 0, 1], n_rows: 3, n_cols: 3 })]
331    #[case(2, 2, BitBoard { board: bitvec![0, 0, 0, 0, 0, 1, 0, 1, 0], n_rows: 3, n_cols: 3 })]
332    fn set_caridnal_neighbors_3x3(
333        #[case] row: usize,
334        #[case] col: usize,
335        #[case] expect: BitBoard,
336    ) {
337        let mut bb = BitBoard::new(3, 3);
338        bb.set_cardinal_neighbors(row, col, true);
339        assert_eq!(expect, bb);
340    }
341
342    #[rstest]
343    #[case(0, 0, BitBoard { board: bitvec![0, 1, 1, 1], n_rows: 2, n_cols: 2 })]
344    #[case(0, 1, BitBoard { board: bitvec![1, 0, 1, 1], n_rows: 2, n_cols: 2 })]
345    #[case(1, 0, BitBoard { board: bitvec![1, 1, 0, 1], n_rows: 2, n_cols: 2 })]
346    #[case(1, 1, BitBoard { board: bitvec![1, 1, 1, 0], n_rows: 2, n_cols: 2 })]
347    fn set_all_neighbors_2x2(#[case] row: usize, #[case] col: usize, #[case] expect: BitBoard) {
348        let mut bb = BitBoard::new(2, 2);
349        bb.set_all_neighbors(row, col, true);
350        assert_eq!(expect, bb);
351    }
352
353    #[rstest]
354    #[case(0, 0, BitBoard { board: bitvec![0, 1, 0, 1, 1, 0, 0, 0, 0], n_rows: 3, n_cols: 3 })]
355    #[case(0, 1, BitBoard { board: bitvec![1, 0, 1, 1, 1, 1, 0, 0, 0], n_rows: 3, n_cols: 3 })]
356    #[case(0, 2, BitBoard { board: bitvec![0, 1, 0, 0, 1, 1, 0, 0, 0], n_rows: 3, n_cols: 3 })]
357    #[case(1, 0, BitBoard { board: bitvec![1, 1, 0, 0, 1, 0, 1, 1, 0], n_rows: 3, n_cols: 3 })]
358    #[case(1, 1, BitBoard { board: bitvec![1, 1, 1, 1, 0, 1, 1, 1, 1], n_rows: 3, n_cols: 3 })]
359    #[case(1, 2, BitBoard { board: bitvec![0, 1, 1, 0, 1, 0, 0, 1, 1], n_rows: 3, n_cols: 3 })]
360    #[case(2, 0, BitBoard { board: bitvec![0, 0, 0, 1, 1, 0, 0, 1, 0], n_rows: 3, n_cols: 3 })]
361    #[case(2, 1, BitBoard { board: bitvec![0, 0, 0, 1, 1, 1, 1, 0, 1], n_rows: 3, n_cols: 3 })]
362    #[case(2, 2, BitBoard { board: bitvec![0, 0, 0, 0, 1, 1, 0, 1, 0], n_rows: 3, n_cols: 3 })]
363    fn set_all_neighbors_3x3(#[case] row: usize, #[case] col: usize, #[case] expect: BitBoard) {
364        let mut bb = BitBoard::new(3, 3);
365        bb.set_all_neighbors(row, col, true);
366        assert_eq!(expect, bb);
367    }
368
369    #[rstest]
370    #[case(1, 1, 1, 2)]
371    #[case(2, 1, 1, 2)]
372    #[case(2, 1, 2, 7)]
373    fn and_dimension_mismatch(
374        #[case] b1r: usize,
375        #[case] b1c: usize,
376        #[case] b2r: usize,
377        #[case] b2c: usize,
378    ) {
379        let bb1 = BitBoard::new(b1r, b1c);
380        let bb8 = BitBoard::new(b2r, b2c);
381        assert!(bb1.and(&bb8).is_err());
382    }
383
384    #[rstest]
385    #[case(1, 1, 1, 2)]
386    #[case(2, 1, 1, 2)]
387    #[case(2, 1, 2, 7)]
388    fn or_dimension_mismatch(
389        #[case] b1r: usize,
390        #[case] b1c: usize,
391        #[case] b2r: usize,
392        #[case] b2c: usize,
393    ) {
394        let bb1 = BitBoard::new(b1r, b1c);
395        let bb8 = BitBoard::new(b2r, b2c);
396        assert!(bb1.or(&bb8).is_err());
397    }
398
399    #[rstest]
400    #[case(bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0])] // empty AND empty
401    #[case(bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1])] // full AND full
402    #[case(bitvec![0, 0, 0, 0], bitvec![1, 1, 1, 1], bitvec![0, 0, 0, 0])] // empty AND full
403    #[case(bitvec![1, 1, 1, 1], bitvec![1, 0, 0, 1], bitvec![1, 0, 0, 1])] // full AND partial
404    #[case(bitvec![1, 0, 1, 0], bitvec![0, 1, 0, 1], bitvec![0, 0, 0, 0])] // alternating patterns
405    #[case(bitvec![1, 1, 0, 0], bitvec![1, 0, 1, 0], bitvec![1, 0, 0, 0])] // partial patterns
406    fn and_operations(#[case] board1: BitVec, #[case] board2: BitVec, #[case] expected: BitVec) {
407        let bb1 = BitBoard {
408            board: board1,
409            n_rows: 2,
410            n_cols: 2,
411        };
412        let bb2 = BitBoard {
413            board: board2,
414            n_rows: 2,
415            n_cols: 2,
416        };
417
418        let result = bb1.and(&bb2).unwrap();
419        assert_eq!(result.board, expected);
420        assert_eq!(result.n_rows, 2);
421        assert_eq!(result.n_cols, 2);
422    }
423
424    #[rstest]
425    #[case(bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0], bitvec![0, 0, 0, 0])] // empty OR empty
426    #[case(bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1])] // full OR full
427    #[case(bitvec![0, 0, 0, 0], bitvec![1, 1, 1, 1], bitvec![1, 1, 1, 1])] // empty OR full
428    #[case(bitvec![0, 0, 0, 0], bitvec![1, 0, 0, 1], bitvec![1, 0, 0, 1])] // empty OR partial
429    #[case(bitvec![1, 0, 1, 0], bitvec![0, 1, 0, 1], bitvec![1, 1, 1, 1])] // alternating patterns
430    #[case(bitvec![1, 1, 0, 0], bitvec![0, 0, 1, 1], bitvec![1, 1, 1, 1])] // complementary patterns
431    #[case(bitvec![1, 0, 0, 1], bitvec![0, 1, 1, 0], bitvec![1, 1, 1, 1])] // diagonal patterns
432    fn or_operations(#[case] board1: BitVec, #[case] board2: BitVec, #[case] expected: BitVec) {
433        let bb1 = BitBoard {
434            board: board1,
435            n_rows: 2,
436            n_cols: 2,
437        };
438        let bb2 = BitBoard {
439            board: board2,
440            n_rows: 2,
441            n_cols: 2,
442        };
443
444        let result = bb1.or(&bb2).unwrap();
445        assert_eq!(result.board, expected);
446        assert_eq!(result.n_rows, 2);
447        assert_eq!(result.n_cols, 2);
448    }
449
450    #[test]
451    fn and_or_larger_boards() {
452        let mut bb1 = BitBoard::new(3, 3);
453        bb1.set_row(0, true); // First row all true
454        bb1.set(2, 2, true); // Bottom right corner
455
456        let mut bb2 = BitBoard::new(3, 3);
457        bb2.set_col(0, true); // First column all true
458        bb2.set(1, 1, true); // Center
459
460        // Test AND operation
461        let and_result = bb1.and(&bb2).unwrap();
462        assert!(and_result.board[0]); // (0,0) - both have true
463        assert!(!and_result.board[1]); // (0,1) - only bb1 has true
464        assert!(!and_result.board[2]); // (0,2) - only bb1 has true
465        assert!(!and_result.board[3]); // (1,0) - only bb2 has true
466        assert!(!and_result.board[4]); // (1,1) - only bb2 has true
467        assert!(!and_result.board[6]); // (2,0) - only bb2 has true
468
469        // Test OR operation
470        let or_result = bb1.or(&bb2).unwrap();
471        assert!(or_result.board[0]); // (0,0) - both have true
472        assert!(or_result.board[1]); // (0,1) - bb1 has true
473        assert!(or_result.board[2]); // (0,2) - bb1 has true
474        assert!(or_result.board[3]); // (1,0) - bb2 has true
475        assert!(or_result.board[4]); // (1,1) - bb2 has true
476        assert!(!or_result.board[5]); // (1,2) - neither has true
477        assert!(or_result.board[6]); // (2,0) - bb2 has true
478        assert!(!or_result.board[7]); // (2,1) - neither has true
479        assert!(or_result.board[8]); // (2,2) - bb1 has true
480    }
481
482    #[test]
483    fn and_or_preserve_original_boards() {
484        let mut bb1 = BitBoard::new(2, 2);
485        bb1.set(0, 0, true);
486        let bb1_original = bb1.clone();
487
488        let mut bb2 = BitBoard::new(2, 2);
489        bb2.set(1, 1, true);
490        let bb2_original = bb2.clone();
491
492        // Perform operations
493        let _and_result = bb1.and(&bb2).unwrap();
494        let _or_result = bb1.or(&bb2).unwrap();
495
496        // Original boards should be unchanged
497        assert_eq!(bb1, bb1_original);
498        assert_eq!(bb2, bb2_original);
499    }
500}