board_game/ai/
simple.rs

1//! Two simple bots: `RandomBot` and `RolloutBot`.
2use std::fmt::{Debug, Formatter};
3
4use internal_iterator::InternalIterator;
5use rand::Rng;
6
7use crate::ai::Bot;
8use crate::board::{Board, BoardDone};
9use crate::pov::NonPov;
10
11/// Bot that chooses moves randomly uniformly among possible moves.
12pub struct RandomBot<R: Rng> {
13    rng: R,
14}
15
16impl<R: Rng> Debug for RandomBot<R> {
17    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18        write!(f, "RandomBot")
19    }
20}
21
22impl<R: Rng> RandomBot<R> {
23    pub fn new(rng: R) -> Self {
24        RandomBot { rng }
25    }
26}
27
28impl<B: Board, R: Rng> Bot<B> for RandomBot<R> {
29    fn select_move(&mut self, board: &B) -> Result<B::Move, BoardDone> {
30        board.random_available_move(&mut self.rng)
31    }
32}
33
34/// Bot that chooses moves after simulating random games for each of them.
35///
36/// The same number of simulations `rollouts / nb_moves` is done for
37/// each move, and the move resulting in the best average score is selected.
38pub struct RolloutBot<R: Rng> {
39    rollouts: u32,
40    rng: R,
41}
42
43impl<R: Rng> Debug for RolloutBot<R> {
44    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45        write!(f, "RolloutBot {{ rollouts: {} }}", self.rollouts)
46    }
47}
48
49impl<R: Rng> RolloutBot<R> {
50    pub fn new(rollouts: u32, rng: R) -> Self {
51        RolloutBot { rollouts, rng }
52    }
53}
54
55impl<B: Board, R: Rng> Bot<B> for RolloutBot<R> {
56    fn select_move(&mut self, board: &B) -> Result<B::Move, BoardDone> {
57        let rollouts_per_move = self.rollouts / board.available_moves().unwrap().count() as u32;
58
59        Ok(board
60            .children()?
61            .max_by_key(|(_, child)| {
62                let score: i64 = (0..rollouts_per_move)
63                    .map(|_| {
64                        let mut copy = child.clone();
65                        while let Ok(mv) = copy.random_available_move(&mut self.rng) {
66                            copy.play(mv).unwrap();
67                        }
68                        copy.outcome().unwrap().pov(board.next_player()).sign::<i64>()
69                    })
70                    .sum();
71                score
72            })
73            .map(|(mv, _)| mv)
74            .unwrap())
75    }
76}