blunders_engine/search/
alpha_beta.rs1use std::cmp;
4use std::time::Instant;
5
6use crate::coretypes::Color::*;
7use crate::coretypes::{Cp, Move, PlyKind, Square};
8use crate::eval::{evaluate_abs, terminal_abs};
9use crate::movelist::Line;
10use crate::search::SearchResult;
11use crate::Position;
12
13pub fn alpha_beta(position: Position, ply: PlyKind) -> SearchResult {
17 debug_assert_ne!(ply, 0);
18
19 let mut pv = Line::new();
20 let mut nodes = 0;
21 let instant = Instant::now();
22 let (score, best_move) = alpha_beta_root(position, ply, &mut nodes, Cp::MIN, Cp::MAX);
23 pv.push(best_move);
24
25 SearchResult {
26 player: position.player,
27 depth: ply,
28 best_move,
29 score,
30 pv,
31 nodes,
32 elapsed: instant.elapsed(),
33 ..Default::default()
34 }
35}
36
37const WHITE: u8 = White as u8;
38const BLACK: u8 = Black as u8;
39
40pub(crate) fn alpha_beta_root(
55 mut position: Position,
56 ply: PlyKind,
57 nodes: &mut u64,
58 mut alpha: Cp,
59 mut beta: Cp,
60) -> (Cp, Move) {
61 *nodes += 1;
62 let cache = position.cache();
63 let legal_moves = position.get_legal_moves();
64 debug_assert_ne!(ply, 0);
65 debug_assert!(legal_moves.len() > 0);
66
67 let mut best_move = Move::new(Square::D2, Square::D4, None);
68
69 if position.player == White {
70 for legal_move in legal_moves {
71 let move_info = position.do_move(legal_move);
72 let move_cp = alpha_beta_impl::<BLACK>(&mut position, ply - 1, nodes, alpha, beta);
73 position.undo_move(move_info, cache);
74
75 if move_cp > alpha {
76 alpha = move_cp;
77 best_move = legal_move;
78 }
79 }
80 (alpha, best_move)
81 } else {
82 for legal_move in legal_moves {
83 let move_info = position.do_move(legal_move);
84 let move_cp = alpha_beta_impl::<WHITE>(&mut position, ply - 1, nodes, alpha, beta);
85 position.undo_move(move_info, cache);
86
87 if move_cp < beta {
88 beta = move_cp;
89 best_move = legal_move;
90 }
91 }
92 (beta, best_move)
93 }
94}
95
96fn alpha_beta_impl<const COLOR: u8>(
97 position: &mut Position,
98 ply: PlyKind,
99 nodes: &mut u64,
100 alpha: Cp,
101 beta: Cp,
102) -> Cp {
103 *nodes += 1;
104 let cache = position.cache();
105 let legal_moves = position.get_legal_moves();
106 let num_moves = legal_moves.len();
107
108 if num_moves == 0 {
110 return terminal_abs(position);
111 } else if ply == 0 {
112 return evaluate_abs(position);
113 }
114
115 if COLOR == White as u8 {
116 let mut best_cp = Cp::MIN;
117 let mut alpha = alpha;
118
119 for legal_move in legal_moves {
120 let move_info = position.do_move(legal_move);
121 let move_cp = alpha_beta_impl::<BLACK>(position, ply - 1, nodes, alpha, beta);
122 position.undo_move(move_info, cache);
123
124 best_cp = cmp::max(best_cp, move_cp);
125 alpha = cmp::max(alpha, best_cp);
126 if alpha >= beta {
127 return best_cp;
129 }
130 }
131 best_cp
132 } else {
133 let mut best_cp = Cp::MAX;
134 let mut beta = beta;
135
136 for legal_move in legal_moves {
137 let move_info = position.do_move(legal_move);
138 let move_cp = alpha_beta_impl::<WHITE>(position, ply - 1, nodes, alpha, beta);
139 position.undo_move(move_info, cache);
140
141 best_cp = cmp::min(best_cp, move_cp);
142 beta = cmp::min(beta, best_cp);
143 if alpha >= beta {
144 return best_cp;
146 }
147 }
148 best_cp
149 }
150}