chess_move_gen/position/
make.rs

1use super::{Position, State};
2use crate::bb::*;
3use crate::castle::*;
4use crate::castling_rights::*;
5use crate::mv::{Move, NULL_MOVE};
6use crate::piece::*;
7use crate::side::{BLACK, Side};
8use crate::square::*;
9
10// If move intersects this mask, then remove castling right
11const CASTLE_MASKS: [BB; 4] = [
12    BB(1u64 | (1u64 << 4)),                // WHITE QS: A1 + E1
13    BB((1u64 | (1u64 << 4)) << 56),        // BLACK QS: A8 + E8
14    BB((1u64 << 4) | (1u64 << 7)),         // WHITE KS: E1 + H1
15    BB(((1u64 << 4) | (1u64 << 7)) << 56), /* BLACK KS: E8 + H8 */
16];
17
18impl Position {
19    /// Returns piece captured and square if any
20    pub fn make(&mut self, mv: Move) -> Option<(Piece, Square)> {
21        debug_assert_ne!(mv, NULL_MOVE);
22
23        let stm = self.state.stm;
24        let initial_state = self.state.clone();
25        let mut move_resets_half_move_clock = false;
26
27        // increment full move clock if black moved
28        if self.state.stm == BLACK {
29            self.state.full_move_number += 1;
30        }
31        self.state.stm = self.state.stm.flip();
32
33        self.state.ep_square = None;
34        let mut captured = None;
35
36        let mut xor_key = 0u64;
37
38        if mv.is_castle() {
39            let castle = mv.castle();
40            self.state.castling_rights.clear_side(stm);
41            self.make_castle(castle, stm);
42
43            xor_key ^= self.hash.castle(castle, stm);
44        } else {
45            let from = mv.from();
46            let to = mv.to();
47
48            if mv.is_capture() {
49                // half move clock reset after all pawn moves and captures
50                move_resets_half_move_clock = true;
51
52                let capture_sq = if mv.is_ep_capture() {
53                    from.along_row_with_col(to)
54                } else {
55                    to
56                };
57
58                let captured_piece = self.at(capture_sq);
59
60                debug_assert!(captured_piece.is_some());
61
62                debug_assert_ne!(captured_piece.kind(), KING);
63
64                self.remove_piece(capture_sq);
65
66                captured = Some((captured_piece, capture_sq));
67
68                xor_key ^= self.hash.capture(captured_piece, capture_sq);
69            }
70
71            let mover = self.at(from);
72            debug_assert!(mover.is_some());
73
74            let move_mask = self.move_piece(from, to);
75            let mut updated_mover = mover;
76
77            // half move clock reset after all pawn moves and captures
78            if mover.kind() == PAWN {
79                move_resets_half_move_clock = true;
80
81                // if double pawn push (pawn move that travels two rows), set ep square
82                if mv.distance() == 16 {
83                    self.state.ep_square = Some(Square((to.raw() + from.raw()) >> 1));
84                }
85            }
86
87            if mv.is_promotion() {
88                updated_mover = mv.promote_to().pc(stm);
89                self.promote_piece(to, updated_mover);
90            }
91
92            xor_key ^= self.hash.push(mover, from, updated_mover, to);
93
94            for (i, mask) in CASTLE_MASKS.iter().enumerate() {
95                if (move_mask & *mask) != EMPTY {
96                    self.state.castling_rights.clear(CastlingRights(1 << i));
97                }
98            }
99        }
100
101        if move_resets_half_move_clock {
102            self.state.half_move_clock = 0;
103        } else {
104            self.state.half_move_clock += 1;
105        }
106
107        xor_key ^= self.hash.state(&initial_state, &self.state);
108
109        self.key ^= xor_key;
110
111        captured
112    }
113
114    pub fn make_null_move(&mut self) -> Option<(Piece, Square)> {
115        let initial_state = self.state.clone();
116
117        // increment full move clock if black moved
118        if self.state.stm == BLACK {
119            self.state.full_move_number += 1;
120        }
121        self.state.half_move_clock += 1;
122        self.state.stm = self.state.stm.flip();
123        self.state.ep_square = None;
124
125        let mut xor_key = 0u64;
126
127        xor_key ^= self.hash.state(&initial_state, &self.state);
128
129        self.key ^= xor_key;
130
131        None
132    }
133
134    pub fn unmake(
135        &mut self,
136        mv: Move,
137        capture: Option<(Piece, Square)>,
138        original_state: &State,
139        original_hash_key: u64,
140    ) {
141        debug_assert_ne!(mv, NULL_MOVE);
142
143        self.state = original_state.clone();
144        self.key = original_hash_key;
145
146        if mv.is_castle() {
147            self.unmake_castle(mv.castle(), original_state.stm);
148            return;
149        }
150
151        if mv.is_promotion() {
152            let mover = PAWN.pc(original_state.stm);
153            self.promote_piece(mv.to(), mover);
154        }
155
156        self.move_piece(mv.to(), mv.from());
157
158        if let Some((captured_piece, capture_sq)) = capture {
159            self.put_piece(captured_piece, capture_sq);
160        }
161    }
162
163    fn unmake_castle(&mut self, castle: Castle, stm: Side) {
164        let (to, from) = castle_king_squares(stm, castle);
165        self.move_piece(from, to);
166        let (to, from) = castle_rook_squares(stm, castle);
167        self.move_piece(from, to);
168    }
169
170    fn make_castle(&mut self, castle: Castle, stm: Side) {
171        let (from, to) = castle_king_squares(stm, castle);
172        self.move_piece(from, to);
173        let (from, to) = castle_rook_squares(stm, castle);
174        self.move_piece(from, to);
175    }
176
177    pub fn unmake_null_move(&mut self, original_state: &State, original_hash_key: u64) {
178        self.state = original_state.clone();
179        self.key = original_hash_key;
180    }
181}
182
183#[cfg(test)]
184mod test {
185    use crate::castle::*;
186    use crate::integrity;
187    use crate::mv::Move;
188    use crate::piece::*;
189    use crate::square::*;
190    use crate::position::Position;
191
192    fn test_make_unmake(initial_fen: &'static str, expected_fen: &'static str, mv: Move) {
193        let mut position = Position::from_fen(initial_fen).unwrap();
194        assert!(integrity::test(&position).is_none());
195
196        let state = position.state().clone();
197
198        let initial_key = position.hash_key();
199
200        let capture = position.make(mv);
201        assert_eq!(position.to_fen(), expected_fen);
202
203        assert!(integrity::test(&position).is_none());
204
205        position.unmake(mv, capture, &state, initial_key);
206        assert_eq!(
207            position.to_string(),
208            Position::from_fen(initial_fen).unwrap().to_string()
209        );
210        assert!(integrity::test(&position).is_none());
211    }
212
213    #[test]
214    fn test_hmc_incremented_by_non_pawn_non_capture() {
215        let mut position =
216            Position::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBKQBNR w QqKk - 20 1")
217                .unwrap();
218        position.make(Move::new_push(B1, C3));
219
220        assert_eq!(position.state().half_move_clock, 21);
221    }
222
223    #[test]
224    fn test_hmc_reset_by_pawn_non_capture() {
225        let mut position =
226            Position::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBKQBNR w QqKk - 20 1")
227                .unwrap();
228        position.make(Move::new_push(A2, A3));
229
230        assert_eq!(position.state().half_move_clock, 0);
231    }
232
233    #[test]
234    fn test_hmc_reset_by_non_pawn_capture() {
235        let mut position =
236            Position::from_fen("rnbqkbnr/1ppppppp/8/8/8/p7/PPPPPPPP/RNBKQBNR w QqKk - 20 1")
237                .unwrap();
238        position.make(Move::new_capture(B1, A3));
239
240        assert_eq!(position.state().half_move_clock, 0);
241    }
242
243    #[test]
244    fn test_hash() {
245        let mut position_1 =
246            Position::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/R3KBNR w QqKk - 1 1").unwrap();
247        let mut position_2 =
248            Position::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/R3KBNR w QqKk - 1 1").unwrap();
249        position_1.make(Move::new_push(D2, D4));
250        position_1.make(Move::new_push(B8, C6));
251        position_1.make(Move::new_castle(QUEEN_SIDE));
252        position_1.make(Move::new_capture(C6, D4));
253
254        position_2.make(Move::new_castle(QUEEN_SIDE));
255        position_2.make(Move::new_push(B8, C6));
256        position_2.make(Move::new_push(D2, D4));
257        position_2.make(Move::new_capture(C6, D4));
258
259        assert_eq!(position_1.to_fen(), position_2.to_fen());
260        assert_eq!(position_1.hash_key(), position_2.hash_key());
261    }
262
263    #[test]
264    fn test_make_unmake_simple_push() {
265        test_make_unmake(
266            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR",
267            "rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b QqKk - 1 1",
268            Move::new_push(B1, C3),
269        );
270    }
271
272    #[test]
273    fn test_make_unmake_double_pawn_push() {
274        test_make_unmake(
275            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR",
276            "rnbqkbnr/pppppppp/8/8/3P4/8/PPP1PPPP/RNBQKBNR b QqKk d3 0 1",
277            Move::new_push(D2, D4),
278        );
279    }
280
281    #[test]
282    fn test_make_unmake_push_with_castle_invalidation() {
283        test_make_unmake(
284            "rnbqkbnr/1ppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR",
285            "1nbqkbnr/rppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b QKk - 1 1",
286            Move::new_push(A8, A7),
287        );
288    }
289
290    #[test]
291    fn test_make_unmake_promotion() {
292        test_make_unmake(
293            "rnbqkbnr/ppppppp1/8/8/8/8/PPPPPPPp/RNBQKBN1 b Qqk",
294            "rnbqkbnr/ppppppp1/8/8/8/8/PPPPPPP1/RNBQKBNq w Qqk - 0 2",
295            Move::new_promotion(H2, H1, QUEEN),
296        );
297    }
298
299    #[test]
300    fn test_make_unmake_capture_promotion() {
301        test_make_unmake(
302            "rnbqkbnr/pPpppppp/8/8/8/8/P1PPPPPP/RNBQKBNR w QKqk",
303            "Nnbqkbnr/p1pppppp/8/8/8/8/P1PPPPPP/RNBQKBNR b QKk - 0 1",
304            Move::new_capture_promotion(B7, A8, KNIGHT),
305        );
306    }
307
308    #[test]
309    fn test_make_unmake_ep_capture() {
310        test_make_unmake(
311            "rnbqkbnr/pppp1ppp/8/3Pp3/8/8/PPP1PPPP/RNBQKBNR w QqKk e6",
312            "rnbqkbnr/pppp1ppp/4P3/8/8/8/PPP1PPPP/RNBQKBNR b QqKk - 0 1",
313            Move::new_ep_capture(D5, E6),
314        );
315    }
316
317    #[test]
318    fn test_make_unmake_castle() {
319        test_make_unmake(
320            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/R3KBNR w qkQK - 20 10",
321            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/2KR1BNR b qk - 21 10",
322            Move::new_castle(QUEEN_SIDE),
323        );
324    }
325
326    #[test]
327    fn test_make_unmake_double_push() {
328        test_make_unmake(
329            "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b",
330            "rnbqkbnr/ppp1pppp/8/3p4/8/8/PPPPPPPP/RNBQKBNR w QqKk d6 0 2",
331            Move::new_push(D7, D5),
332        );
333    }
334
335    #[test]
336    fn test_make_unmake_capture() {
337        test_make_unmake(
338            "rnbqkbnr/pppppppp/7P/8/8/8/PPPPPPP1/RNBQKBNR",
339            "rnbqkbnr/ppppppPp/8/8/8/8/PPPPPPP1/RNBQKBNR b QqKk - 0 1",
340            Move::new_capture(H6, G7),
341        );
342    }
343}