myopic_brain/eval/additional_components/
opening.rs

1use crate::eval::EvalComponent;
2use crate::{CastleZone, Move, Piece};
3use myopic_board::Square;
4
5#[derive(Clone)]
6pub struct OpeningRewards {
7    pub e_pawn: i32,
8    pub d_pawn: i32,
9    pub b_knight: i32,
10    pub g_knight: i32,
11    pub c_bishop: i32,
12    pub f_bishop: i32,
13    pub k_castle: i32,
14    pub q_castle: i32,
15}
16
17impl Default for OpeningRewards {
18    fn default() -> Self {
19        OpeningRewards {
20            e_pawn: 200,
21            d_pawn: 150,
22            b_knight: 100,
23            g_knight: 150,
24            c_bishop: 100,
25            f_bishop: 150,
26            k_castle: 200,
27            q_castle: 100,
28        }
29    }
30}
31
32#[derive(Clone)]
33pub struct OpeningComponent {
34    rewards: OpeningRewards,
35    score: i32,
36    pieces: DevTracker,
37    move_dist: usize,
38}
39
40impl Default for OpeningComponent {
41    fn default() -> Self {
42        OpeningComponent::new(OpeningRewards::default())
43    }
44}
45
46impl OpeningComponent {
47    pub fn new(rewards: OpeningRewards) -> OpeningComponent {
48        OpeningComponent {
49            score: 0,
50            move_dist: 0,
51            pieces: DevTracker {
52                w_e_pawn: PieceTracker::new(Square::E2, rewards.e_pawn),
53                w_d_pawn: PieceTracker::new(Square::D2, rewards.d_pawn),
54                w_b_knight: PieceTracker::new(Square::B1, rewards.b_knight),
55                w_g_knight: PieceTracker::new(Square::G1, rewards.g_knight),
56                w_c_bishop: PieceTracker::new(Square::C1, rewards.c_bishop),
57                w_f_bishop: PieceTracker::new(Square::F1, rewards.f_bishop),
58
59                b_e_pawn: PieceTracker::new(Square::E7, -rewards.e_pawn),
60                b_d_pawn: PieceTracker::new(Square::D7, -rewards.d_pawn),
61                b_b_knight: PieceTracker::new(Square::B8, -rewards.b_knight),
62                b_g_knight: PieceTracker::new(Square::G8, -rewards.g_knight),
63                b_c_bishop: PieceTracker::new(Square::C8, -rewards.c_bishop),
64                b_f_bishop: PieceTracker::new(Square::F8, -rewards.f_bishop),
65            },
66            rewards,
67        }
68    }
69}
70
71#[derive(Debug, Clone)]
72struct DevTracker {
73    // whites
74    w_e_pawn: PieceTracker,
75    w_d_pawn: PieceTracker,
76    w_b_knight: PieceTracker,
77    w_g_knight: PieceTracker,
78    w_c_bishop: PieceTracker,
79    w_f_bishop: PieceTracker,
80    // blacks
81    b_e_pawn: PieceTracker,
82    b_d_pawn: PieceTracker,
83    b_b_knight: PieceTracker,
84    b_g_knight: PieceTracker,
85    b_c_bishop: PieceTracker,
86    b_f_bishop: PieceTracker,
87}
88
89impl DevTracker {
90    fn get_piece_trackers(&mut self, p: Piece) -> Vec<&mut PieceTracker> {
91        match p {
92            Piece::WP => vec![&mut self.w_d_pawn, &mut self.w_e_pawn],
93            Piece::WN => vec![&mut self.w_b_knight, &mut self.w_g_knight],
94            Piece::WB => vec![&mut self.w_c_bishop, &mut self.w_f_bishop],
95            Piece::BP => vec![&mut self.b_d_pawn, &mut self.b_e_pawn],
96            Piece::BN => vec![&mut self.b_b_knight, &mut self.b_g_knight],
97            Piece::BB => vec![&mut self.b_c_bishop, &mut self.b_f_bishop],
98            _ => vec![],
99        }
100    }
101}
102
103#[derive(Debug, Clone)]
104struct PieceTracker {
105    /// The most recent location of the piece on the board
106    loc: Square,
107    /// How many moves this piece has made from it's start
108    /// position
109    count: usize,
110    /// What reward this tracker is associated with
111    reward: i32,
112    /// The move dist at which this piece was removed
113    /// from the board. None if still on board.
114    capture_dist: Option<usize>,
115}
116
117impl PieceTracker {
118    fn new(start: Square, reward: i32) -> PieceTracker {
119        PieceTracker {
120            loc: start,
121            count: 0,
122            capture_dist: None,
123            reward,
124        }
125    }
126
127    fn move_forward(&mut self, new_loc: Square) -> usize {
128        self.loc = new_loc;
129        self.count += 1;
130        self.count
131    }
132
133    fn move_backward(&mut self, old_loc: Square) -> usize {
134        self.loc = old_loc;
135        self.count -= 1;
136        self.count
137    }
138
139    fn remove_piece(&mut self, capture_dist: usize) {
140        self.capture_dist = Some(capture_dist);
141    }
142
143    fn add_piece(&mut self) {
144        self.capture_dist = None;
145    }
146
147    fn matches_on(&self, loc: Square) -> bool {
148        self.capture_dist.is_none() && loc == self.loc
149    }
150
151    fn matches_off(&self, move_dist: usize) -> bool {
152        match self.capture_dist {
153            None => false,
154            Some(cd) => cd == move_dist,
155        }
156    }
157}
158
159impl EvalComponent for OpeningComponent {
160    fn static_eval(&self) -> i32 {
161        self.score
162    }
163
164    fn make(&mut self, mv: &Move) {
165        self.move_dist += 1;
166        match mv {
167            &Move::Standard {
168                moving,
169                from,
170                dest,
171                capture,
172                ..
173            } => {
174                // Update location of moving piece
175                for pt in self.pieces.get_piece_trackers(moving) {
176                    if pt.matches_on(from) && pt.move_forward(dest) == 1 {
177                        self.score += pt.reward
178                    }
179                }
180                // Remove any captured piece
181                if capture.is_some() {
182                    for pt in self.pieces.get_piece_trackers(capture.unwrap()) {
183                        if pt.matches_on(dest) {
184                            pt.remove_piece(self.move_dist);
185                        }
186                    }
187                }
188            }
189            &Move::Castle { zone, .. } => match zone {
190                CastleZone::WK => self.score += self.rewards.k_castle,
191                CastleZone::WQ => self.score += self.rewards.q_castle,
192                CastleZone::BK => self.score -= self.rewards.k_castle,
193                CastleZone::BQ => self.score -= self.rewards.q_castle,
194            },
195            _ => {}
196        }
197    }
198
199    fn unmake(&mut self, mv: &Move) {
200        match mv {
201            &Move::Standard {
202                moving,
203                from,
204                dest,
205                capture,
206                ..
207            } => {
208                // Update location of moving piece
209                for pt in self.pieces.get_piece_trackers(moving) {
210                    if pt.matches_on(dest) && pt.move_backward(from) == 0 {
211                        self.score -= pt.reward
212                    }
213                }
214                // Replace any captured piece
215                if capture.is_some() {
216                    for pt in self.pieces.get_piece_trackers(capture.unwrap()) {
217                        if pt.matches_off(self.move_dist) {
218                            pt.add_piece();
219                        }
220                    }
221                }
222            }
223            &Move::Castle { zone, .. } => match zone {
224                CastleZone::WK => self.score -= self.rewards.k_castle,
225                CastleZone::WQ => self.score -= self.rewards.q_castle,
226                CastleZone::BK => self.score += self.rewards.k_castle,
227                CastleZone::BQ => self.score += self.rewards.q_castle,
228            },
229            _ => {}
230        };
231        self.move_dist -= 1;
232    }
233}
234
235#[cfg(test)]
236mod test {
237    use crate::eval::additional_components::opening::{OpeningComponent, OpeningRewards};
238    use crate::eval::EvalComponent;
239    use crate::{Board, EvalBoard, Reflectable, UciMove};
240    use anyhow::Result;
241    use myopic_board::ChessBoard;
242
243    #[test]
244    fn issue_97() -> Result<()> {
245        let mut state = EvalBoard::start();
246        state.play_uci(
247            "e2e4 g7g6 d2d4 f8g7 c2c4 d7d6 b1c3 g8f6 g1f3 e8g8 f1d3 e7e5 f3d2 e5d4 d2b1 d4c3 b2c3",
248        )?;
249        state.unmake()?;
250        state.play_uci("b1c3")?;
251        state.unmake()?;
252        state.unmake()?;
253        state.unmake()?;
254        Ok(())
255    }
256
257    fn dummy_rewards() -> OpeningRewards {
258        OpeningRewards {
259            d_pawn: 1,
260            e_pawn: 10,
261            g_knight: 100,
262            b_knight: 1000,
263            f_bishop: 10000,
264            c_bishop: 100000,
265            k_castle: 1000000,
266            q_castle: 10000000,
267        }
268    }
269
270    #[test]
271    fn case_3() -> Result<()> {
272        execute_case(TestCase {
273            board: crate::start(),
274            moves_evals: vec![
275                (UciMove::new("e2e4")?, 10),
276                (UciMove::new("g7g6")?, 10),
277                (UciMove::new("d2d4")?, 11),
278                (UciMove::new("f8g7")?, -9989),
279                (UciMove::new("c2c4")?, -9989),
280                (UciMove::new("d7d6")?, -9990),
281                (UciMove::new("b1c3")?, -8990),
282                (UciMove::new("g8f6")?, -9090),
283                (UciMove::new("g1f3")?, -8990),
284                (UciMove::new("e8g8")?, -1008990),
285                (UciMove::new("f1d3")?, -998990),
286            ],
287        })
288    }
289
290    #[test]
291    fn case_1() -> Result<()> {
292        execute_case(TestCase {
293            board: crate::start(),
294            moves_evals: vec![
295                (UciMove::new("d2d4")?, 1),
296                (UciMove::new("d7d5")?, 0),
297                (UciMove::new("e2e4")?, 10),
298                (UciMove::new("e7e5")?, 0),
299                (UciMove::new("a2a3")?, 0),
300                (UciMove::new("g8f6")?, -100),
301                (UciMove::new("b1c3")?, 900), // w
302                (UciMove::new("b8a6")?, -100),
303                (UciMove::new("g1f3")?, 0),
304                (UciMove::new("c8d7")?, -100000),
305                (UciMove::new("c1d2")?, 0),
306                (UciMove::new("f8b4")?, -10000),
307                (UciMove::new("f1b5")?, 0),
308                (UciMove::new("d8e7")?, 0),
309                (UciMove::new("d1e2")?, 0),
310                // Castle kingside
311                (UciMove::new("e8g8")?, -1000000),
312                (UciMove::new("e1g1")?, 0),
313            ],
314        })
315    }
316
317    #[test]
318    fn case_2() -> Result<()> {
319        execute_case(TestCase {
320            board: crate::start(),
321            moves_evals: vec![
322                (UciMove::new("d2d4")?, 1),
323                (UciMove::new("d7d5")?, 0),
324                (UciMove::new("e2e4")?, 10),
325                (UciMove::new("e7e5")?, 0),
326                (UciMove::new("a2a3")?, 0),
327                (UciMove::new("g8f6")?, -100),
328                (UciMove::new("b1c3")?, 900), // w
329                (UciMove::new("b8a6")?, -100),
330                (UciMove::new("g1f3")?, 0),
331                (UciMove::new("c8d7")?, -100000),
332                (UciMove::new("c1d2")?, 0),
333                (UciMove::new("f8b4")?, -10000),
334                (UciMove::new("f1b5")?, 0),
335                (UciMove::new("d8e7")?, 0),
336                (UciMove::new("d1e2")?, 0),
337                // Castle queenside
338                (UciMove::new("e8c8")?, -10000000),
339                (UciMove::new("e1c1")?, 0),
340            ],
341        })
342    }
343
344    fn execute_case(case: TestCase) -> Result<()> {
345        execute_case_impl(case.reflect())?;
346        execute_case_impl(case)?;
347        Ok(())
348    }
349
350    fn execute_case_impl(case: TestCase) -> Result<()> {
351        let mut board = case.board;
352        let mut component = OpeningComponent::new(dummy_rewards());
353        for (uci_mv, expected_eval) in case.moves_evals {
354            let curr_eval = component.static_eval();
355            let mv = board.parse_uci(uci_mv.as_str())?;
356            component.make(&mv);
357            assert_eq!(
358                expected_eval,
359                component.static_eval(),
360                "make {}",
361                uci_mv.as_str()
362            );
363            component.unmake(&mv);
364            assert_eq!(
365                curr_eval,
366                component.static_eval(),
367                "unmake {}",
368                uci_mv.as_str()
369            );
370            component.make(&mv);
371            board.make(mv)?;
372        }
373        Ok(())
374    }
375
376    struct TestCase {
377        board: Board,
378        moves_evals: Vec<(UciMove, i32)>,
379    }
380
381    impl Reflectable for TestCase {
382        fn reflect(&self) -> Self {
383            TestCase {
384                board: self.board.reflect(),
385                moves_evals: self.moves_evals.reflect(),
386            }
387        }
388    }
389}