1
2extern crate crossbeam;
108extern crate smallvec;
109
110mod search_tree;
111mod atomics;
112pub mod tree_policy;
113pub mod transposition_table;
114
115pub use search_tree::*;
116use tree_policy::*;
117use transposition_table::*;
118
119use atomics::*;
120use std::sync::Arc;
121use std::thread::JoinHandle;
122use std::time::Duration;
123
124pub trait MCTS: Sized + Sync {
125 type State: GameState + Sync;
126 type Eval: Evaluator<Self>;
127 type TreePolicy: TreePolicy<Self>;
128 type NodeData: Default + Sync + Send;
129 type TranspositionTable: TranspositionTable<Self>;
130 type ExtraThreadData;
131
132 fn virtual_loss(&self) -> i64 {
133 0
134 }
135 fn visits_before_expansion(&self) -> u64 {
136 1
137 }
138 fn node_limit(&self) -> usize {
139 std::usize::MAX
140 }
141 fn select_child_after_search<'a>(&self, children: &'a [MoveInfo<Self>]) -> &'a MoveInfo<Self> {
142 children.into_iter().max_by_key(|child| child.visits()).unwrap()
143 }
144 fn max_playout_length(&self) -> usize {
146 1_000_000
147 }
148 fn on_backpropagation(&self, _evaln: &StateEvaluation<Self>, _handle: SearchHandle<Self>) {}
149 fn cycle_behaviour(&self) -> CycleBehaviour<Self> {
150 if std::mem::size_of::<Self::TranspositionTable>() == 0 {
151 CycleBehaviour::Ignore
152 } else {
153 CycleBehaviour::PanicWhenCycleDetected
154 }
155 }
156}
157
158pub struct ThreadData<Spec: MCTS> {
159 pub policy_data: TreePolicyThreadData<Spec>,
160 pub extra_data: Spec::ExtraThreadData,
161}
162
163impl<Spec: MCTS> Default for ThreadData<Spec>
164 where TreePolicyThreadData<Spec>: Default, Spec::ExtraThreadData: Default
165{
166 fn default() -> Self {
167 Self {
168 policy_data: Default::default(),
169 extra_data: Default::default(),
170 }
171 }
172}
173
174pub type MoveEvaluation<Spec> = <<Spec as MCTS>::TreePolicy as TreePolicy<Spec>>::MoveEvaluation;
175pub type StateEvaluation<Spec> = <<Spec as MCTS>::Eval as Evaluator<Spec>>::StateEvaluation;
176pub type Move<Spec> = <<Spec as MCTS>::State as GameState>::Move;
177pub type MoveList<Spec> = <<Spec as MCTS>::State as GameState>::MoveList;
178pub type Player<Spec> = <<Spec as MCTS>::State as GameState>::Player;
179pub type TreePolicyThreadData<Spec> = <<Spec as MCTS>::TreePolicy as TreePolicy<Spec>>::ThreadLocalData;
180
181pub trait GameState: Clone {
182 type Move: Sync + Send + Clone;
183 type Player: Sync;
184 type MoveList: std::iter::IntoIterator<Item=Self::Move>;
185
186 fn current_player(&self) -> Self::Player;
187 fn available_moves(&self) -> Self::MoveList;
188 fn make_move(&mut self, mov: &Self::Move);
189}
190
191pub trait Evaluator<Spec: MCTS>: Sync {
192 type StateEvaluation: Sync + Send;
193
194 fn evaluate_new_state(&self,
195 state: &Spec::State, moves: &MoveList<Spec>,
196 handle: Option<SearchHandle<Spec>>)
197 -> (Vec<MoveEvaluation<Spec>>, Self::StateEvaluation);
198
199 fn evaluate_existing_state(&self, state: &Spec::State, existing_evaln: &Self::StateEvaluation,
200 handle: SearchHandle<Spec>)
201 -> Self::StateEvaluation;
202
203 fn interpret_evaluation_for_player(&self,
204 evaluation: &Self::StateEvaluation,
205 player: &Player<Spec>) -> i64;
206}
207
208
209pub struct MCTSManager<Spec: MCTS> {
210 search_tree: SearchTree<Spec>,
211 single_threaded_tld: Option<ThreadData<Spec>>,
213 print_on_playout_error: bool,
214}
215
216impl<Spec: MCTS> MCTSManager<Spec> where ThreadData<Spec>: Default {
217 pub fn new(state: Spec::State, manager: Spec, eval: Spec::Eval, tree_policy: Spec::TreePolicy,
218 table: Spec::TranspositionTable) -> Self {
219 let search_tree = SearchTree::new(state, manager, tree_policy, eval, table);
220 let single_threaded_tld = None;
221 Self {search_tree, single_threaded_tld, print_on_playout_error: true}
222 }
223
224 pub fn print_on_playout_error(&mut self, v: bool) -> &mut Self {
225 self.print_on_playout_error = v;
226 self
227 }
228
229 pub fn playout(&mut self) {
230 if self.single_threaded_tld.is_none() {
232 self.single_threaded_tld = Some(Default::default());
233 }
234 self.search_tree.playout(self.single_threaded_tld.as_mut().unwrap());
235 }
236 pub fn playout_until<Predicate: FnMut() -> bool>(&mut self, mut pred: Predicate) {
237 while !pred() {
238 self.playout();
239 }
240 }
241 pub fn playout_n(&mut self, n: u64) {
242 for _ in 0..n {
243 self.playout();
244 }
245 }
246 unsafe fn spawn_worker_thread(&self, stop_signal: Arc<AtomicBool>) -> JoinHandle<()> {
247 let search_tree = &self.search_tree;
248 let print_on_playout_error = self.print_on_playout_error;
249 crossbeam::spawn_unsafe(move || {
250 let mut tld = Default::default();
251 loop {
252 if stop_signal.load(Ordering::SeqCst) {
253 break;
254 }
255 if !search_tree.playout(&mut tld) {
256 if print_on_playout_error {
257 eprintln!("Node limit of {} reached. Halting search.",
258 search_tree.spec().node_limit());
259 }
260 break;
261 }
262 }
263 })
264 }
265 pub fn playout_parallel_async<'a>(&'a mut self, num_threads: usize) -> AsyncSearch<'a, Spec> {
266 assert!(num_threads != 0);
267 let stop_signal = Arc::new(AtomicBool::new(false));
268 let threads = (0..num_threads).map(|_| {
269 let stop_signal = stop_signal.clone();
270 unsafe {
271 self.spawn_worker_thread(stop_signal)
272 }
273 }).collect();
274 AsyncSearch {
275 manager: self,
276 stop_signal,
277 threads,
278 }
279 }
280 pub fn into_playout_parallel_async(self, num_threads: usize) -> AsyncSearchOwned<Spec> {
281 assert!(num_threads != 0);
282 let self_box = Box::new(self);
283 let stop_signal = Arc::new(AtomicBool::new(false));
284 let threads = (0..num_threads).map(|_| {
285 let stop_signal = stop_signal.clone();
286 unsafe {
287 self_box.spawn_worker_thread(stop_signal)
288 }
289 }).collect();
290 AsyncSearchOwned {
291 manager: Some(self_box),
292 stop_signal,
293 threads
294 }
295 }
296 pub fn playout_parallel_for(&mut self, duration: Duration, num_threads: usize) {
297 let search = self.playout_parallel_async(num_threads);
298 std::thread::sleep(duration);
299 search.halt();
300 }
301 pub fn playout_n_parallel(&mut self, n: u32, num_threads: usize) {
302 if n == 0 {
303 return;
304 }
305 assert!(num_threads != 0);
306 let counter = AtomicIsize::new(n as isize);
307 let search_tree = &self.search_tree;
308 crossbeam::scope(|scope| {
309 for _ in 0..num_threads {
310 scope.spawn(|| {
311 let mut tld = Default::default();
312 loop {
313 let count = counter.fetch_sub(1, Ordering::SeqCst);
314 if count <= 0 {
315 break;
316 }
317 search_tree.playout(&mut tld);
318 }
319 });
320 }
321 });
322 }
323 pub fn principal_variation_info(&self, num_moves: usize) -> Vec<MoveInfoHandle<Spec>> {
324 self.search_tree.principal_variation(num_moves)
325 }
326 pub fn principal_variation(&self, num_moves: usize) -> Vec<Move<Spec>> {
327 self.search_tree.principal_variation(num_moves)
328 .into_iter()
329 .map(|x| x.get_move())
330 .map(|x| x.clone())
331 .collect()
332 }
333 pub fn principal_variation_states(&self, num_moves: usize)
334 -> Vec<Spec::State> {
335 let moves = self.principal_variation(num_moves);
336 let mut states = vec![self.search_tree.root_state().clone()];
337 for mov in moves {
338 let mut state = states[states.len() - 1].clone();
339 state.make_move(&mov);
340 states.push(state);
341 }
342 states
343 }
344 pub fn tree(&self) -> &SearchTree<Spec> {&self.search_tree}
345 pub fn best_move(&self) -> Option<Move<Spec>> {
346 self.principal_variation(1).get(0).map(|x| x.clone())
347 }
348 pub fn perf_test<F>(&mut self, num_threads: usize, mut f: F) where F: FnMut(usize) {
349 let search = self.playout_parallel_async(num_threads);
350 for _ in 0..10 {
351 let n1 = search.manager.search_tree.num_nodes();
352 std::thread::sleep(Duration::from_secs(1));
353 let n2 = search.manager.search_tree.num_nodes();
354 let diff = if n2 > n1 {
355 n2 - n1
356 } else {
357 0
358 };
359 f(diff);
360 }
361 }
362 pub fn perf_test_to_stderr(&mut self, num_threads: usize) {
363 self.perf_test(num_threads, |x| eprintln!("{} nodes/sec", thousands_separate(x)));
364 }
365 pub fn reset(self) -> Self {
366 Self {
367 search_tree: self.search_tree.reset(),
368 print_on_playout_error: self.print_on_playout_error,
369 single_threaded_tld: None,
370 }
371 }
372}
373
374fn thousands_separate(x: usize) -> String {
376 let s = format!("{}", x);
377 let bytes: Vec<_> = s.bytes().rev().collect();
378 let chunks: Vec<_> = bytes.chunks(3).map(|chunk| String::from_utf8(chunk.to_vec()).unwrap()).collect();
379 let result: Vec<_> = chunks.join(",").bytes().rev().collect();
380 String::from_utf8(result).unwrap()
381}
382
383#[must_use]
384pub struct AsyncSearch<'a, Spec: 'a + MCTS> {
385 manager: &'a mut MCTSManager<Spec>,
386 stop_signal: Arc<AtomicBool>,
387 threads: Vec<JoinHandle<()>>,
388}
389
390impl<'a, Spec: MCTS> AsyncSearch<'a, Spec> {
391 pub fn halt(self) {}
392 pub fn num_threads(&self) -> usize {
393 self.threads.len()
394 }
395}
396
397impl<'a, Spec: MCTS> Drop for AsyncSearch<'a, Spec> {
398 fn drop(&mut self) {
399 self.stop_signal.store(true, Ordering::SeqCst);
400 drain_join_unwrap(&mut self.threads);
401 }
402}
403
404#[must_use]
405pub struct AsyncSearchOwned<Spec: MCTS> {
406 manager: Option<Box<MCTSManager<Spec>>>,
407 stop_signal: Arc<AtomicBool>,
408 threads: Vec<JoinHandle<()>>,
409}
410
411impl<Spec: MCTS> AsyncSearchOwned<Spec> {
412 fn stop_threads(&mut self) {
413 self.stop_signal.store(true, Ordering::SeqCst);
414 drain_join_unwrap(&mut self.threads);
415 }
416 pub fn halt(mut self) -> MCTSManager<Spec> {
417 self.stop_threads();
418 *self.manager.take().unwrap()
419 }
420 pub fn num_threads(&self) -> usize {
421 self.threads.len()
422 }
423}
424
425impl<Spec: MCTS> Drop for AsyncSearchOwned<Spec> {
426 fn drop(&mut self) {
427 self.stop_threads();
428 }
429}
430
431impl<Spec: MCTS> From<MCTSManager<Spec>> for AsyncSearchOwned<Spec> {
432 fn from(m: MCTSManager<Spec>) -> Self {
434 Self {
435 manager: Some(Box::new(m)),
436 stop_signal: Arc::new(AtomicBool::new(false)),
437 threads: Vec::new(),
438 }
439 }
440}
441
442fn drain_join_unwrap(threads: &mut Vec<JoinHandle<()>>) {
443 let join_results: Vec<_> = threads.drain(..).map(|x| x.join()).collect();
444 for x in join_results {
445 x.unwrap();
446 }
447}
448
449pub enum CycleBehaviour<Spec: MCTS> {
450 Ignore,
451 UseCurrentEvalWhenCycleDetected,
452 PanicWhenCycleDetected,
453 UseThisEvalWhenCycleDetected(StateEvaluation<Spec>),
454}