1use crate::{BitBoard, Dir};
2use anyhow::{anyhow, Error, Result};
3use std::str::FromStr;
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
7#[rustfmt::skip]
8pub enum Square {
9 H1, G1, F1, E1, D1, C1, B1, A1,
10 H2, G2, F2, E2, D2, C2, B2, A2,
11 H3, G3, F3, E3, D3, C3, B3, A3,
12 H4, G4, F4, E4, D4, C4, B4, A4,
13 H5, G5, F5, E5, D5, C5, B5, A5,
14 H6, G6, F6, E6, D6, C6, B6, A6,
15 H7, G7, F7, E7, D7, C7, B7, A7,
16 H8, G8, F8, E8, D8, C8, B8, A8,
17}
18
19impl std::fmt::Display for Square {
20 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
21 write!(f, "{}", format!("{:?}", self).to_lowercase())
22 }
23}
24
25impl FromStr for Square {
26 type Err = Error;
27
28 fn from_str(s: &str) -> Result<Self, Self::Err> {
29 let lower = s.to_lowercase();
30 Square::iter()
31 .find(|sq| sq.to_string() == lower)
32 .ok_or(anyhow!("Cannot parse {} as a Square", s))
33 }
34}
35
36impl Square {
37 pub fn iter() -> impl Iterator<Item = Square> {
39 ALL.iter().cloned()
40 }
41
42 pub fn from_index(i: usize) -> Square {
44 ALL[i]
45 }
46
47 pub const fn rank_index(self) -> usize {
49 (self as usize) / 8
50 }
51
52 pub const fn file_index(self) -> usize {
54 (self as usize) % 8
55 }
56
57 pub fn rank(self) -> BitBoard {
60 BitBoard::RANKS[self.rank_index()]
61 }
62
63 pub fn file(self) -> BitBoard {
66 BitBoard::FILES[self.file_index()]
67 }
68
69 pub const fn lift(self) -> BitBoard {
71 BitBoard(1u64 << (self as u64))
72 }
73
74 pub fn next(self, dir: Dir) -> Option<Square> {
77 let dr = match dir {
78 Dir::E | Dir::W => 0,
79 Dir::N | Dir::NE | Dir::NEE | Dir::NW | Dir::NWW => 1,
80 Dir::NNE | Dir::NNW => 2,
81 Dir::S | Dir::SE | Dir::SEE | Dir::SW | Dir::SWW => -1,
82 Dir::SSE | Dir::SSW => -2,
83 };
84 let df = match dir {
85 Dir::N | Dir::S => 0,
86 Dir::W | Dir::NW | Dir::NNW | Dir::SW | Dir::SSW => 1,
87 Dir::NWW | Dir::SWW => 2,
88 Dir::E | Dir::NE | Dir::NNE | Dir::SE | Dir::SSE => -1,
89 Dir::NEE | Dir::SEE => -2,
90 };
91 let new_rank = (self.rank_index() as i8) + dr;
92 let new_file = (self.file_index() as i8) + df;
93 if -1 < new_rank && new_rank < 8 && -1 < new_file && new_file < 8 {
94 Some(ALL[(8 * new_rank + new_file) as usize])
95 } else {
96 None
97 }
98 }
99
100 pub fn search(self, dir: Dir) -> BitBoard {
103 self.search_vec(dir).into_iter().collect()
104 }
105
106 pub fn search_vec(self, dir: Dir) -> Vec<Square> {
110 itertools::iterate(Some(self), |op| op.and_then(|sq| sq.next(dir)))
111 .skip(1)
112 .take_while(|op| op.is_some())
113 .map(|op| op.unwrap())
114 .collect()
115 }
116
117 pub fn search_all(self, dirs: &Vec<Dir>) -> BitBoard {
120 dirs.iter().flat_map(|&dir| self.search(dir)).collect()
121 }
122
123 pub fn search_one(self, dirs: &Vec<Dir>) -> BitBoard {
126 dirs.iter()
127 .flat_map(|&dir| self.next(dir).into_iter())
128 .collect()
129 }
130}
131
132impl std::ops::Shl<usize> for Square {
133 type Output = Square;
134 fn shl(self, rhs: usize) -> Self::Output {
135 Square::from_index(self as usize + rhs)
136 }
137}
138
139impl std::ops::Shr<usize> for Square {
140 type Output = Square;
141 fn shr(self, rhs: usize) -> Self::Output {
142 Square::from_index(self as usize - rhs)
143 }
144}
145
146impl std::ops::Not for Square {
147 type Output = BitBoard;
148 fn not(self) -> Self::Output {
149 !self.lift()
150 }
151}
152
153impl std::ops::BitOr<Square> for Square {
154 type Output = BitBoard;
155 fn bitor(self, other: Square) -> Self::Output {
156 self.lift() | other.lift()
157 }
158}
159
160impl std::ops::BitOr<BitBoard> for Square {
161 type Output = BitBoard;
162 fn bitor(self, other: BitBoard) -> Self::Output {
163 self.lift() | other
164 }
165}
166
167impl std::ops::BitAnd<BitBoard> for Square {
168 type Output = BitBoard;
169 fn bitand(self, other: BitBoard) -> Self::Output {
170 self.lift() & other
171 }
172}
173
174impl std::ops::Sub<BitBoard> for Square {
175 type Output = BitBoard;
176 fn sub(self, other: BitBoard) -> Self::Output {
177 self.lift() - other
178 }
179}
180
181#[rustfmt::skip]
182const ALL: [Square; 64] = [
183 Square::H1, Square::G1, Square::F1, Square::E1, Square::D1, Square::C1, Square::B1, Square::A1,
184 Square::H2, Square::G2, Square::F2, Square::E2, Square::D2, Square::C2, Square::B2, Square::A2,
185 Square::H3, Square::G3, Square::F3, Square::E3, Square::D3, Square::C3, Square::B3, Square::A3,
186 Square::H4, Square::G4, Square::F4, Square::E4, Square::D4, Square::C4, Square::B4, Square::A4,
187 Square::H5, Square::G5, Square::F5, Square::E5, Square::D5, Square::C5, Square::B5, Square::A5,
188 Square::H6, Square::G6, Square::F6, Square::E6, Square::D6, Square::C6, Square::B6, Square::A6,
189 Square::H7, Square::G7, Square::F7, Square::E7, Square::D7, Square::C7, Square::B7, Square::A7,
190 Square::H8, Square::G8, Square::F8, Square::E8, Square::D8, Square::C8, Square::B8, Square::A8,
191];
192
193#[cfg(test)]
194mod test {
195 use crate::square::Square;
196 use crate::square::Square::*;
197 use crate::Dir::*;
198
199 #[test]
200 fn test_rank() {
201 assert_eq!(A1 | B1 | C1 | D1 | E1 | F1 | G1 | H1, F1.rank());
202 assert_eq!(A4 | B4 | C4 | D4 | E4 | F4 | G4 | H4, D4.rank());
203 assert_eq!(A8 | B8 | C8 | D8 | E8 | F8 | G8 | H8, A8.rank());
204 }
205
206 #[test]
207 fn test_file() {
208 assert_eq!(B1 | B2 | B3 | B4 | B5 | B6 | B7 | B8, B4.file())
209 }
210
211 #[test]
212 fn test_partial_ord() {
213 for i in 0..64 {
214 let prev: Vec<_> = Square::iter().take(i).collect();
215 let next: Vec<_> = Square::iter().skip(i + 1).collect();
216 let pivot = Square::from_index(i);
217
218 for smaller in prev {
219 assert_eq!(true, smaller < pivot);
220 }
221
222 for larger in next {
223 assert_eq!(true, pivot < larger);
224 }
225 }
226 }
227
228 #[test]
229 fn test_search() {
230 assert_eq!(D3.search(S), D2 | D1);
231 }
232
233 #[test]
234 fn test_search_vec() {
235 assert_eq!(D3.search_vec(S), vec![D2, D1])
236 }
237
238 #[test]
239 fn test_search_one() {
240 assert_eq!(D3.search_one(&vec!(S, E)), D2 | E3);
241 assert_eq!(A8.search_one(&vec!(N, NWW, SE)), B7.lift());
242 }
243
244 #[test]
245 fn test_search_all() {
246 assert_eq!(C3.search_all(&vec!(SSW, SWW, S)), B1 | A2 | C2 | C1);
247 }
248
249 #[test]
250 fn test_next() {
251 assert_eq!(C3.next(N), Some(C4));
252 assert_eq!(C3.next(E), Some(D3));
253 assert_eq!(C3.next(S), Some(C2));
254 assert_eq!(C3.next(W), Some(B3));
255 assert_eq!(C3.next(NE), Some(D4));
256 assert_eq!(C3.next(SE), Some(D2));
257 assert_eq!(C3.next(SW), Some(B2));
258 assert_eq!(C3.next(NW), Some(B4));
259 assert_eq!(C3.next(NNE), Some(D5));
260 assert_eq!(C3.next(NEE), Some(E4));
261 assert_eq!(C3.next(SEE), Some(E2));
262 assert_eq!(C3.next(SSE), Some(D1));
263 assert_eq!(C3.next(SSW), Some(B1));
264 assert_eq!(C3.next(SWW), Some(A2));
265 assert_eq!(C3.next(NWW), Some(A4));
266 assert_eq!(C3.next(NNW), Some(B5));
267
268 assert_eq!(G8.next(N), None);
269 assert_eq!(H6.next(E), None);
270 assert_eq!(B1.next(S), None);
271 assert_eq!(A4.next(W), None);
272 }
273}