bit_board/
bitboard.rs

1use bitvec::slice::BitSlice;
2
3use crate::DimensionMismatch;
4
5pub trait BitBoard: Sized {
6    /// Returns the number of rows in the board.
7    fn n_rows(&self) -> usize;
8
9    /// Returns the number of columns in the board.
10    fn n_cols(&self) -> usize;
11
12    /// Returns a mutable reference to the underlying bits.
13    fn board_mut(&mut self) -> &mut BitSlice;
14
15    /// Returns an immutable reference to the underlying bits.
16    fn board(&self) -> &BitSlice;
17
18    /// Get the index that we can use to directly access a certain spot on the board
19    fn index_of(&self, row: usize, col: usize) -> usize {
20        assert!(
21            row <= (self.n_rows() - 1),
22            "row cannot be greater than n_rows"
23        );
24        assert!(
25            col <= (self.n_cols() - 1),
26            "col cannot be greater than n_cols"
27        );
28        (row * self.n_cols()) + col
29    }
30
31    /// Get the row and column of the linear index
32    fn row_col_of(&self, index: usize) -> (usize, usize) {
33        let row = index / self.n_cols();
34        let col = index % self.n_cols();
35        (row, col)
36    }
37
38    /// Set all bits to the desired value.
39    fn fill(&mut self, value: bool) {
40        self.board_mut().fill(value);
41    }
42
43    fn or(&self, other: &impl BitBoard) -> Result<Self, DimensionMismatch>;
44    fn and(&self, other: &impl BitBoard) -> Result<Self, DimensionMismatch>;
45
46    /// Set the value at index [row, col] to be the `new_val`.
47    fn set(&mut self, row: usize, col: usize, value: bool) {
48        let new_ind = self.index_of(row, col);
49        self.board_mut().set(new_ind, value);
50    }
51
52    /// Get the value at index [row, col]. If the index is out of bounds, return false.
53    fn get(&self, row: usize, col: usize) -> bool {
54        if row >= self.n_rows() || col >= self.n_cols() {
55            return false;
56        }
57        let new_ind = self.index_of(row, col);
58        *self.board().get(new_ind).as_deref().unwrap_or(&false)
59    }
60
61    /// Set an entire column to a certain value
62    fn set_col(&mut self, col: usize, value: bool) {
63        // For each row
64        for r_idx in 0..self.n_rows() {
65            // Calculate the index
66            let idx = (r_idx * self.n_cols()) + col;
67            self.board_mut().set(idx, value);
68        }
69    }
70
71    /// Get the values in a given col
72    fn get_col(&self, col: usize) -> impl Iterator<Item = bool> {
73        (0..self.n_rows()).map(move |row| self.get(row, col))
74    }
75
76    /// Set an entire row to a certain value
77    fn set_row(&mut self, row: usize, value: bool) {
78        // For each column in the row
79        for cidx in 0..self.n_cols() {
80            // Calculate the index
81            let idx = (row * self.n_cols()) + cidx;
82            self.board_mut().set(idx, value);
83        }
84    }
85
86    /// Get the values in a given row
87    fn get_row(&self, row: usize) -> impl Iterator<Item = bool> {
88        (0..self.n_cols()).map(move |col| self.get(row, col))
89    }
90
91    /// Will set the neighbors immediately above, below, left, and right to `value`. If
92    /// the neighbor is out of bounds, nothing will happen
93    fn set_cardinal_neighbors(&mut self, row: usize, col: usize, value: bool) {
94        // Above
95        if row > 0 {
96            self.set(row - 1, col, value);
97        }
98
99        // Below
100        if row < self.n_rows() - 1 {
101            self.set(row + 1, col, value);
102        }
103
104        // Left
105        if col > 0 {
106            self.set(row, col - 1, value);
107        }
108
109        // Right
110        if col < self.n_cols() - 1 {
111            self.set(row, col + 1, value);
112        }
113    }
114
115    /// Set just the spots diagonal from the given position to `value`. If
116    /// the neighbor is out of bounds, nothing will happen
117    fn set_diagonals(&mut self, row: usize, col: usize, value: bool) {
118        // Above left
119        if row > 0 && col > 0 {
120            self.set(row - 1, col - 1, value);
121        }
122
123        // Above right
124        if row > 0 && col < self.n_cols() - 1 {
125            self.set(row - 1, col + 1, value);
126        }
127
128        // Below left
129        if row < self.n_rows() - 1 && col > 0 {
130            self.set(row + 1, col - 1, value);
131        }
132
133        // Below right
134        if row < self.n_rows() - 1 && col < self.n_cols() - 1 {
135            self.set(row + 1, col + 1, value);
136        }
137    }
138
139    /// Set the cardinal neighbors and the diagonal neighbors to `value`. If
140    /// the neighbor is out of bounds, nothing will happen
141    fn set_all_neighbors(&mut self, row: usize, col: usize, value: bool) {
142        self.set_cardinal_neighbors(row, col, value);
143        self.set_diagonals(row, col, value);
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::{bitboard::BitBoard, bitboardstatic::BitBoardStatic};
150    use rstest::rstest;
151
152    #[rstest]
153    #[case(0, 0, 0)]
154    #[case(1, 0, 1)]
155    #[case(2, 1, 0)]
156    #[case(3, 1, 1)]
157    fn index_of_and_row_col_of(#[case] index: usize, #[case] row: usize, #[case] col: usize) {
158        let bb = BitBoardStatic::<1>::new(2, 2);
159        assert_eq!(bb.index_of(row, col), index);
160        assert_eq!(bb.row_col_of(index), (row, col));
161    }
162
163    #[rstest]
164    #[case(0, 0, 0)]
165    #[case(1, 0, 1)]
166    #[case(2, 0, 2)]
167    #[case(3, 1, 0)]
168    #[case(4, 1, 1)]
169    #[case(5, 1, 2)]
170    #[case(6, 2, 0)]
171    #[case(7, 2, 1)]
172    #[case(8, 2, 2)]
173    fn index_of_and_row_col_of_3x3(#[case] index: usize, #[case] row: usize, #[case] col: usize) {
174        let bb = BitBoardStatic::<1>::new(3, 3);
175        assert_eq!(bb.index_of(row, col), index);
176        assert_eq!(bb.row_col_of(index), (row, col));
177    }
178
179    #[test]
180    fn index_of_and_row_col_of_2x10() {
181        let bb = BitBoardStatic::<1>::new(2, 10);
182        for index in 0..20 {
183            let (row, col) = bb.row_col_of(index);
184            assert_eq!(bb.index_of(row, col), index);
185        }
186    }
187
188    #[rstest]
189    #[case(0, 0, false)]
190    #[case(0, 1, true)]
191    #[case(1, 0, false)]
192    #[case(1, 1, false)]
193    #[case(2, 0, false)]
194    #[case(0, 2, false)]
195    #[case(2, 2, false)]
196    fn get_2x2(#[case] row: usize, #[case] col: usize, #[case] expected: bool) {
197        let mut bb = BitBoardStatic::<1>::new(2, 2);
198        bb.set(0, 1, true);
199        assert_eq!(bb.get(row, col), expected);
200    }
201
202    #[rstest]
203    #[case(0, vec![true, false, true])]
204    #[case(1, vec![false, true, false])]
205    #[case(2, vec![true, true, true])]
206    #[case(3, vec![false, false, false])]
207    #[case(4, vec![false, false, false])]
208    fn get_row_3x3(#[case] row: usize, #[case] expected: Vec<bool>) {
209        let mut bb = BitBoardStatic::<1>::new(3, 3);
210        bb.set(0, 0, true);
211        bb.set(0, 2, true);
212        bb.set(1, 1, true);
213        bb.set_row(2, true);
214
215        assert_eq!(bb.get_row(row).collect::<Vec<bool>>(), expected);
216    }
217
218    #[rstest]
219    #[case(0, vec![true, false, true])]
220    #[case(1, vec![false, true, true])]
221    #[case(2, vec![true, false, true])]
222    #[case(3, vec![false, false, false])]
223    fn get_col_3x3(#[case] col: usize, #[case] expected: Vec<bool>) {
224        let mut bb = BitBoardStatic::<1>::new(3, 3);
225        bb.set(0, 0, true);
226        bb.set(0, 2, true);
227        bb.set(1, 1, true);
228        bb.set(2, 0, true);
229        bb.set(2, 1, true);
230        bb.set(2, 2, true);
231
232        assert_eq!(bb.get_col(col).collect::<Vec<bool>>(), expected);
233    }
234}