mod atomics;
pub mod batch;
mod search_tree;
pub mod transposition_table;
pub mod tree_policy;
pub use batch::*;
pub use search_tree::*;
use {transposition_table::*, tree_policy::*};
use rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256PlusPlus as Rng64;
use std::cell::RefCell;
use {
atomics::*,
std::{sync::Arc, thread::JoinHandle, time::Duration},
vec_storage_reuse::VecStorageForReuse,
};
pub trait MCTS: Sized + Send + Sync + 'static {
type State: GameState + Send + Sync + 'static;
type Eval: Evaluator<Self> + Send + 'static;
type TreePolicy: TreePolicy<Self> + Send + 'static;
type NodeData: Default + Sync + Send + 'static;
type TranspositionTable: TranspositionTable<Self> + Send + 'static;
type ExtraThreadData: 'static;
fn virtual_loss(&self) -> i64 {
0
}
fn fpu_value(&self) -> f64 {
f64::INFINITY
}
fn visits_before_expansion(&self) -> u64 {
1
}
fn node_limit(&self) -> usize {
usize::MAX
}
fn select_child_after_search<'a>(&self, children: &'a [MoveInfo<Self>]) -> &'a MoveInfo<Self> {
if self.solver_enabled() {
if let Some(winner) = children
.iter()
.find(|c| c.child_proven_value() == ProvenValue::Loss)
{
return winner;
}
if let Some(drawer) = children
.iter()
.find(|c| c.child_proven_value() == ProvenValue::Draw)
{
return drawer;
}
}
if self.score_bounded_enabled() {
let best_lower = children
.iter()
.map(|c| negate_bound(c.child_score_bounds().upper))
.max()
.unwrap_or(i32::MIN);
if best_lower > i32::MIN {
return children
.iter()
.max_by_key(|c| negate_bound(c.child_score_bounds().upper))
.unwrap();
}
}
children.iter().max_by_key(|child| child.visits()).unwrap()
}
fn max_playout_length(&self) -> usize {
1_000_000
}
fn max_playout_depth(&self) -> usize {
usize::MAX
}
fn rng_seed(&self) -> Option<u64> {
None
}
fn dirichlet_noise(&self) -> Option<(f64, f64)> {
None
}
fn selection_temperature(&self) -> f64 {
0.0
}
fn solver_enabled(&self) -> bool {
false
}
fn score_bounded_enabled(&self) -> bool {
false
}
fn closed_loop_chance(&self) -> bool {
false
}
fn on_backpropagation(&self, _evaln: &StateEvaluation<Self>, _handle: SearchHandle<Self>) {}
fn cycle_behaviour(&self) -> CycleBehaviour<Self> {
if std::mem::size_of::<Self::TranspositionTable>() == 0 {
CycleBehaviour::Ignore
} else {
CycleBehaviour::PanicWhenCycleDetected
}
}
}
pub struct ThreadData<Spec: MCTS> {
pub policy_data: TreePolicyThreadData<Spec>,
pub extra_data: Spec::ExtraThreadData,
}
impl<Spec: MCTS> Default for ThreadData<Spec>
where
TreePolicyThreadData<Spec>: Default,
Spec::ExtraThreadData: Default,
{
fn default() -> Self {
Self {
policy_data: Default::default(),
extra_data: Default::default(),
}
}
}
pub struct ThreadDataFull<Spec: MCTS> {
tld: ThreadData<Spec>,
path: VecStorageForReuse<*const MoveInfo<Spec>>,
node_path: VecStorageForReuse<*const SearchNode<Spec>>,
players: VecStorageForReuse<Player<Spec>>,
chance_rng: Rng64,
}
impl<Spec: MCTS> Default for ThreadDataFull<Spec>
where
ThreadData<Spec>: Default,
{
fn default() -> Self {
Self {
tld: Default::default(),
path: VecStorageForReuse::default(),
node_path: VecStorageForReuse::default(),
players: VecStorageForReuse::default(),
chance_rng: Rng64::from_rng(rand::thread_rng()).unwrap(),
}
}
}
pub type MoveEvaluation<Spec> = <<Spec as MCTS>::TreePolicy as TreePolicy<Spec>>::MoveEvaluation;
pub type StateEvaluation<Spec> = <<Spec as MCTS>::Eval as Evaluator<Spec>>::StateEvaluation;
pub type Move<Spec> = <<Spec as MCTS>::State as GameState>::Move;
pub type MoveList<Spec> = <<Spec as MCTS>::State as GameState>::MoveList;
pub type Player<Spec> = <<Spec as MCTS>::State as GameState>::Player;
pub type TreePolicyThreadData<Spec> =
<<Spec as MCTS>::TreePolicy as TreePolicy<Spec>>::ThreadLocalData;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum ProvenValue {
Unknown = 0,
Win = 1,
Loss = 2,
Draw = 3,
}
impl ProvenValue {
pub const fn from_u8(v: u8) -> Self {
match v {
1 => ProvenValue::Win,
2 => ProvenValue::Loss,
3 => ProvenValue::Draw,
_ => ProvenValue::Unknown,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ScoreBounds {
pub lower: i32,
pub upper: i32,
}
impl ScoreBounds {
pub const UNBOUNDED: Self = Self {
lower: i32::MIN,
upper: i32::MAX,
};
pub const fn exact(v: i32) -> Self {
Self { lower: v, upper: v }
}
pub const fn is_proven(&self) -> bool {
self.lower == self.upper
}
}
pub const fn negate_bound(v: i32) -> i32 {
match v {
i32::MIN => i32::MAX,
i32::MAX => i32::MIN,
_ => -v,
}
}
pub trait GameState: Clone {
type Move: Sync + Send + Clone;
type Player: Sync;
type MoveList: std::iter::IntoIterator<Item = Self::Move>;
fn current_player(&self) -> Self::Player;
fn available_moves(&self) -> Self::MoveList;
fn make_move(&mut self, mov: &Self::Move);
fn max_children(&self, _visits: u64) -> usize {
usize::MAX
}
fn terminal_value(&self) -> Option<ProvenValue> {
None
}
fn terminal_score(&self) -> Option<i32> {
None
}
fn chance_outcomes(&self) -> Option<Vec<(Self::Move, f64)>> {
None
}
}
pub trait Evaluator<Spec: MCTS>: Sync {
type StateEvaluation: Sync + Send;
fn evaluate_new_state(
&self,
state: &Spec::State,
moves: &MoveList<Spec>,
handle: Option<SearchHandle<Spec>>,
) -> (Vec<MoveEvaluation<Spec>>, Self::StateEvaluation);
fn evaluate_existing_state(
&self,
state: &Spec::State,
existing_evaln: &Self::StateEvaluation,
handle: SearchHandle<Spec>,
) -> Self::StateEvaluation;
fn interpret_evaluation_for_player(
&self,
evaluation: &Self::StateEvaluation,
player: &Player<Spec>,
) -> i64;
}
pub struct MCTSManager<Spec: MCTS> {
search_tree: Arc<SearchTree<Spec>>,
single_threaded_tld: Option<ThreadDataFull<Spec>>,
print_on_playout_error: bool,
selection_rng: RefCell<Rng64>,
}
impl<Spec: MCTS> MCTSManager<Spec>
where
ThreadData<Spec>: Default,
{
pub fn new(
state: Spec::State,
manager: Spec,
eval: Spec::Eval,
tree_policy: Spec::TreePolicy,
table: Spec::TranspositionTable,
) -> Self {
let selection_rng = match manager.rng_seed() {
Some(seed) => Rng64::seed_from_u64(seed.wrapping_add(u64::MAX / 2)),
None => Rng64::from_rng(rand::thread_rng()).unwrap(),
};
let search_tree = Arc::new(SearchTree::new(state, manager, tree_policy, eval, table));
let single_threaded_tld = None;
Self {
search_tree,
single_threaded_tld,
print_on_playout_error: true,
selection_rng: RefCell::new(selection_rng),
}
}
pub fn print_on_playout_error(&mut self, v: bool) -> &mut Self {
self.print_on_playout_error = v;
self
}
pub fn playout(&mut self) {
if self.single_threaded_tld.is_none() {
self.single_threaded_tld = Some(self.search_tree.make_thread_data());
}
self.search_tree
.playout(self.single_threaded_tld.as_mut().unwrap());
}
pub fn playout_until<Predicate: FnMut() -> bool>(&mut self, mut pred: Predicate) {
while !pred() {
self.playout();
}
}
pub fn playout_n(&mut self, n: u64) {
for _ in 0..n {
self.playout();
}
}
pub fn playout_parallel_async<'a>(&'a mut self, num_threads: usize) -> AsyncSearch<'a, Spec> {
assert!(num_threads != 0);
let stop_signal = Arc::new(AtomicBool::new(false));
let threads = (0..num_threads)
.map(|_| {
spawn_search_thread(
Arc::clone(&self.search_tree),
Arc::clone(&stop_signal),
self.print_on_playout_error,
)
})
.collect();
AsyncSearch {
manager: self,
stop_signal,
threads,
}
}
pub fn into_playout_parallel_async(self, num_threads: usize) -> AsyncSearchOwned<Spec> {
assert!(num_threads != 0);
let self_box = Box::new(self);
let stop_signal = Arc::new(AtomicBool::new(false));
let threads = (0..num_threads)
.map(|_| {
spawn_search_thread(
Arc::clone(&self_box.search_tree),
Arc::clone(&stop_signal),
self_box.print_on_playout_error,
)
})
.collect();
AsyncSearchOwned {
manager: Some(self_box),
stop_signal,
threads,
}
}
pub fn playout_parallel_for(&mut self, duration: Duration, num_threads: usize) {
assert!(num_threads != 0);
let stop_signal = AtomicBool::new(false);
let search_tree = &*self.search_tree;
let print_on_playout_error = self.print_on_playout_error;
std::thread::scope(|s| {
for _ in 0..num_threads {
s.spawn(|| {
let mut tld = search_tree.make_thread_data();
loop {
if stop_signal.load(Ordering::Acquire) {
break;
}
if !search_tree.playout(&mut tld) {
if print_on_playout_error {
eprintln!(
"Node limit of {} reached. Halting search.",
search_tree.spec().node_limit()
);
}
break;
}
}
});
}
std::thread::sleep(duration);
stop_signal.store(true, Ordering::Release);
});
}
pub fn playout_n_parallel(&mut self, n: u32, num_threads: usize) {
if n == 0 {
return;
}
assert!(num_threads != 0);
let counter = AtomicI64::new(n as i64);
let search_tree = &*self.search_tree;
std::thread::scope(|s| {
for _ in 0..num_threads {
s.spawn(|| {
let mut tld = search_tree.make_thread_data();
loop {
let count = counter.fetch_sub(1, Ordering::SeqCst);
if count <= 0 {
break;
}
search_tree.playout(&mut tld);
}
});
}
});
}
#[must_use]
pub fn principal_variation_info(&self, num_moves: usize) -> Vec<MoveInfoHandle<'_, Spec>> {
self.search_tree.principal_variation(num_moves)
}
#[must_use]
pub fn principal_variation(&self, num_moves: usize) -> Vec<Move<Spec>> {
self.search_tree
.principal_variation(num_moves)
.into_iter()
.map(|x| x.get_move().clone())
.collect()
}
#[must_use]
pub fn principal_variation_states(&self, num_moves: usize) -> Vec<Spec::State> {
let moves = self.principal_variation(num_moves);
let mut states = vec![self.search_tree.root_state().clone()];
for mov in moves {
let mut state = states[states.len() - 1].clone();
state.make_move(&mov);
states.push(state);
}
states
}
pub fn tree(&self) -> &SearchTree<Spec> {
&self.search_tree
}
#[must_use]
pub fn root_proven_value(&self) -> ProvenValue {
self.search_tree.root_proven_value()
}
#[must_use]
pub fn root_score_bounds(&self) -> ScoreBounds {
self.search_tree.root_score_bounds()
}
#[must_use]
pub fn best_move(&self) -> Option<Move<Spec>> {
let temperature = self.search_tree.spec().selection_temperature();
if temperature < 1e-8 {
self.principal_variation(1).first().cloned()
} else {
self.select_move_by_temperature(temperature)
}
}
fn select_move_by_temperature(&self, temperature: f64) -> Option<Move<Spec>> {
let inv_temp = 1.0 / temperature;
let weighted: Vec<_> = self
.search_tree
.root_node()
.moves()
.filter(|c| c.visits() > 0)
.map(|c| (c.get_move().clone(), (c.visits() as f64).powf(inv_temp)))
.collect();
if weighted.is_empty() {
return None;
}
let total: f64 = weighted.iter().map(|(_, w)| w).sum();
let mut roll: f64 = self.selection_rng.borrow_mut().gen::<f64>() * total;
for (mov, weight) in &weighted {
roll -= weight;
if roll <= 0.0 {
return Some(mov.clone());
}
}
Some(weighted.last().unwrap().0.clone())
}
pub fn perf_test<F>(&mut self, num_threads: usize, mut f: F)
where
F: FnMut(usize),
{
let search = self.playout_parallel_async(num_threads);
for _ in 0..10 {
let n1 = search.manager.search_tree.num_nodes();
std::thread::sleep(Duration::from_secs(1));
let n2 = search.manager.search_tree.num_nodes();
let diff = n2.saturating_sub(n1);
f(diff);
}
}
pub fn perf_test_to_stderr(&mut self, num_threads: usize) {
self.perf_test(num_threads, |x| {
eprintln!("{} nodes/sec", thousands_separate(x))
});
}
#[must_use]
pub fn reset(self) -> Self {
let search_tree = Arc::try_unwrap(self.search_tree)
.unwrap_or_else(|_| panic!("Cannot reset while async search is running"));
let selection_rng = match search_tree.spec().rng_seed() {
Some(seed) => Rng64::seed_from_u64(seed.wrapping_add(u64::MAX / 2)),
None => Rng64::from_rng(rand::thread_rng()).unwrap(),
};
Self {
search_tree: Arc::new(search_tree.reset()),
print_on_playout_error: self.print_on_playout_error,
single_threaded_tld: None,
selection_rng: RefCell::new(selection_rng),
}
}
}
impl<Spec: MCTS> MCTSManager<Spec>
where
Move<Spec>: PartialEq,
ThreadData<Spec>: Default,
{
pub fn advance(&mut self, mov: &Move<Spec>) -> Result<(), AdvanceError> {
let tree = Arc::get_mut(&mut self.search_tree)
.expect("Cannot advance while async search is running");
tree.advance_root(mov)?;
self.single_threaded_tld = None;
Ok(())
}
}
impl<Spec: MCTS> MCTSManager<Spec>
where
MoveEvaluation<Spec>: Clone,
{
#[must_use]
pub fn root_child_stats(&self) -> Vec<ChildStats<Spec>> {
self.search_tree.root_child_stats()
}
}
fn thousands_separate(x: usize) -> String {
let s = format!("{}", x);
let chunks: Vec<&str> = s
.as_bytes()
.rchunks(3)
.rev()
.map(|chunk| std::str::from_utf8(chunk).unwrap())
.collect();
chunks.join(",")
}
#[must_use]
pub struct AsyncSearch<'a, Spec: 'a + MCTS> {
manager: &'a mut MCTSManager<Spec>,
stop_signal: Arc<AtomicBool>,
threads: Vec<JoinHandle<()>>,
}
impl<'a, Spec: MCTS> AsyncSearch<'a, Spec> {
pub fn halt(self) {}
pub fn num_threads(&self) -> usize {
self.threads.len()
}
}
impl<'a, Spec: MCTS> Drop for AsyncSearch<'a, Spec> {
fn drop(&mut self) {
self.stop_signal.store(true, Ordering::Release);
drain_join_unwrap(&mut self.threads);
}
}
#[must_use]
pub struct AsyncSearchOwned<Spec: MCTS> {
manager: Option<Box<MCTSManager<Spec>>>,
stop_signal: Arc<AtomicBool>,
threads: Vec<JoinHandle<()>>,
}
impl<Spec: MCTS> AsyncSearchOwned<Spec> {
fn stop_threads(&mut self) {
self.stop_signal.store(true, Ordering::Release);
drain_join_unwrap(&mut self.threads);
}
pub fn halt(mut self) -> MCTSManager<Spec> {
self.stop_threads();
*self.manager.take().unwrap()
}
pub fn num_threads(&self) -> usize {
self.threads.len()
}
}
impl<Spec: MCTS> Drop for AsyncSearchOwned<Spec> {
fn drop(&mut self) {
self.stop_threads();
}
}
impl<Spec: MCTS> From<MCTSManager<Spec>> for AsyncSearchOwned<Spec> {
fn from(m: MCTSManager<Spec>) -> Self {
Self {
manager: Some(Box::new(m)),
stop_signal: Arc::new(AtomicBool::new(false)),
threads: Vec::new(),
}
}
}
fn spawn_search_thread<Spec: MCTS>(
search_tree: Arc<SearchTree<Spec>>,
stop_signal: Arc<AtomicBool>,
print_on_playout_error: bool,
) -> JoinHandle<()>
where
ThreadData<Spec>: Default,
{
std::thread::spawn(move || {
let mut tld = search_tree.make_thread_data();
loop {
if stop_signal.load(Ordering::Acquire) {
break;
}
if !search_tree.playout(&mut tld) {
if print_on_playout_error {
eprintln!(
"Node limit of {} reached. Halting search.",
search_tree.spec().node_limit()
);
}
break;
}
}
})
}
fn drain_join_unwrap(threads: &mut Vec<JoinHandle<()>>) {
let join_results: Vec<_> = threads.drain(..).map(|x| x.join()).collect();
for x in join_results {
x.unwrap();
}
}
pub enum CycleBehaviour<Spec: MCTS> {
Ignore,
UseCurrentEvalWhenCycleDetected,
PanicWhenCycleDetected,
UseThisEvalWhenCycleDetected(StateEvaluation<Spec>),
}