myopic_core/
square.rs

1use crate::{BitBoard, Dir};
2use anyhow::{anyhow, Error, Result};
3use std::str::FromStr;
4
5/// Type representing a square on a chessboard.
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
7#[rustfmt::skip]
8pub enum Square {
9    H1, G1, F1, E1, D1, C1, B1, A1,
10    H2, G2, F2, E2, D2, C2, B2, A2,
11    H3, G3, F3, E3, D3, C3, B3, A3,
12    H4, G4, F4, E4, D4, C4, B4, A4,
13    H5, G5, F5, E5, D5, C5, B5, A5,
14    H6, G6, F6, E6, D6, C6, B6, A6,
15    H7, G7, F7, E7, D7, C7, B7, A7,
16    H8, G8, F8, E8, D8, C8, B8, A8,
17}
18
19impl std::fmt::Display for Square {
20    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
21        write!(f, "{}", format!("{:?}", self).to_lowercase())
22    }
23}
24
25impl FromStr for Square {
26    type Err = Error;
27
28    fn from_str(s: &str) -> Result<Self, Self::Err> {
29        let lower = s.to_lowercase();
30        Square::iter()
31            .find(|sq| sq.to_string() == lower)
32            .ok_or(anyhow!("Cannot parse {} as a Square", s))
33    }
34}
35
36impl Square {
37    /// Return an iterator traversing all squares in order.
38    pub fn iter() -> impl Iterator<Item = Square> {
39        ALL.iter().cloned()
40    }
41
42    /// Retrieve a square by it's corresponding index.
43    pub fn from_index(i: usize) -> Square {
44        ALL[i]
45    }
46
47    /// Return the index of the rank on which this square resides.
48    pub const fn rank_index(self) -> usize {
49        (self as usize) / 8
50    }
51
52    /// Return the index of the file on which this square resides.
53    pub const fn file_index(self) -> usize {
54        (self as usize) % 8
55    }
56
57    /// Return a bitboard representing the rank on which this square
58    /// resides.
59    pub fn rank(self) -> BitBoard {
60        BitBoard::RANKS[self.rank_index()]
61    }
62
63    /// Return a bitboard representing the file on which this square
64    /// resides.
65    pub fn file(self) -> BitBoard {
66        BitBoard::FILES[self.file_index()]
67    }
68
69    /// 'Lifts' this square to a singleton set of squares.
70    pub const fn lift(self) -> BitBoard {
71        BitBoard(1u64 << (self as u64))
72    }
73
74    /// Finds the next square on a chessboard from this square in a
75    /// given direction if it exists.
76    pub fn next(self, dir: Dir) -> Option<Square> {
77        let dr = match dir {
78            Dir::E | Dir::W => 0,
79            Dir::N | Dir::NE | Dir::NEE | Dir::NW | Dir::NWW => 1,
80            Dir::NNE | Dir::NNW => 2,
81            Dir::S | Dir::SE | Dir::SEE | Dir::SW | Dir::SWW => -1,
82            Dir::SSE | Dir::SSW => -2,
83        };
84        let df = match dir {
85            Dir::N | Dir::S => 0,
86            Dir::W | Dir::NW | Dir::NNW | Dir::SW | Dir::SSW => 1,
87            Dir::NWW | Dir::SWW => 2,
88            Dir::E | Dir::NE | Dir::NNE | Dir::SE | Dir::SSE => -1,
89            Dir::NEE | Dir::SEE => -2,
90        };
91        let new_rank = (self.rank_index() as i8) + dr;
92        let new_file = (self.file_index() as i8) + df;
93        if -1 < new_rank && new_rank < 8 && -1 < new_file && new_file < 8 {
94            Some(ALL[(8 * new_rank + new_file) as usize])
95        } else {
96            None
97        }
98    }
99
100    /// Find all squares in a given direction from this square and
101    /// returns them as a set.
102    pub fn search(self, dir: Dir) -> BitBoard {
103        self.search_vec(dir).into_iter().collect()
104    }
105
106    /// Find all squares in a given direction from this square and
107    /// returns them as a vector where the squares are ordered in
108    /// increasing distance from this square.
109    pub fn search_vec(self, dir: Dir) -> Vec<Square> {
110        itertools::iterate(Some(self), |op| op.and_then(|sq| sq.next(dir)))
111            .skip(1)
112            .take_while(|op| op.is_some())
113            .map(|op| op.unwrap())
114            .collect()
115    }
116
117    /// Find all squares in all directions in a given vector and
118    /// returns them as a set.
119    pub fn search_all(self, dirs: &Vec<Dir>) -> BitBoard {
120        dirs.iter().flat_map(|&dir| self.search(dir)).collect()
121    }
122
123    /// Find the squares adjacent to this square in all of the
124    /// given directions and returns them as a set.
125    pub fn search_one(self, dirs: &Vec<Dir>) -> BitBoard {
126        dirs.iter()
127            .flat_map(|&dir| self.next(dir).into_iter())
128            .collect()
129    }
130}
131
132impl std::ops::Shl<usize> for Square {
133    type Output = Square;
134    fn shl(self, rhs: usize) -> Self::Output {
135        Square::from_index(self as usize + rhs)
136    }
137}
138
139impl std::ops::Shr<usize> for Square {
140    type Output = Square;
141    fn shr(self, rhs: usize) -> Self::Output {
142        Square::from_index(self as usize - rhs)
143    }
144}
145
146impl std::ops::Not for Square {
147    type Output = BitBoard;
148    fn not(self) -> Self::Output {
149        !self.lift()
150    }
151}
152
153impl std::ops::BitOr<Square> for Square {
154    type Output = BitBoard;
155    fn bitor(self, other: Square) -> Self::Output {
156        self.lift() | other.lift()
157    }
158}
159
160impl std::ops::BitOr<BitBoard> for Square {
161    type Output = BitBoard;
162    fn bitor(self, other: BitBoard) -> Self::Output {
163        self.lift() | other
164    }
165}
166
167impl std::ops::BitAnd<BitBoard> for Square {
168    type Output = BitBoard;
169    fn bitand(self, other: BitBoard) -> Self::Output {
170        self.lift() & other
171    }
172}
173
174impl std::ops::Sub<BitBoard> for Square {
175    type Output = BitBoard;
176    fn sub(self, other: BitBoard) -> Self::Output {
177        self.lift() - other
178    }
179}
180
181#[rustfmt::skip]
182const ALL: [Square; 64] = [
183    Square::H1, Square::G1, Square::F1, Square::E1, Square::D1, Square::C1, Square::B1, Square::A1,
184    Square::H2, Square::G2, Square::F2, Square::E2, Square::D2, Square::C2, Square::B2, Square::A2,
185    Square::H3, Square::G3, Square::F3, Square::E3, Square::D3, Square::C3, Square::B3, Square::A3,
186    Square::H4, Square::G4, Square::F4, Square::E4, Square::D4, Square::C4, Square::B4, Square::A4,
187    Square::H5, Square::G5, Square::F5, Square::E5, Square::D5, Square::C5, Square::B5, Square::A5,
188    Square::H6, Square::G6, Square::F6, Square::E6, Square::D6, Square::C6, Square::B6, Square::A6,
189    Square::H7, Square::G7, Square::F7, Square::E7, Square::D7, Square::C7, Square::B7, Square::A7,
190    Square::H8, Square::G8, Square::F8, Square::E8, Square::D8, Square::C8, Square::B8, Square::A8,
191];
192
193#[cfg(test)]
194mod test {
195    use crate::square::Square;
196    use crate::square::Square::*;
197    use crate::Dir::*;
198
199    #[test]
200    fn test_rank() {
201        assert_eq!(A1 | B1 | C1 | D1 | E1 | F1 | G1 | H1, F1.rank());
202        assert_eq!(A4 | B4 | C4 | D4 | E4 | F4 | G4 | H4, D4.rank());
203        assert_eq!(A8 | B8 | C8 | D8 | E8 | F8 | G8 | H8, A8.rank());
204    }
205
206    #[test]
207    fn test_file() {
208        assert_eq!(B1 | B2 | B3 | B4 | B5 | B6 | B7 | B8, B4.file())
209    }
210
211    #[test]
212    fn test_partial_ord() {
213        for i in 0..64 {
214            let prev: Vec<_> = Square::iter().take(i).collect();
215            let next: Vec<_> = Square::iter().skip(i + 1).collect();
216            let pivot = Square::from_index(i);
217
218            for smaller in prev {
219                assert_eq!(true, smaller < pivot);
220            }
221
222            for larger in next {
223                assert_eq!(true, pivot < larger);
224            }
225        }
226    }
227
228    #[test]
229    fn test_search() {
230        assert_eq!(D3.search(S), D2 | D1);
231    }
232
233    #[test]
234    fn test_search_vec() {
235        assert_eq!(D3.search_vec(S), vec![D2, D1])
236    }
237
238    #[test]
239    fn test_search_one() {
240        assert_eq!(D3.search_one(&vec!(S, E)), D2 | E3);
241        assert_eq!(A8.search_one(&vec!(N, NWW, SE)), B7.lift());
242    }
243
244    #[test]
245    fn test_search_all() {
246        assert_eq!(C3.search_all(&vec!(SSW, SWW, S)), B1 | A2 | C2 | C1);
247    }
248
249    #[test]
250    fn test_next() {
251        assert_eq!(C3.next(N), Some(C4));
252        assert_eq!(C3.next(E), Some(D3));
253        assert_eq!(C3.next(S), Some(C2));
254        assert_eq!(C3.next(W), Some(B3));
255        assert_eq!(C3.next(NE), Some(D4));
256        assert_eq!(C3.next(SE), Some(D2));
257        assert_eq!(C3.next(SW), Some(B2));
258        assert_eq!(C3.next(NW), Some(B4));
259        assert_eq!(C3.next(NNE), Some(D5));
260        assert_eq!(C3.next(NEE), Some(E4));
261        assert_eq!(C3.next(SEE), Some(E2));
262        assert_eq!(C3.next(SSE), Some(D1));
263        assert_eq!(C3.next(SSW), Some(B1));
264        assert_eq!(C3.next(SWW), Some(A2));
265        assert_eq!(C3.next(NWW), Some(A4));
266        assert_eq!(C3.next(NNW), Some(B5));
267
268        assert_eq!(G8.next(N), None);
269        assert_eq!(H6.next(E), None);
270        assert_eq!(B1.next(S), None);
271        assert_eq!(A4.next(W), None);
272    }
273}