mcts_lib/boards/
tic_tac_toe.rs

1use crate::board::{Board, GameOutcome, Player};
2use std::fmt::Debug;
3
4/// An implementation of the `Board` trait for the game of Tic-Tac-Toe.
5///
6/// The board is represented by a 9-element array, where each element corresponds to a cell.
7/// A move is represented by a `u8` from 0 to 8.
8pub struct TicTacToeBoard {
9    root_player: TTTPlayer,
10    current_player: TTTPlayer,
11    field: [Option<TTTPlayer>; 9],
12    outcome: GameOutcome,
13}
14
15impl TicTacToeBoard {
16    fn new(root_player: TTTPlayer) -> Self {
17        Self {
18            root_player,
19            current_player: TTTPlayer::X,
20            field: [None; 9],
21            outcome: GameOutcome::InProgress,
22        }
23    }
24}
25
26impl Default for TicTacToeBoard {
27    /// Creates a new Tic-Tac-Toe board with player 'X' starting.
28    fn default() -> Self {
29        TicTacToeBoard::new(TTTPlayer::X)
30    }
31}
32
33impl Clone for TicTacToeBoard {
34    fn clone(&self) -> Self {
35        let mut copied_field = [None; 9];
36        copied_field.copy_from_slice(&self.field);
37        Self {
38            root_player: self.root_player,
39            current_player: self.current_player,
40            field: copied_field,
41            outcome: self.outcome,
42        }
43    }
44}
45
46impl Board for TicTacToeBoard {
47    type Move = u8;
48
49    fn get_current_player(&self) -> Player {
50        match self.current_player == self.root_player {
51            true => Player::Me,
52            false => Player::Other,
53        }
54    }
55
56    fn get_outcome(&self) -> GameOutcome {
57        if self.field[0].is_some()
58            && (self.field[0] == self.field[1] && self.field[0] == self.field[2]
59                || self.field[0] == self.field[3] && self.field[0] == self.field[6])
60        {
61            return if self.field[0].unwrap() == self.root_player {
62                GameOutcome::Win
63            } else {
64                GameOutcome::Lose
65            };
66        }
67
68        if self.field[8].is_some()
69            && (self.field[8] == self.field[2] && self.field[8] == self.field[5]
70                || self.field[8] == self.field[6] && self.field[8] == self.field[7])
71        {
72            return if self.field[8].unwrap() == self.root_player {
73                GameOutcome::Win
74            } else {
75                GameOutcome::Lose
76            };
77        }
78
79        if self.field[4].is_some()
80            && (self.field[4] == self.field[1] && self.field[4] == self.field[7]
81                || self.field[4] == self.field[3] && self.field[4] == self.field[5]
82                || self.field[4] == self.field[0] && self.field[4] == self.field[8]
83                || self.field[4] == self.field[2] && self.field[4] == self.field[6])
84        {
85            return if self.field[4].unwrap() == self.root_player {
86                GameOutcome::Win
87            } else {
88                GameOutcome::Lose
89            };
90        }
91
92        if self.field.iter().any(|x| x.is_none()) {
93            GameOutcome::InProgress
94        } else {
95            GameOutcome::Draw
96        }
97    }
98
99    fn get_available_moves(&self) -> Vec<Self::Move> {
100        if self.outcome != GameOutcome::InProgress {
101            return Vec::new();
102        }
103
104        self.field
105            .iter()
106            .enumerate()
107            .filter(|(_, x)| x.is_none())
108            .map(|(i, _)| i as u8)
109            .collect()
110    }
111
112    fn perform_move(&mut self, b_move: &Self::Move) {
113        self.field[*b_move as usize] = Some(self.current_player);
114        self.current_player = match self.current_player {
115            TTTPlayer::X => TTTPlayer::O,
116            TTTPlayer::O => TTTPlayer::X,
117        };
118        self.outcome = self.get_outcome();
119    }
120
121    fn get_hash(&self) -> u128 {
122        let mut hash = 0;
123        for (i, &cell) in self.field.iter().enumerate() {
124            let cell_value = match cell {
125                None => 0,
126                Some(TTTPlayer::X) => 1,
127                Some(TTTPlayer::O) => 2,
128            };
129            hash += cell_value * 3u128.pow(i as u32);
130        }
131        hash
132    }
133}
134
135#[derive(Debug, PartialEq, Copy, Clone)]
136enum TTTPlayer {
137    X,
138    O,
139}
140
141#[cfg(test)]
142mod tests {
143    use crate::boards::tic_tac_toe::TicTacToeBoard;
144    use crate::mcts::MonteCarloTreeSearch;
145    use crate::random::CustomNumberGenerator;
146
147    #[test]
148    fn test1_usual() {
149        // arrange
150        let board = TicTacToeBoard::default();
151        let mut mcts = MonteCarloTreeSearch::builder(board)
152            .with_alpha_beta_pruning(false)
153            .with_random_generator(CustomNumberGenerator::default())
154            .build();
155
156        // act
157        mcts.iterate_n_times(20000);
158
159        // assert
160        let best_node = &mcts.get_root().get_best_child().unwrap().value();
161        assert_eq!(best_node.prev_move.unwrap(), 4);
162        let root = &mcts.get_root().value();
163        assert_eq!(root.wins, 13867);
164        assert_eq!(root.draws, 2104);
165        assert_eq!(root.visits, 20000);
166        assert!(!root.is_fully_calculated);
167    }
168
169    #[test]
170    fn test2_abp() {
171        // arrange
172        let board = TicTacToeBoard::default();
173        let mut mcts = MonteCarloTreeSearch::builder(board)
174            .with_random_generator(CustomNumberGenerator::default())
175            .build();
176
177        // act
178        mcts.iterate_n_times(20000);
179
180        // assert
181        let best_node = &mcts.get_root().get_best_child().unwrap().value();
182        assert_eq!(best_node.prev_move.unwrap(), 4);
183        let root = &mcts.get_root().value();
184        assert_eq!(root.wins, 10758);
185        assert_eq!(root.draws, 3808);
186        assert_eq!(root.visits, 20000);
187        assert!(!root.is_fully_calculated);
188    }
189
190    #[test]
191    fn test3_abp_fully_calculated() {
192        // arrange
193        let board = TicTacToeBoard::default();
194        let mut mcts = MonteCarloTreeSearch::builder(board)
195            .with_random_generator(CustomNumberGenerator::default())
196            .build();
197
198        // act
199        mcts.iterate_n_times(50000);
200
201        // assert
202        let best_node = &mcts.get_root().get_best_child().unwrap().value();
203        assert_eq!(best_node.prev_move.unwrap(), 4);
204        let root = &mcts.get_root().value();
205        assert_eq!(root.wins, 18225);
206        assert_eq!(root.draws, 10342);
207        assert_eq!(root.visits, 37432);
208        assert!(root.is_fully_calculated);
209    }
210}