Skip to main content

timecat/
search.rs

1use super::*;
2
3// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
4#[derive(Clone, Debug)]
5pub struct Searcher<P: PositionEvaluation> {
6    id: usize,
7    score: Score,
8    initial_position: ChessPosition,
9    board: Board,
10    evaluator: P,
11    transposition_table: Arc<TranspositionTable>,
12    pv_table: PVTable,
13    best_moves: Vec<Move>,
14    move_sorter: MoveSorter,
15    num_nodes_searched: Arc<AtomicUsize>,
16    selective_depth: Arc<AtomicUsize>,
17    ply: Ply,
18    root_score_cached: Score,
19    depth_completed: Depth,
20    is_outside_aspiration_window: bool,
21    clock: Instant,
22    stop_command: Arc<AtomicBool>,
23    properties: EngineProperties,
24}
25
26impl<P: PositionEvaluation> Searcher<P> {
27    pub fn new(
28        id: usize,
29        last_score: Option<Score>,
30        board: Board,
31        mut evaluator: P,
32        transposition_table: Arc<TranspositionTable>,
33        num_nodes_searched: Arc<AtomicUsize>,
34        selective_depth: Arc<AtomicUsize>,
35        stop_command: Arc<AtomicBool>,
36        properties: EngineProperties,
37    ) -> Self {
38        Self {
39            id,
40            score: board.score_flipped(last_score.unwrap_or_else(|| evaluator.evaluate(&board))),
41            initial_position: board.get_position().to_owned(),
42            board,
43            evaluator,
44            transposition_table,
45            pv_table: PVTable::new(),
46            best_moves: Vec::new(),
47            move_sorter: MoveSorter::new(),
48            num_nodes_searched,
49            selective_depth,
50            ply: 0,
51            root_score_cached: -INFINITY,
52            depth_completed: 0,
53            is_outside_aspiration_window: false,
54            clock: Instant::now(),
55            stop_command,
56            properties,
57        }
58    }
59
60    #[inline]
61    pub fn is_main_threaded(&self) -> bool {
62        self.get_id() == 0
63    }
64
65    #[inline]
66    pub fn get_id(&self) -> usize {
67        self.id
68    }
69
70    #[inline]
71    pub fn get_initial_position(&self) -> &ChessPosition {
72        &self.initial_position
73    }
74
75    #[inline]
76    pub fn get_board(&self) -> &Board {
77        &self.board
78    }
79
80    #[inline]
81    pub fn get_evaluator(&self) -> &P {
82        &self.evaluator
83    }
84
85    #[inline]
86    pub fn get_evaluator_mut(&mut self) -> &mut P {
87        &mut self.evaluator
88    }
89
90    #[inline]
91    pub fn get_transposition_table(&self) -> &TranspositionTable {
92        &self.transposition_table
93    }
94
95    #[inline]
96    pub fn get_pv_table(&self) -> &PVTable {
97        &self.pv_table
98    }
99
100    #[inline]
101    pub fn get_best_moves(&self) -> &[Move] {
102        &self.best_moves
103    }
104
105    #[inline]
106    pub fn get_move_sorter(&self) -> &MoveSorter {
107        &self.move_sorter
108    }
109
110    #[inline]
111    pub fn get_ply(&self) -> Ply {
112        self.ply
113    }
114
115    #[inline]
116    pub fn get_stop_command(&self) -> Arc<AtomicBool> {
117        self.stop_command.clone()
118    }
119
120    #[inline]
121    pub fn get_num_nodes_searched(&self) -> usize {
122        self.num_nodes_searched.load(MEMORY_ORDERING)
123    }
124
125    #[inline]
126    pub fn get_selective_depth(&self) -> Ply {
127        self.selective_depth.load(MEMORY_ORDERING)
128    }
129
130    #[inline]
131    pub fn get_clock(&self) -> Instant {
132        self.clock
133    }
134
135    #[inline]
136    pub fn get_time_elapsed(&self) -> Duration {
137        self.clock.elapsed()
138    }
139
140    #[inline]
141    pub fn get_pv(&self) -> impl Iterator<Item = &Move> {
142        self.pv_table.get_pv(0)
143    }
144
145    #[inline]
146    pub fn get_pv_from_t_table(&self) -> Vec<Move> {
147        extract_pv_from_t_table(&self.initial_position, &self.transposition_table)
148    }
149
150    #[inline]
151    pub fn get_nth_pv_move(&self, n: usize) -> Option<&Move> {
152        self.pv_table.get_pv(0).nth(n)
153    }
154
155    #[inline]
156    pub fn get_best_move(&self) -> Option<&Move> {
157        self.get_nth_pv_move(0)
158    }
159
160    #[inline]
161    pub fn get_ponder_move(&self) -> Option<&Move> {
162        self.get_nth_pv_move(1)
163    }
164
165    #[inline]
166    pub fn get_score(&self) -> Score {
167        self.score
168    }
169
170    #[inline]
171    pub fn get_depth_completed(&self) -> Depth {
172        self.depth_completed
173    }
174
175    #[inline]
176    pub fn is_outside_aspiration_window(&self) -> bool {
177        self.is_outside_aspiration_window
178    }
179
180    #[inline]
181    pub fn get_search_info(&self) -> SearchInfo {
182        self.into()
183    }
184
185    #[inline]
186    pub fn stop_search_at_every_node(
187        &mut self,
188        controller: Option<&mut impl SearchControl<Self>>,
189    ) -> bool {
190        self.stop_command.load(MEMORY_ORDERING)
191            || controller.is_some_and(|controller| controller.stop_search_at_every_node(self))
192    }
193
194    fn pop(&mut self) -> Result<ValidOrNullMove> {
195        self.ply -= 1;
196        self.board.pop()
197    }
198
199    fn evaluate_flipped(&mut self) -> Score {
200        self.evaluator.evaluate_flipped(&self.board)
201    }
202
203    #[inline]
204    pub fn print_root_node_info(
205        board: &Board,
206        curr_move: Move,
207        depth: Depth,
208        score: Score,
209        num_nodes_searched: usize,
210        time_elapsed: Duration,
211    ) {
212        println_wasm!(
213            "{} {} {} {} {} {} {} {} {} {} {}",
214            "info".colorize(INFO_MESSAGE_STYLE),
215            "curr move".colorize(INFO_MESSAGE_STYLE),
216            curr_move.stringify_move(board).unwrap(),
217            "depth".colorize(INFO_MESSAGE_STYLE),
218            depth,
219            "score".colorize(INFO_MESSAGE_STYLE),
220            if GLOBAL_TIMECAT_STATE.is_in_console_mode() {
221                board.score_flipped(score)
222            } else {
223                score
224            }
225            .stringify(),
226            "nodes".colorize(INFO_MESSAGE_STYLE),
227            num_nodes_searched,
228            "time".colorize(INFO_MESSAGE_STYLE),
229            time_elapsed.stringify(),
230        );
231    }
232
233    fn is_draw_move(&self, valid_or_null_move: ValidOrNullMove) -> bool {
234        self.board.gives_threefold_repetition(valid_or_null_move)
235            || self
236                .board
237                .gives_claimable_threefold_repetition(valid_or_null_move)
238    }
239
240    fn update_best_moves(&mut self) {
241        if let Some(&best_move) = self.get_best_move() {
242            self.best_moves
243                .retain(|&valid_or_null_move| valid_or_null_move != best_move);
244            self.best_moves.insert(0, best_move);
245        }
246    }
247
248    fn get_sorted_root_node_moves(
249        &mut self,
250        moves_to_search: Option<Vec<Move>>,
251    ) -> Vec<(Move, MoveWeight)> {
252        let best_move = self.get_best_move().copied();
253
254        let mut moves_vec_sorted = self
255            .move_sorter
256            .get_weighted_moves_sorted(
257                &self.board,
258                moves_to_search
259                    .unwrap_or_else(|| self.board.generate_legal_moves().into_iter().collect_vec()),
260                &self.transposition_table,
261                0,
262                self.transposition_table
263                    .read_best_move(self.board.get_hash()),
264                best_move,
265            )
266            .map(|WeightedMove { move_, .. }| {
267                (
268                    move_,
269                    MoveSorter::score_root_moves(
270                        &self.board,
271                        &mut self.evaluator,
272                        move_,
273                        best_move, // PV Move
274                        &self.best_moves,
275                    ),
276                )
277            })
278            .collect_vec();
279        moves_vec_sorted.sort_by_key(|&t| Reverse(t.1));
280        moves_vec_sorted
281    }
282
283    fn search_root(
284        &mut self,
285        depth: Depth,
286        mut alpha: Score,
287        beta: Score,
288        mut controller: Option<&mut impl SearchControl<Self>>,
289        print_move_info: bool,
290    ) -> Option<Infallible> {
291        if FOLLOW_PV {
292            self.move_sorter.follow_pv();
293        }
294        if self.is_main_threaded() {
295            self.selective_depth.store(0, MEMORY_ORDERING);
296        }
297        if self.board.is_game_over() {
298            self.root_score_cached = if self.board.is_checkmate() {
299                -self.evaluator.evaluate_checkmate_in(0)
300            } else {
301                self.evaluator.evaluate_draw()
302            };
303            return None;
304        }
305        let moves_to_search = controller
306            .as_ref()
307            .and_then(|controller| controller.get_root_moves_to_search())
308            .map(|moves| moves.to_vec());
309        if !(depth > 1 && self.is_main_threaded()) {
310            controller = None;
311        }
312        if self.stop_search_at_every_node(controller.as_deref_mut()) {
313            return None;
314        }
315        let key = self.board.get_hash();
316        self.root_score_cached = -INFINITY;
317        let mut flag = EntryFlagHash::Alpha;
318        let is_endgame = self.board.is_endgame();
319        let moves = self.get_sorted_root_node_moves(moves_to_search);
320        for (move_index, &(move_, _)) in moves.iter().enumerate() {
321            if !is_endgame
322                && self.is_draw_move(move_.into())
323                && self.root_score_cached > -DRAW_SCORE
324            {
325                continue;
326            }
327            let clock = Instant::now();
328            unsafe { self.push_unchecked(move_) };
329            if move_index == 0
330                || -self.alpha_beta(depth - 1, -alpha - 1, -alpha, controller.as_deref_mut())?
331                    > alpha
332            {
333                self.root_score_cached =
334                    -self.alpha_beta(depth - 1, -beta, -alpha, controller.as_deref_mut())?;
335            }
336            self.pop().unwrap();
337            if print_move_info && self.is_main_threaded() {
338                let time_elapsed = clock.elapsed();
339                if time_elapsed > PRINT_MOVE_INFO_DURATION_THRESHOLD {
340                    Self::print_root_node_info(
341                        &self.board,
342                        move_,
343                        depth,
344                        self.root_score_cached,
345                        self.get_num_nodes_searched(),
346                        time_elapsed,
347                    )
348                }
349            }
350            if self.root_score_cached > alpha {
351                flag = EntryFlagHash::Exact;
352                alpha = self.root_score_cached;
353                self.pv_table.update_table(self.ply, move_);
354                if self.root_score_cached >= beta {
355                    self.transposition_table.write(
356                        key,
357                        depth,
358                        self.ply,
359                        beta,
360                        EntryFlagHash::Beta,
361                        Some(move_),
362                    );
363                    self.root_score_cached = beta;
364                    return None;
365                }
366            }
367        }
368        self.transposition_table.write(
369            key,
370            depth,
371            self.ply,
372            alpha,
373            flag,
374            self.get_best_move().copied(),
375        );
376        self.update_best_moves();
377        self.root_score_cached = alpha;
378        None
379    }
380
381    fn get_lmr_reduction(depth: Depth, move_index: usize, is_pv_node: bool) -> Depth {
382        let mut reduction =
383            LMR_BASE_REDUCTION + (depth as f64).ln() * (move_index as f64).ln() / LMR_MOVE_DIVIDER;
384        // let mut reduction = (depth as f64 - 1.0).max(0.0).sqrt() + (move_index as f64 - 1.0).max(0.0).sqrt();
385        if is_pv_node {
386            // reduction /= 3.0;
387            reduction *= 2.0 / 3.0;
388        }
389        reduction.round() as Depth
390    }
391
392    fn alpha_beta(
393        &mut self,
394        mut depth: Depth,
395        mut alpha: Score,
396        mut beta: Score,
397        mut controller: Option<&mut impl SearchControl<Self>>,
398    ) -> Option<Score> {
399        self.pv_table.set_length(self.ply, self.ply);
400        let mate_score = self.evaluator.evaluate_checkmate_in(self.ply);
401        let draw_score = self.evaluator.evaluate_draw();
402        if self.board.is_other_draw() {
403            return Some(draw_score);
404        }
405        if self.properties.use_mate_distance_pruning() {
406            // mate distance pruning
407            alpha = alpha.max(-mate_score);
408            beta = beta.min(mate_score - 1);
409            if alpha >= beta {
410                return Some(alpha);
411            }
412        }
413        let checkers = self.board.get_checkers();
414        if depth > 10 {
415            depth += checkers.popcnt() as Depth;
416        }
417        let min_depth = self.move_sorter.is_following_pv() as Depth;
418        depth = depth.max(min_depth);
419        let is_pv_node = alpha != beta - 1;
420        let key = self.board.get_hash();
421        let best_move = if is_pv_node && self.is_main_threaded() {
422            self.transposition_table.read_best_move(key)
423        } else {
424            let (optional_data, best_move) = self.transposition_table.read(key, depth, self.ply);
425            if let Some((score, flag)) = optional_data {
426                // match flag {
427                //     HashExact => return Some(score),
428                //     HashAlpha => alpha = alpha.max(score),
429                //     HashBeta => beta = beta.min(score),
430                // }
431                // if alpha >= beta {
432                //     return Some(alpha);
433                // }
434                match flag {
435                    EntryFlagHash::Exact => return Some(score),
436                    EntryFlagHash::Alpha => {
437                        if score <= alpha {
438                            return Some(score);
439                        }
440                    }
441                    EntryFlagHash::Beta => {
442                        if score >= beta {
443                            return Some(score);
444                        }
445                    }
446                }
447            }
448            best_move
449        };
450        if self.ply == MAX_PLY - 1 {
451            return Some(self.evaluate_flipped());
452        }
453        // enable_controller &= depth > 3;
454        if self.stop_search_at_every_node(controller.as_deref_mut()) {
455            return None;
456        }
457        if depth == 0 {
458            return self.quiescence(alpha, beta, controller.as_deref_mut());
459        }
460        if self.is_main_threaded() && is_pv_node {
461            self.selective_depth.fetch_max(self.ply, MEMORY_ORDERING);
462        }
463        self.num_nodes_searched.fetch_add(1, MEMORY_ORDERING);
464        let not_in_check = checkers.is_empty();
465        let mut futility_pruning = false;
466        if not_in_check && !DISABLE_ALL_PRUNINGS {
467            // static evaluation
468            let static_evaluation = self.evaluate_flipped();
469            if depth < 3 && !is_pv_node && !is_checkmate(beta) {
470                let eval_margin = ((6 * PAWN_VALUE) / 5) * depth as Score;
471                let new_score = static_evaluation - eval_margin;
472                if new_score >= beta {
473                    return Some(new_score);
474                }
475            }
476            // razoring
477            static RAZORING_DEPTH: Depth = 3;
478            if !is_pv_node && depth <= RAZORING_DEPTH && !is_checkmate(beta) {
479                let mut score = static_evaluation + const { (5 * PAWN_VALUE) / 4 };
480                if score < beta {
481                    if depth == 1 {
482                        let new_score = self.quiescence(alpha, beta, controller.as_deref_mut())?;
483                        return Some(new_score.max(score));
484                    }
485                    score += const { (7 * PAWN_VALUE) / 4 };
486                    if score < beta && depth < RAZORING_DEPTH {
487                        let new_score = self.quiescence(alpha, beta, controller.as_deref_mut())?;
488                        if new_score < beta {
489                            return Some(new_score.max(score));
490                        }
491                    }
492                }
493            }
494            // null move pruning
495            if !is_pv_node && depth >= NULL_MOVE_MIN_DEPTH && static_evaluation >= beta {
496                // let r = NULL_MOVE_MIN_REDUCTION
497                //     + (depth.max(NULL_MOVE_MIN_DEPTH) as f64 / NULL_MOVE_DEPTH_DIVIDER as f64)
498                //         .round() as Depth;
499                // let reduced_depth = depth - r - 1;
500                let r = 1920 + (depth as u32) * 2368;
501                let reduced_depth = ((depth as u32) - r / 4096) as Depth;
502                unsafe { self.push_unchecked(ValidOrNullMove::NullMove) };
503                let score =
504                    -self.alpha_beta(reduced_depth, -beta, -beta + 1, controller.as_deref_mut())?;
505                self.pop().unwrap();
506                if score >= beta {
507                    return Some(beta);
508                }
509            }
510            // futility pruning condition
511            if depth < 4 && alpha < mate_score {
512                let futility_margin = get_item_unchecked!(
513                    const { [0, PAWN_VALUE, Knight.evaluate(), Rook.evaluate()] },
514                    depth as usize,
515                );
516                futility_pruning = static_evaluation + futility_margin <= alpha;
517            }
518        }
519        let mut flag = EntryFlagHash::Alpha;
520        let weighted_moves = self.move_sorter.get_weighted_moves_sorted(
521            &self.board,
522            self.board.generate_legal_moves(),
523            &self.transposition_table,
524            self.ply,
525            best_move,
526            self.get_nth_pv_move(self.ply).copied(),
527        );
528        if weighted_moves.is_empty() {
529            return if not_in_check {
530                Some(draw_score)
531            } else {
532                Some(-mate_score)
533            };
534        }
535        for (move_index, WeightedMove { move_, .. }) in weighted_moves.enumerate() {
536            let not_capture_move = !self.board.is_capture(move_);
537            let not_an_interesting_position = not_capture_move
538                && not_in_check
539                && move_.get_promotion().is_none()
540                && !self.move_sorter.is_killer_move(move_, self.ply);
541            if move_index != 0 && futility_pruning && not_an_interesting_position {
542                continue;
543            }
544            let mut safe_to_apply_lmr = self.properties.use_lmr()
545                && move_index >= FULL_DEPTH_SEARCH_LMR
546                && depth >= REDUCTION_LIMIT_LMR
547                && not_an_interesting_position;
548            unsafe { self.push_unchecked(move_) };
549            safe_to_apply_lmr &= !self.board.is_check();
550            let mut score: Score;
551            if move_index == 0 {
552                score = -self.alpha_beta(depth - 1, -beta, -alpha, controller.as_deref_mut())?;
553            } else {
554                if safe_to_apply_lmr {
555                    let lmr_reduction = Self::get_lmr_reduction(depth, move_index, is_pv_node);
556                    score = if depth > lmr_reduction {
557                        -self.alpha_beta(
558                            depth - 1 - lmr_reduction,
559                            -alpha - 1,
560                            -alpha,
561                            controller.as_deref_mut(),
562                        )?
563                    } else {
564                        alpha + 1
565                    }
566                } else {
567                    score = alpha + 1;
568                }
569                if score > alpha {
570                    score = -self.alpha_beta(
571                        depth - 1,
572                        -alpha - 1,
573                        -alpha,
574                        controller.as_deref_mut(),
575                    )?;
576                    if score > alpha && score < beta {
577                        score = -self.alpha_beta(
578                            depth - 1,
579                            -beta,
580                            -alpha,
581                            controller.as_deref_mut(),
582                        )?;
583                    }
584                }
585            }
586            self.pop().unwrap();
587            if score > alpha {
588                flag = EntryFlagHash::Exact;
589                self.pv_table.update_table(self.ply, move_);
590                alpha = score;
591                if not_capture_move {
592                    self.move_sorter.add_history_move(move_, &self.board, depth);
593                }
594                if score >= beta {
595                    self.transposition_table.write(
596                        key,
597                        depth,
598                        self.ply,
599                        beta,
600                        EntryFlagHash::Beta,
601                        Some(move_),
602                    );
603                    if not_capture_move {
604                        self.move_sorter.update_killer_moves(move_, self.ply);
605                    }
606                    return Some(beta);
607                }
608            }
609        }
610        self.transposition_table.write(
611            key,
612            depth,
613            self.ply,
614            alpha,
615            flag,
616            self.get_nth_pv_move(self.ply).copied(),
617        );
618        Some(alpha)
619    }
620
621    fn quiescence(
622        &mut self,
623        mut alpha: Score,
624        beta: Score,
625        mut controller: Option<&mut impl SearchControl<Self>>,
626    ) -> Option<Score> {
627        if self.ply == MAX_PLY - 1 {
628            return Some(self.evaluate_flipped());
629        }
630        self.pv_table.set_length(self.ply, self.ply);
631        if self.board.is_other_draw() {
632            return Some(self.evaluator.evaluate_draw());
633        }
634        let is_pv_node = alpha != beta - 1;
635        if self.is_main_threaded() && is_pv_node {
636            self.selective_depth.fetch_max(self.ply, MEMORY_ORDERING);
637        }
638        self.num_nodes_searched.fetch_add(1, MEMORY_ORDERING);
639        let evaluation = self.evaluate_flipped();
640        if evaluation >= beta {
641            return Some(beta);
642        }
643        if self.stop_search_at_every_node(controller.as_deref_mut()) {
644            return None;
645        }
646        alpha = alpha.max(evaluation);
647        for WeightedMove { move_, weight } in self
648            .move_sorter
649            .get_weighted_capture_moves_sorted(&self.board, &self.transposition_table)
650        {
651            if weight.is_negative() {
652                break;
653            }
654            unsafe { self.push_unchecked(move_) };
655            let score = -self.quiescence(-beta, -alpha, controller.as_deref_mut())?;
656            self.pop().unwrap();
657            if score >= beta {
658                return Some(beta);
659            }
660            if score > alpha {
661                self.pv_table.update_table(self.ply, move_);
662                alpha = score;
663            }
664            // delta pruning
665            let mut delta = const { Queen.evaluate() };
666            if let Some(piece) = move_.get_promotion() {
667                delta += piece.evaluate() - PAWN_VALUE;
668            }
669            if score + delta < alpha {
670                return Some(alpha);
671            }
672        }
673        Some(alpha)
674    }
675
676    pub fn search(
677        &mut self,
678        config: &SearchConfig,
679        mut controller: impl SearchControl<Self>,
680        verbose: bool,
681    ) {
682        let legal_moves = self.board.generate_legal_moves();
683        if legal_moves.len() == 1 {
684            self.pv_table
685                .update_table(self.ply, legal_moves.into_iter().next().unwrap());
686            return;
687        }
688        controller.on_receiving_search_config(config, self);
689        let mut alpha = -INFINITY;
690        let mut beta = INFINITY;
691        self.depth_completed = 0;
692        while self.depth_completed < Depth::MAX
693            && !self.stop_command.load(MEMORY_ORDERING)
694            && !controller.stop_search_at_root_node(self)
695        {
696            let last_score = self.score;
697            self.search_root(
698                self.depth_completed + 1,
699                alpha,
700                beta,
701                Some(&mut controller),
702                verbose,
703            );
704            if self.root_score_cached != -INFINITY {
705                self.score = self.root_score_cached;
706            }
707            let search_info = self.get_search_info();
708            if verbose && self.is_main_threaded() {
709                search_info.print_info();
710            }
711            controller.on_each_search_completion(self);
712            self.is_outside_aspiration_window = self.score <= alpha || self.score >= beta;
713            if self.is_outside_aspiration_window {
714                if verbose && self.is_main_threaded() {
715                    search_info.print_warning_message(alpha, beta);
716                }
717                alpha = -INFINITY;
718                beta = INFINITY;
719                self.score = last_score;
720                continue;
721            }
722            alpha = self.score - ASPIRATION_WINDOW_CUTOFF;
723            beta = self.score + ASPIRATION_WINDOW_CUTOFF;
724            self.depth_completed += 1;
725        }
726    }
727}
728
729impl<P: PositionEvaluation> SearcherMethodOverload<Move> for Searcher<P> {
730    unsafe fn push_unchecked(&mut self, move_: Move) {
731        self.board.push_unchecked(move_);
732        self.ply += 1;
733    }
734}
735
736impl<P: PositionEvaluation> SearcherMethodOverload<ValidOrNullMove> for Searcher<P> {
737    unsafe fn push_unchecked(&mut self, valid_or_null_move: ValidOrNullMove) {
738        self.board.push_unchecked(valid_or_null_move);
739        self.ply += 1;
740    }
741}