board_game/util/
game_stats.rs

1//! Utilities for collecting game statistics and testing game and bot implementations.
2use std::collections::{HashMap, HashSet};
3use std::hash::Hash;
4
5use internal_iterator::InternalIterator;
6use rand::Rng;
7
8use crate::ai::Bot;
9use crate::board::{Board, Player};
10use crate::pov::NonPov;
11use crate::wdl::WDL;
12
13/// The number of legal positions reachable after `depth` moves, including duplicates.
14/// See <https://www.chessprogramming.org/Perft>.
15pub fn perft<B: Board + Hash>(board: &B, depth: u32) -> u64 {
16    let mut map = HashMap::default();
17    perft_recurse(&mut map, board.clone(), depth)
18}
19
20fn perft_recurse<B: Board + Hash>(map: &mut HashMap<(B, u32), u64>, board: B, depth: u32) -> u64 {
21    if depth == 0 {
22        return 1;
23    }
24    if board.is_done() {
25        return 0;
26    }
27    if depth == 1 {
28        return board.available_moves().unwrap().count() as u64;
29    }
30
31    // we need keys (B, depth) because otherwise we risk miscounting if the same board is encountered at different depths
32    let key = (board, depth);
33    let board = &key.0;
34
35    if let Some(&p) = map.get(&key) {
36        return p;
37    }
38
39    let mut p = 0;
40    board.children().unwrap().for_each(|(_, child)| {
41        p += perft_recurse(map, child, depth - 1);
42    });
43
44    map.insert(key, p);
45    p
46}
47
48/// Same as [perft] but without any caching of perft values for visited boards.
49pub fn perft_naive<B: Board>(board: &B, depth: u32) -> u64 {
50    if depth == 0 {
51        return 1;
52    }
53    if board.is_done() {
54        return 0;
55    }
56    if depth == 1 {
57        return board.available_moves().unwrap().count() as u64;
58    }
59
60    let mut p = 0;
61    board.available_moves().unwrap().for_each(|mv: B::Move| {
62        p += perft_naive(&board.clone_and_play(mv).unwrap(), depth - 1);
63    });
64    p
65}
66
67/// Structure returned by [`average_game_stats`].
68#[derive(Debug)]
69pub struct GameStats {
70    pub game_length: f32,
71    pub available_moves: f32,
72    pub total_wdl_a: WDL<u64>,
73}
74
75/// Return `GameStats` estimated from `n` games starting from `start` played by `bot`.
76pub fn average_game_stats<B: Board>(mut start: impl FnMut() -> B, mut bot: impl Bot<B>, n: u64) -> GameStats {
77    let mut total_moves = 0;
78    let mut total_positions = 0;
79    let mut total_wdl_a = WDL::default();
80
81    for _ in 0..n {
82        let mut board = start();
83
84        let outcome = loop {
85            total_moves += board.available_moves().unwrap().count();
86            total_positions += 1;
87
88            board.play(bot.select_move(&board).unwrap()).unwrap();
89
90            if let Some(outcome) = board.outcome() {
91                break outcome;
92            }
93        };
94
95        total_wdl_a += outcome.pov(Player::A).to_wdl();
96    }
97
98    GameStats {
99        game_length: total_positions as f32 / n as f32,
100        available_moves: total_moves as f32 / total_positions as f32,
101        total_wdl_a,
102    }
103}
104
105/// Generate the set of all possible board positions reachable from the given board, in `depth` moves or less.
106/// The returned vec does not contain duplicate elements.
107///
108/// **Warning**: This function can easily take a long time to terminate or not terminate at all depending on the game.
109pub fn all_possible_boards<B: Board + Hash>(start: &B, depth: u32, include_done: bool) -> Vec<B> {
110    let mut set = HashSet::new();
111    let mut result = vec![];
112    all_possible_boards_impl(start, depth, include_done, &mut result, &mut set);
113    result
114}
115
116fn all_possible_boards_impl<B: Board + Hash>(
117    start: &B,
118    depth: u32,
119    include_done: bool,
120    result: &mut Vec<B>,
121    set: &mut HashSet<B>,
122) {
123    if !include_done && start.is_done() {
124        return;
125    }
126    if !set.insert(start.clone()) {
127        return;
128    }
129    result.push(start.clone());
130    if start.is_done() || depth == 0 {
131        return;
132    }
133
134    start
135        .children()
136        .unwrap()
137        .for_each(|(_, child)| all_possible_boards_impl(&child, depth - 1, include_done, result, set))
138}
139
140/// Collect all available moves form `n` games played until the end with random moves.
141/// Also returns the number of time each move was availalbe.
142pub fn all_available_moves_sampled<B: Board>(start: &B, n: u64, rng: &mut impl Rng) -> HashMap<B::Move, u64>
143where
144    B::Move: Hash,
145{
146    let mut moves = HashMap::default();
147
148    for _ in 0..n {
149        let mut curr = start.clone();
150        while !curr.is_done() {
151            curr.available_moves().unwrap().for_each(|mv| {
152                *moves.entry(mv).or_default() += 1;
153            });
154            curr.play_random_available_move(rng).unwrap();
155        }
156    }
157
158    moves
159}