mcts/
lib.rs

1
2//! This is a library for Monte Carlo tree search. 
3//! 
4//! It is still under development and the documentation isn't good. However, the following example may be helpful:
5//!
6//! ```
7//! use mcts::*;
8//! use mcts::tree_policy::*;
9//! use mcts::transposition_table::*;
10//! 
11//! // A really simple game. There's one player and one number. In each move the player can
12//! // increase or decrease the number. The player's score is the number.
13//! // The game ends when the number reaches 100.
14//! // 
15//! // The best strategy is to increase the number at every step.
16//!
17//! #[derive(Clone, Debug, PartialEq)]
18//! struct CountingGame(i64);
19//! 
20//! #[derive(Clone, Debug, PartialEq)]
21//! enum Move {
22//!     Add, Sub
23//! }
24//! 
25//! impl GameState for CountingGame {
26//!     type Move = Move;
27//!     type Player = ();
28//!     type MoveList = Vec<Move>;
29//! 
30//!     fn current_player(&self) -> Self::Player {
31//!         ()
32//!     }
33//!     fn available_moves(&self) -> Vec<Move> {
34//!         let x = self.0;
35//!         if x == 100 {
36//!             vec![]
37//!         } else {
38//!             vec![Move::Add, Move::Sub]
39//!         }
40//!     }
41//!     fn make_move(&mut self, mov: &Self::Move) {
42//!         match *mov {
43//!             Move::Add => self.0 += 1,
44//!             Move::Sub => self.0 -= 1,
45//!         }
46//!     }
47//! }
48//! 
49//! impl TranspositionHash for CountingGame {
50//!     fn hash(&self) -> u64 {
51//!         self.0 as u64
52//!     }
53//! }
54//! 
55//! struct MyEvaluator;
56//! 
57//! impl Evaluator<MyMCTS> for MyEvaluator {
58//!     type StateEvaluation = i64;
59//! 
60//!     fn evaluate_new_state(&self, state: &CountingGame, moves: &Vec<Move>,
61//!         _: Option<SearchHandle<MyMCTS>>)
62//!         -> (Vec<()>, i64) {
63//!         (vec![(); moves.len()], state.0)
64//!     }
65//!     fn interpret_evaluation_for_player(&self, evaln: &i64, _player: &()) -> i64 {
66//!         *evaln
67//!     }
68//!     fn evaluate_existing_state(&self, _: &CountingGame,  evaln: &i64, _: SearchHandle<MyMCTS>) -> i64 {
69//!         *evaln
70//!     }
71//! }
72//! 
73//! #[derive(Default)]
74//! struct MyMCTS;
75//! 
76//! impl MCTS for MyMCTS {
77//!     type State = CountingGame;
78//!     type Eval = MyEvaluator;
79//!     type NodeData = ();
80//!     type ExtraThreadData = ();
81//!     type TreePolicy = UCTPolicy;
82//!     type TranspositionTable = ApproxTable<Self>;
83//!
84//!     fn cycle_behaviour(&self) -> CycleBehaviour<Self> {
85//!         CycleBehaviour::UseCurrentEvalWhenCycleDetected
86//!     }
87//! }
88//! 
89//! let game = CountingGame(0);
90//! let mut mcts = MCTSManager::new(game, MyMCTS, MyEvaluator, UCTPolicy::new(0.5),
91//!     ApproxTable::new(1024));
92//! mcts.playout_n_parallel(10000, 4); // 10000 playouts, 4 search threads
93//! mcts.tree().debug_moves();
94//! assert_eq!(mcts.best_move().unwrap(), Move::Add);
95//! assert_eq!(mcts.principal_variation(50),
96//!     vec![Move::Add; 50]);
97//! assert_eq!(mcts.principal_variation_states(5),
98//!     vec![
99//!         CountingGame(0),
100//!         CountingGame(1),
101//!         CountingGame(2),
102//!         CountingGame(3),
103//!         CountingGame(4),
104//!         CountingGame(5)]);
105//! ```
106
107extern crate crossbeam;
108extern crate smallvec;
109
110mod search_tree;
111mod atomics;
112pub mod tree_policy;
113pub mod transposition_table;
114
115pub use search_tree::*;
116use tree_policy::*;
117use transposition_table::*;
118
119use atomics::*;
120use std::sync::Arc;
121use std::thread::JoinHandle;
122use std::time::Duration;
123
124pub trait MCTS: Sized + Sync {
125    type State: GameState + Sync;
126    type Eval: Evaluator<Self>;
127    type TreePolicy: TreePolicy<Self>;
128    type NodeData: Default + Sync + Send;
129    type TranspositionTable: TranspositionTable<Self>;
130    type ExtraThreadData;
131
132    fn virtual_loss(&self) -> i64 {
133        0
134    }
135    fn visits_before_expansion(&self) -> u64 {
136        1
137    }
138    fn node_limit(&self) -> usize {
139        std::usize::MAX
140    }
141    fn select_child_after_search<'a>(&self, children: &'a [MoveInfo<Self>]) -> &'a MoveInfo<Self> {
142        children.into_iter().max_by_key(|child| child.visits()).unwrap()
143    }
144    /// `playout` panics when this length is exceeded. Defaults to one million.
145    fn max_playout_length(&self) -> usize {
146        1_000_000
147    }
148    fn on_backpropagation(&self, _evaln: &StateEvaluation<Self>, _handle: SearchHandle<Self>) {}
149    fn cycle_behaviour(&self) -> CycleBehaviour<Self> {
150        if std::mem::size_of::<Self::TranspositionTable>() == 0 {
151            CycleBehaviour::Ignore
152        } else {
153            CycleBehaviour::PanicWhenCycleDetected
154        }
155    }
156}
157
158pub struct ThreadData<Spec: MCTS> {
159    pub policy_data: TreePolicyThreadData<Spec>,
160    pub extra_data: Spec::ExtraThreadData,
161}
162
163impl<Spec: MCTS> Default for ThreadData<Spec>
164    where TreePolicyThreadData<Spec>: Default, Spec::ExtraThreadData: Default
165{
166    fn default() -> Self {
167        Self {
168            policy_data: Default::default(),
169            extra_data: Default::default(),
170        }
171    }
172} 
173
174pub type MoveEvaluation<Spec> = <<Spec as MCTS>::TreePolicy as TreePolicy<Spec>>::MoveEvaluation;
175pub type StateEvaluation<Spec> = <<Spec as MCTS>::Eval as Evaluator<Spec>>::StateEvaluation;
176pub type Move<Spec> = <<Spec as MCTS>::State as GameState>::Move;
177pub type MoveList<Spec> = <<Spec as MCTS>::State as GameState>::MoveList;
178pub type Player<Spec> = <<Spec as MCTS>::State as GameState>::Player;
179pub type TreePolicyThreadData<Spec> = <<Spec as MCTS>::TreePolicy as TreePolicy<Spec>>::ThreadLocalData;
180
181pub trait GameState: Clone {
182    type Move: Sync + Send + Clone;
183    type Player: Sync;
184    type MoveList: std::iter::IntoIterator<Item=Self::Move>;
185
186    fn current_player(&self) -> Self::Player;
187    fn available_moves(&self) -> Self::MoveList;
188    fn make_move(&mut self, mov: &Self::Move);
189}
190
191pub trait Evaluator<Spec: MCTS>: Sync {
192    type StateEvaluation: Sync + Send;
193
194    fn evaluate_new_state(&self,
195        state: &Spec::State, moves: &MoveList<Spec>,
196        handle: Option<SearchHandle<Spec>>)
197        -> (Vec<MoveEvaluation<Spec>>, Self::StateEvaluation);
198
199    fn evaluate_existing_state(&self, state: &Spec::State, existing_evaln: &Self::StateEvaluation,
200        handle: SearchHandle<Spec>)
201        -> Self::StateEvaluation;
202
203    fn interpret_evaluation_for_player(&self,
204        evaluation: &Self::StateEvaluation,
205        player: &Player<Spec>) -> i64;
206}
207
208
209pub struct MCTSManager<Spec: MCTS> {
210    search_tree: SearchTree<Spec>,
211    // thread local data when we have no asynchronous workers
212    single_threaded_tld: Option<ThreadData<Spec>>,
213    print_on_playout_error: bool,
214}
215
216impl<Spec: MCTS> MCTSManager<Spec> where ThreadData<Spec>: Default {
217    pub fn new(state: Spec::State, manager: Spec, eval: Spec::Eval, tree_policy: Spec::TreePolicy,
218            table: Spec::TranspositionTable) -> Self {
219        let search_tree = SearchTree::new(state, manager, tree_policy, eval, table);
220        let single_threaded_tld = None;
221        Self {search_tree, single_threaded_tld, print_on_playout_error: true}
222    }
223
224    pub fn print_on_playout_error(&mut self, v: bool) -> &mut Self {
225        self.print_on_playout_error = v;
226        self
227    }
228
229    pub fn playout(&mut self) {
230        // Avoid overhead of thread creation
231        if self.single_threaded_tld.is_none() {
232            self.single_threaded_tld = Some(Default::default());
233        }
234        self.search_tree.playout(self.single_threaded_tld.as_mut().unwrap());
235    }
236    pub fn playout_until<Predicate: FnMut() -> bool>(&mut self, mut pred: Predicate) {
237        while !pred() {
238            self.playout();
239        }
240    }
241    pub fn playout_n(&mut self, n: u64) {
242        for _ in 0..n {
243            self.playout();
244        }
245    }
246    unsafe fn spawn_worker_thread(&self, stop_signal: Arc<AtomicBool>) -> JoinHandle<()> {
247        let search_tree = &self.search_tree;
248        let print_on_playout_error = self.print_on_playout_error;
249        crossbeam::spawn_unsafe(move || {
250            let mut tld = Default::default();
251            loop {
252                if stop_signal.load(Ordering::SeqCst) {
253                    break;
254                }
255                if !search_tree.playout(&mut tld) {
256                    if print_on_playout_error {
257                        eprintln!("Node limit of {} reached. Halting search.",
258                            search_tree.spec().node_limit());
259                    }
260                    break;
261                }
262            }
263        })
264    }
265    pub fn playout_parallel_async<'a>(&'a mut self, num_threads: usize) -> AsyncSearch<'a, Spec> {
266        assert!(num_threads != 0);
267        let stop_signal = Arc::new(AtomicBool::new(false));
268        let threads = (0..num_threads).map(|_| {
269            let stop_signal = stop_signal.clone();
270            unsafe {
271                self.spawn_worker_thread(stop_signal)
272            }
273        }).collect();
274        AsyncSearch {
275            manager: self,
276            stop_signal,
277            threads,
278        }
279    }
280    pub fn into_playout_parallel_async(self, num_threads: usize) -> AsyncSearchOwned<Spec> {
281        assert!(num_threads != 0);
282        let self_box = Box::new(self);
283        let stop_signal = Arc::new(AtomicBool::new(false));
284        let threads = (0..num_threads).map(|_| {
285            let stop_signal = stop_signal.clone();
286            unsafe {
287                self_box.spawn_worker_thread(stop_signal)
288            }
289        }).collect();
290        AsyncSearchOwned {
291            manager: Some(self_box),
292            stop_signal,
293            threads
294        }
295    }
296    pub fn playout_parallel_for(&mut self, duration: Duration, num_threads: usize) {
297        let search = self.playout_parallel_async(num_threads);
298        std::thread::sleep(duration);
299        search.halt();
300    }
301    pub fn playout_n_parallel(&mut self, n: u32, num_threads: usize) {
302        if n == 0 {
303            return;
304        }
305        assert!(num_threads != 0);
306        let counter = AtomicIsize::new(n as isize);
307        let search_tree = &self.search_tree;
308        crossbeam::scope(|scope| {
309            for _ in 0..num_threads {
310                scope.spawn(|| {
311                    let mut tld = Default::default();
312                    loop {
313                        let count = counter.fetch_sub(1, Ordering::SeqCst);
314                        if count <= 0 {
315                            break;
316                        }
317                        search_tree.playout(&mut tld);
318                    }
319                });
320            }
321        });
322    }
323    pub fn principal_variation_info(&self, num_moves: usize) -> Vec<MoveInfoHandle<Spec>> {
324        self.search_tree.principal_variation(num_moves)
325    }
326    pub fn principal_variation(&self, num_moves: usize) -> Vec<Move<Spec>> {
327        self.search_tree.principal_variation(num_moves)
328            .into_iter()
329            .map(|x| x.get_move())
330            .map(|x| x.clone())
331            .collect()
332    }
333    pub fn principal_variation_states(&self, num_moves: usize)
334            -> Vec<Spec::State> {
335        let moves = self.principal_variation(num_moves);
336        let mut states = vec![self.search_tree.root_state().clone()];
337        for mov in moves {
338            let mut state = states[states.len() - 1].clone();
339            state.make_move(&mov);
340            states.push(state);
341        }
342        states
343    }
344    pub fn tree(&self) -> &SearchTree<Spec> {&self.search_tree}
345    pub fn best_move(&self) -> Option<Move<Spec>> {
346        self.principal_variation(1).get(0).map(|x| x.clone())
347    }
348    pub fn perf_test<F>(&mut self, num_threads: usize, mut f: F) where F: FnMut(usize) {
349        let search = self.playout_parallel_async(num_threads);
350        for _ in 0..10 {
351            let n1 = search.manager.search_tree.num_nodes();
352            std::thread::sleep(Duration::from_secs(1));
353            let n2 = search.manager.search_tree.num_nodes();
354            let diff = if n2 > n1 {
355                n2 - n1
356            } else {
357                0
358            };
359            f(diff);
360        }
361    }
362    pub fn perf_test_to_stderr(&mut self, num_threads: usize) {
363        self.perf_test(num_threads, |x| eprintln!("{} nodes/sec", thousands_separate(x)));
364    }
365    pub fn reset(self) -> Self {
366        Self {
367            search_tree: self.search_tree.reset(),
368            print_on_playout_error: self.print_on_playout_error,
369            single_threaded_tld: None,
370        }
371    }
372}
373
374// https://stackoverflow.com/questions/26998485/rust-print-format-number-with-thousand-separator
375fn thousands_separate(x: usize) -> String {
376    let s = format!("{}", x);
377    let bytes: Vec<_> = s.bytes().rev().collect();
378    let chunks: Vec<_> = bytes.chunks(3).map(|chunk| String::from_utf8(chunk.to_vec()).unwrap()).collect();
379    let result: Vec<_> = chunks.join(",").bytes().rev().collect();
380    String::from_utf8(result).unwrap()
381}
382
383#[must_use]
384pub struct AsyncSearch<'a, Spec: 'a + MCTS> {
385    manager: &'a mut MCTSManager<Spec>,
386    stop_signal: Arc<AtomicBool>,
387    threads: Vec<JoinHandle<()>>,
388}
389
390impl<'a, Spec: MCTS> AsyncSearch<'a, Spec> {
391    pub fn halt(self) {}
392    pub fn num_threads(&self) -> usize {
393        self.threads.len()
394    }
395}
396
397impl<'a, Spec: MCTS> Drop for AsyncSearch<'a, Spec> {
398    fn drop(&mut self) {
399        self.stop_signal.store(true, Ordering::SeqCst);
400        drain_join_unwrap(&mut self.threads);
401    }
402}
403
404#[must_use]
405pub struct AsyncSearchOwned<Spec: MCTS> {
406    manager: Option<Box<MCTSManager<Spec>>>,
407    stop_signal: Arc<AtomicBool>,
408    threads: Vec<JoinHandle<()>>,
409}
410
411impl<Spec: MCTS> AsyncSearchOwned<Spec> {
412    fn stop_threads(&mut self) {
413        self.stop_signal.store(true, Ordering::SeqCst);
414        drain_join_unwrap(&mut self.threads);
415    }
416    pub fn halt(mut self) -> MCTSManager<Spec> {
417        self.stop_threads();
418        *self.manager.take().unwrap()
419    }
420    pub fn num_threads(&self) -> usize {
421        self.threads.len()
422    }
423}
424
425impl<Spec: MCTS> Drop for AsyncSearchOwned<Spec> {
426    fn drop(&mut self) {
427        self.stop_threads();
428    }
429}
430
431impl<Spec: MCTS> From<MCTSManager<Spec>> for AsyncSearchOwned<Spec> {
432    /// An `MCTSManager` is an `AsyncSearchOwned` with zero threads searching.
433    fn from(m: MCTSManager<Spec>) -> Self {
434        Self {
435            manager: Some(Box::new(m)),
436            stop_signal: Arc::new(AtomicBool::new(false)),
437            threads: Vec::new(),
438        }
439    }
440}
441
442fn drain_join_unwrap(threads: &mut Vec<JoinHandle<()>>) {
443    let join_results: Vec<_> = threads.drain(..).map(|x| x.join()).collect();
444    for x in join_results {
445        x.unwrap();
446    }
447}
448
449pub enum CycleBehaviour<Spec: MCTS> {
450    Ignore,
451    UseCurrentEvalWhenCycleDetected,
452    PanicWhenCycleDetected,
453    UseThisEvalWhenCycleDetected(StateEvaluation<Spec>),
454}