use rand::{Rng, SeedableRng};
use rand_xoshiro::Xoshiro256PlusPlus as Rng64;
use {
super::*,
atomics::*,
std::{
fmt,
fmt::{Debug, Display, Formatter},
ptr::null_mut,
sync::Mutex,
},
};
use tree_policy::TreePolicy;
pub struct SearchTree<Spec: MCTS> {
root_node: SearchNode<Spec>,
root_state: Spec::State,
tree_policy: Spec::TreePolicy,
table: Spec::TranspositionTable,
eval: Spec::Eval,
manager: Spec,
num_nodes: AtomicUsize,
thread_counter: AtomicUsize,
orphaned: Mutex<Vec<Box<SearchNode<Spec>>>>,
transposition_table_hits: AtomicUsize,
delayed_transposition_table_hits: AtomicUsize,
expansion_contention_events: AtomicUsize,
}
struct NodeStats {
visits: AtomicUsize,
sum_evaluations: AtomicI64,
}
pub struct MoveInfo<Spec: MCTS> {
mov: Move<Spec>,
move_evaluation: MoveEvaluation<Spec>,
child: AtomicPtr<SearchNode<Spec>>,
owned: AtomicBool,
stats: NodeStats,
}
pub struct SearchNode<Spec: MCTS> {
moves: Vec<MoveInfo<Spec>>,
data: Spec::NodeData,
evaln: StateEvaluation<Spec>,
stats: NodeStats,
proven: AtomicU8,
score_lower: AtomicI32,
score_upper: AtomicI32,
is_chance: bool,
chance_probs: Vec<f64>,
}
impl<Spec: MCTS> SearchNode<Spec> {
fn new(moves: Vec<MoveInfo<Spec>>, evaln: StateEvaluation<Spec>) -> Self {
Self {
moves,
data: Default::default(),
evaln,
stats: NodeStats::new(),
proven: AtomicU8::new(ProvenValue::Unknown as u8),
score_lower: AtomicI32::new(i32::MIN),
score_upper: AtomicI32::new(i32::MAX),
is_chance: false,
chance_probs: Vec::new(),
}
}
pub fn proven_value(&self) -> ProvenValue {
ProvenValue::from_u8(self.proven.load(Ordering::Relaxed))
}
pub fn score_bounds(&self) -> ScoreBounds {
ScoreBounds {
lower: self.score_lower.load(Ordering::Relaxed),
upper: self.score_upper.load(Ordering::Relaxed),
}
}
}
impl<Spec: MCTS> MoveInfo<Spec> {
fn new(mov: Move<Spec>, move_evaluation: MoveEvaluation<Spec>) -> Self {
MoveInfo {
mov,
move_evaluation,
child: AtomicPtr::default(),
stats: NodeStats::new(),
owned: AtomicBool::new(false),
}
}
pub fn get_move(&self) -> &Move<Spec> {
&self.mov
}
pub fn move_evaluation(&self) -> &MoveEvaluation<Spec> {
&self.move_evaluation
}
pub(crate) fn set_move_evaluation(&mut self, eval: MoveEvaluation<Spec>) {
self.move_evaluation = eval;
}
pub fn visits(&self) -> u64 {
self.stats.visits.load(Ordering::Relaxed) as u64
}
pub fn sum_rewards(&self) -> i64 {
self.stats.sum_evaluations.load(Ordering::Relaxed)
}
pub fn child(&self) -> Option<NodeHandle<'_, Spec>> {
let ptr = self.child.load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
unsafe { Some(NodeHandle { node: &*ptr }) }
}
}
pub fn child_proven_value(&self) -> ProvenValue {
let ptr = self.child.load(Ordering::Acquire);
if ptr.is_null() {
ProvenValue::Unknown
} else {
unsafe { (*ptr).proven_value() }
}
}
pub fn child_score_bounds(&self) -> ScoreBounds {
let ptr = self.child.load(Ordering::Acquire);
if ptr.is_null() {
ScoreBounds::UNBOUNDED
} else {
unsafe { (*ptr).score_bounds() }
}
}
}
impl<Spec: MCTS> Display for MoveInfo<Spec>
where
Move<Spec>: Display,
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let own_str = if self.owned.load(Ordering::Relaxed) {
""
} else {
" [child pointer is alias]"
};
if self.visits() == 0 {
write!(f, "{} [0 visits]{}", self.mov, own_str)
} else {
write!(
f,
"{} [{} visit{}] [{} avg reward]{}",
self.mov,
self.visits(),
if self.visits() == 1 { "" } else { "s" },
self.sum_rewards() as f64 / self.visits() as f64,
own_str
)
}
}
}
impl<Spec: MCTS> Debug for MoveInfo<Spec>
where
Move<Spec>: Debug,
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let own_str = if self.owned.load(Ordering::Relaxed) {
""
} else {
" [child pointer is alias]"
};
if self.visits() == 0 {
write!(f, "{:?} [0 visits]{}", self.mov, own_str)
} else {
write!(
f,
"{:?} [{} visit{}] [{} avg reward]{}",
self.mov,
self.visits(),
if self.visits() == 1 { "" } else { "s" },
self.sum_rewards() as f64 / self.visits() as f64,
own_str
)
}
}
}
impl<Spec: MCTS> Drop for MoveInfo<Spec> {
fn drop(&mut self) {
if !*self.owned.get_mut() {
return;
}
let ptr = *self.child.get_mut();
if !ptr.is_null() {
unsafe {
drop(Box::from_raw(ptr));
}
}
}
}
fn create_node<Spec: MCTS>(
eval: &Spec::Eval,
policy: &Spec::TreePolicy,
state: &Spec::State,
handle: Option<SearchHandle<Spec>>,
solver_enabled: bool,
score_bounded: bool,
closed_loop: bool,
) -> SearchNode<Spec> {
if closed_loop {
if let Some(outcomes) = state.chance_outcomes() {
let probs: Vec<f64> = outcomes.iter().map(|(_, p)| *p).collect();
let avail = state.available_moves();
let (_, state_eval) = eval.evaluate_new_state(state, &avail, handle);
let moves: Vec<MoveInfo<Spec>> = outcomes
.into_iter()
.map(|(m, _)| MoveInfo::new(m, Default::default()))
.collect();
return SearchNode {
moves,
data: Default::default(),
evaln: state_eval,
stats: NodeStats::new(),
proven: AtomicU8::new(ProvenValue::Unknown as u8),
score_lower: AtomicI32::new(i32::MIN),
score_upper: AtomicI32::new(i32::MAX),
is_chance: true,
chance_probs: probs,
};
}
}
let moves = state.available_moves();
let (move_eval, state_eval) = eval.evaluate_new_state(state, &moves, handle);
policy.validate_evaluations(&move_eval);
let mut moves: Vec<MoveInfo<Spec>> = moves
.into_iter()
.zip(move_eval)
.map(|(m, e)| MoveInfo::new(m, e))
.collect();
moves.sort_by(|a, b| policy.compare_move_evaluations(&a.move_evaluation, &b.move_evaluation));
let node = SearchNode::new(moves, state_eval);
if node.moves.is_empty() {
let tv = state.terminal_value();
let ts = state.terminal_score();
#[cfg(debug_assertions)]
if let (Some(pv), Some(score)) = (tv, ts) {
match pv {
ProvenValue::Win => debug_assert!(
score > 0,
"terminal_value is Win but terminal_score is {score}"
),
ProvenValue::Loss => debug_assert!(
score < 0,
"terminal_value is Loss but terminal_score is {score}"
),
ProvenValue::Draw => debug_assert!(
score == 0,
"terminal_value is Draw but terminal_score is {score}"
),
ProvenValue::Unknown => {}
}
}
if solver_enabled {
let proven = tv.or_else(|| {
ts.map(|s| {
if s > 0 {
ProvenValue::Win
} else if s < 0 {
ProvenValue::Loss
} else {
ProvenValue::Draw
}
})
});
if let Some(pv) = proven {
node.proven.store(pv as u8, Ordering::Relaxed);
}
}
if score_bounded {
let score = ts.or_else(|| {
tv.and_then(|pv| match pv {
ProvenValue::Win => Some(1),
ProvenValue::Loss => Some(-1),
ProvenValue::Draw => Some(0),
ProvenValue::Unknown => None, })
});
if let Some(s) = score {
node.score_lower.store(s, Ordering::Relaxed);
node.score_upper.store(s, Ordering::Relaxed);
}
}
}
node
}
fn try_prove_node<Spec: MCTS>(node: &SearchNode<Spec>) -> ProvenValue {
if node.moves.is_empty() {
return node.proven_value();
}
let mut all_children_proven = true;
let mut all_children_win = true; let mut has_child_draw = false;
for move_info in &node.moves {
let child_ptr = move_info.child.load(Ordering::Acquire);
if child_ptr.is_null() {
all_children_proven = false;
all_children_win = false;
continue;
}
let child_proven = unsafe { (*child_ptr).proven_value() };
match child_proven {
ProvenValue::Unknown => {
all_children_proven = false;
all_children_win = false;
}
ProvenValue::Win => {
}
ProvenValue::Loss => {
return ProvenValue::Win;
}
ProvenValue::Draw => {
has_child_draw = true;
all_children_win = false;
}
}
}
if all_children_proven && all_children_win {
return ProvenValue::Loss;
}
if all_children_proven && has_child_draw {
return ProvenValue::Draw;
}
ProvenValue::Unknown
}
fn try_tighten_bounds<Spec: MCTS>(node: &SearchNode<Spec>) -> ScoreBounds {
if node.moves.is_empty() {
return node.score_bounds();
}
let mut best_lower = i32::MIN;
let mut best_upper = i32::MIN;
for move_info in &node.moves {
let child_ptr = move_info.child.load(Ordering::Acquire);
let (child_lower, child_upper) = if child_ptr.is_null() {
(i32::MIN, i32::MAX)
} else {
unsafe {
let child = &*child_ptr;
(
child.score_lower.load(Ordering::Relaxed),
child.score_upper.load(Ordering::Relaxed),
)
}
};
let parent_lower_from_child = negate_bound(child_upper);
let parent_upper_from_child = negate_bound(child_lower);
best_lower = best_lower.max(parent_lower_from_child);
best_upper = best_upper.max(parent_upper_from_child);
}
ScoreBounds {
lower: best_lower,
upper: best_upper,
}
}
fn try_prove_chance_node<Spec: MCTS>(node: &SearchNode<Spec>) -> ProvenValue {
if node.moves.is_empty() {
return node.proven_value();
}
let mut all_win = true;
let mut all_loss = true;
for mi in &node.moves {
let ptr = mi.child.load(Ordering::Acquire);
if ptr.is_null() {
return ProvenValue::Unknown;
}
let child_proven = unsafe { (*ptr).proven_value() };
match child_proven {
ProvenValue::Unknown => return ProvenValue::Unknown,
ProvenValue::Win => {
all_loss = false;
}
ProvenValue::Loss => {
all_win = false;
}
ProvenValue::Draw => {
all_win = false;
all_loss = false;
}
}
}
if all_win {
ProvenValue::Win
} else if all_loss {
ProvenValue::Loss
} else {
ProvenValue::Draw
}
}
fn try_tighten_bounds_chance<Spec: MCTS>(node: &SearchNode<Spec>) -> ScoreBounds {
if node.moves.is_empty() {
return node.score_bounds();
}
let mut lower_sum: f64 = 0.0;
let mut upper_sum: f64 = 0.0;
for (mi, &prob) in node.moves.iter().zip(node.chance_probs.iter()) {
let ptr = mi.child.load(Ordering::Acquire);
if ptr.is_null() {
return ScoreBounds::UNBOUNDED;
}
let child_lower = unsafe { (*ptr).score_lower.load(Ordering::Relaxed) };
let child_upper = unsafe { (*ptr).score_upper.load(Ordering::Relaxed) };
if child_lower == i32::MIN || child_upper == i32::MAX {
return ScoreBounds::UNBOUNDED;
}
lower_sum += prob * child_lower as f64;
upper_sum += prob * child_upper as f64;
}
ScoreBounds {
lower: lower_sum.floor() as i32,
upper: upper_sum.ceil() as i32,
}
}
fn sample_chance_child<'a, Spec: MCTS>(
node: &'a SearchNode<Spec>,
rng: &mut Rng64,
) -> &'a MoveInfo<Spec> {
debug_assert!(node.is_chance);
debug_assert!(!node.moves.is_empty(), "chance node has no outcomes");
let roll: f64 = rng.gen();
let mut cumulative = 0.0;
for (mi, &prob) in node.moves.iter().zip(node.chance_probs.iter()) {
cumulative += prob;
if roll < cumulative {
return mi;
}
}
node.moves.last().unwrap()
}
fn is_cycle<T>(past: &[&T], current: &T) -> bool {
past.iter().any(|x| std::ptr::eq(*x, current))
}
fn sample_chance_outcome<'a, M>(outcomes: &'a [(M, f64)], rng: &mut Rng64) -> &'a M {
debug_assert!(!outcomes.is_empty(), "chance_outcomes returned empty vec");
let roll: f64 = rng.gen();
let mut cumulative = 0.0;
for (outcome, prob) in outcomes {
cumulative += prob;
if roll < cumulative {
return outcome;
}
}
&outcomes.last().unwrap().0
}
impl<Spec: MCTS> SearchTree<Spec> {
pub fn new(
state: Spec::State,
manager: Spec,
tree_policy: Spec::TreePolicy,
eval: Spec::Eval,
table: Spec::TranspositionTable,
) -> Self {
let solver = manager.solver_enabled();
let score_bounded = manager.score_bounded_enabled();
let closed_loop = manager.closed_loop_chance();
let mut root_node = create_node(
&eval,
&tree_policy,
&state,
None,
solver,
score_bounded,
closed_loop,
);
if let Some((epsilon, alpha)) = manager.dirichlet_noise() {
let mut rng = match manager.rng_seed() {
Some(seed) => Rng64::seed_from_u64(seed.wrapping_add(u64::MAX)),
None => Rng64::from_rng(rand::thread_rng()).unwrap(),
};
tree_policy.apply_dirichlet_noise(&mut root_node.moves, epsilon, alpha, &mut rng);
}
Self {
root_state: state,
root_node,
manager,
tree_policy,
eval,
table,
num_nodes: 1.into(),
thread_counter: 0.into(),
orphaned: Mutex::new(Vec::new()),
transposition_table_hits: 0.into(),
delayed_transposition_table_hits: 0.into(),
expansion_contention_events: 0.into(),
}
}
pub fn make_thread_data(&self) -> ThreadDataFull<Spec>
where
ThreadData<Spec>: Default,
{
let mut tld = ThreadDataFull::<Spec>::default();
if let Some(base_seed) = self.manager.rng_seed() {
let thread_id = self.thread_counter.fetch_add(1, Ordering::Relaxed) as u64;
let seed = base_seed.wrapping_add(thread_id);
self.tree_policy
.seed_thread_data(&mut tld.tld.policy_data, seed);
tld.chance_rng = Rng64::seed_from_u64(seed.wrapping_add(0xCAFE_BABE));
}
tld
}
pub fn reset(self) -> Self {
Self::new(
self.root_state,
self.manager,
self.tree_policy,
self.eval,
self.table,
)
}
pub fn spec(&self) -> &Spec {
&self.manager
}
#[must_use]
pub fn num_nodes(&self) -> usize {
self.num_nodes.load(Ordering::SeqCst)
}
#[inline(never)]
pub fn playout(&self, tld: &mut ThreadDataFull<Spec>) -> bool {
let sentinel = IncreaseSentinel::new(&self.num_nodes);
if sentinel.num_nodes >= self.manager.node_limit() {
return false;
}
let solver = self.manager.solver_enabled();
let score_bounded = self.manager.score_bounded_enabled();
if solver && self.root_node.proven_value() != ProvenValue::Unknown {
return false;
}
if score_bounded {
let bounds = self.root_node.score_bounds();
if bounds.is_proven() {
return false;
}
}
let mut state = self.root_state.clone();
let path: &mut Vec<&MoveInfo<Spec>> = &mut tld.path.reuse_allocation();
let node_path: &mut Vec<&SearchNode<Spec>> = &mut tld.node_path.reuse_allocation();
let players: &mut Vec<Player<Spec>> = &mut tld.players.reuse_allocation();
let chance_rng = &mut tld.chance_rng;
let closed_loop = self.manager.closed_loop_chance();
let tld = &mut tld.tld;
if !closed_loop {
while let Some(outcomes) = state.chance_outcomes() {
let outcome = sample_chance_outcome(&outcomes, chance_rng);
state.make_move(outcome);
}
}
let mut did_we_create = false;
let mut node = &self.root_node;
loop {
if node.moves.is_empty() {
break;
}
if solver && node.proven_value() != ProvenValue::Unknown {
break;
}
if path.len() >= self.manager.max_playout_depth() {
break;
}
if path.len() >= self.manager.max_playout_length() {
break;
}
let choice = if node.is_chance {
sample_chance_child(node, chance_rng)
} else {
let parent_visits = node.stats.visits.load(Ordering::Relaxed) as u64;
let max_children = state
.max_children(parent_visits)
.max(1)
.min(node.moves.len());
self.tree_policy.choose_child(
node.moves[..max_children].iter(),
self.make_handle(node, tld),
)
};
choice.stats.down(&self.manager);
players.push(state.current_player());
path.push(choice);
assert!(
path.len() <= self.manager.max_playout_length(),
"playout length exceeded maximum of {} (maybe the transposition table is creating an infinite loop?)",
self.manager.max_playout_length()
);
state.make_move(&choice.mov);
if !closed_loop {
while let Some(outcomes) = state.chance_outcomes() {
let outcome = sample_chance_outcome(&outcomes, chance_rng);
state.make_move(outcome);
}
}
let (new_node, new_did_we_create) = self.descend(&state, choice, node, tld);
node = new_node;
did_we_create = new_did_we_create;
match self.manager.cycle_behaviour() {
CycleBehaviour::Ignore => (),
CycleBehaviour::PanicWhenCycleDetected => {
if is_cycle(node_path, node) {
panic!("cycle detected! you should do one of the following:\n- make states acyclic\n- remove transposition table\n- change cycle_behaviour()");
}
}
CycleBehaviour::UseCurrentEvalWhenCycleDetected => {
if is_cycle(node_path, node) {
break;
}
}
CycleBehaviour::UseThisEvalWhenCycleDetected(e) => {
if is_cycle(node_path, node) {
self.finish_playout(path, node_path, players, tld, &e);
return true;
}
}
};
node_path.push(node);
node.stats.down(&self.manager);
if node.stats.visits.load(Ordering::Relaxed) as u64
<= self.manager.visits_before_expansion()
{
break;
}
}
let new_evaln = if did_we_create {
None
} else {
Some(self.eval.evaluate_existing_state(
&state,
&node.evaln,
self.make_handle(node, tld),
))
};
let evaln = new_evaln.as_ref().unwrap_or(&node.evaln);
self.finish_playout(path, node_path, players, tld, evaln);
true
}
fn descend<'a, 'b>(
&'a self,
state: &Spec::State,
choice: &MoveInfo<Spec>,
current_node: &'b SearchNode<Spec>,
tld: &'b mut ThreadData<Spec>,
) -> (&'a SearchNode<Spec>, bool) {
let child = choice.child.load(Ordering::Acquire);
if !child.is_null() {
return unsafe { (&*child, false) };
}
if let Some(node) = self
.table
.lookup(state, self.make_handle(current_node, tld))
{
let child = choice
.child
.compare_exchange(
null_mut(),
node as *const _ as *mut _,
Ordering::Release,
Ordering::Acquire,
)
.unwrap_or_else(|x| x);
if child.is_null() {
self.transposition_table_hits
.fetch_add(1, Ordering::Relaxed);
return (node, false);
} else {
return unsafe { (&*child, false) };
}
}
let created = create_node(
&self.eval,
&self.tree_policy,
state,
Some(self.make_handle(current_node, tld)),
self.manager.solver_enabled(),
self.manager.score_bounded_enabled(),
self.manager.closed_loop_chance(),
);
let created = Box::into_raw(Box::new(created));
let other_child = choice
.child
.compare_exchange(null_mut(), created, Ordering::Release, Ordering::Acquire)
.unwrap_or_else(|x| x);
if !other_child.is_null() {
self.expansion_contention_events
.fetch_add(1, Ordering::Relaxed);
unsafe {
drop(Box::from_raw(created));
return (&*other_child, false);
}
}
if let Some(existing) = self.table.insert(
state,
unsafe { &*created },
self.make_handle(current_node, tld),
) {
self.delayed_transposition_table_hits
.fetch_add(1, Ordering::Relaxed);
let existing_ptr = existing as *const _ as *mut _;
choice.child.store(existing_ptr, Ordering::Release);
self.orphaned
.lock()
.unwrap()
.push(unsafe { Box::from_raw(created) });
return (existing, false);
}
choice.owned.store(true, Ordering::Release);
self.num_nodes.fetch_add(1, Ordering::Relaxed);
unsafe { (&*created, true) }
}
fn finish_playout(
&self,
path: &[&MoveInfo<Spec>],
node_path: &[&SearchNode<Spec>],
players: &[Player<Spec>],
tld: &mut ThreadData<Spec>,
evaln: &StateEvaluation<Spec>,
) {
for ((move_info, player), node) in
path.iter().zip(players.iter()).zip(node_path.iter()).rev()
{
let evaln_value = self.eval.interpret_evaluation_for_player(evaln, player);
node.stats.up(&self.manager, evaln_value);
move_info.stats.replace(&node.stats);
unsafe {
self.manager.on_backpropagation(
evaln,
self.make_handle(&*move_info.child.load(Ordering::Acquire), tld),
);
}
}
self.manager
.on_backpropagation(evaln, self.make_handle(&self.root_node, tld));
if self.manager.solver_enabled() {
self.propagate_proven(path, node_path);
}
if self.manager.score_bounded_enabled() {
self.propagate_score_bounds(path, node_path);
}
}
fn propagate_proven(&self, path: &[&MoveInfo<Spec>], node_path: &[&SearchNode<Spec>]) {
for i in (0..path.len()).rev() {
let child_ptr = path[i].child.load(Ordering::Acquire);
if child_ptr.is_null() {
break;
}
let child_proven = unsafe { (*child_ptr).proven_value() };
if child_proven == ProvenValue::Unknown {
break;
}
let parent = if i == 0 {
&self.root_node
} else {
node_path[i - 1]
};
if parent.proven_value() != ProvenValue::Unknown {
continue;
}
let parent_proven = if parent.is_chance {
try_prove_chance_node(parent)
} else {
try_prove_node(parent)
};
if parent_proven == ProvenValue::Unknown {
break;
}
let _ = parent.proven.compare_exchange(
ProvenValue::Unknown as u8,
parent_proven as u8,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
}
fn propagate_score_bounds(&self, path: &[&MoveInfo<Spec>], node_path: &[&SearchNode<Spec>]) {
for i in (0..path.len()).rev() {
let parent = if i == 0 {
&self.root_node
} else {
node_path[i - 1]
};
let new_bounds = if parent.is_chance {
try_tighten_bounds_chance(parent)
} else {
try_tighten_bounds(parent)
};
let old_lower = parent.score_lower.load(Ordering::Relaxed);
let old_upper = parent.score_upper.load(Ordering::Relaxed);
if new_bounds.lower > old_lower {
let _ = parent.score_lower.compare_exchange_weak(
old_lower,
new_bounds.lower,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
if new_bounds.upper < old_upper {
let _ = parent.score_upper.compare_exchange_weak(
old_upper,
new_bounds.upper,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
if self.manager.solver_enabled() && new_bounds.lower == new_bounds.upper {
let pv = if new_bounds.lower > 0 {
ProvenValue::Win
} else if new_bounds.lower < 0 {
ProvenValue::Loss
} else {
ProvenValue::Draw
};
let _ = parent.proven.compare_exchange(
ProvenValue::Unknown as u8,
pv as u8,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
if new_bounds.lower <= old_lower && new_bounds.upper >= old_upper {
break;
}
}
}
fn make_handle<'a>(
&'a self,
node: &'a SearchNode<Spec>,
tld: &'a mut ThreadData<Spec>,
) -> SearchHandle<'a, Spec> {
SearchHandle {
node,
tld,
manager: &self.manager,
}
}
pub fn root_state(&self) -> &Spec::State {
&self.root_state
}
pub fn root_node(&self) -> NodeHandle<'_, Spec> {
NodeHandle {
node: &self.root_node,
}
}
pub fn root_proven_value(&self) -> ProvenValue {
self.root_node.proven_value()
}
pub fn root_score_bounds(&self) -> ScoreBounds {
self.root_node.score_bounds()
}
pub fn principal_variation(&self, num_moves: usize) -> Vec<MoveInfoHandle<'_, Spec>> {
let mut result = Vec::new();
let mut crnt = &self.root_node;
while !crnt.moves.is_empty() && result.len() < num_moves {
let choice = if crnt.is_chance {
crnt.moves.iter().max_by_key(|c| c.visits()).unwrap()
} else {
self.manager.select_child_after_search(&crnt.moves)
};
result.push(choice);
let child = choice.child.load(Ordering::SeqCst) as *const SearchNode<Spec>;
if child.is_null() {
break;
} else {
unsafe {
crnt = &*child;
}
}
}
result
}
#[must_use]
pub fn diagnose(&self) -> String {
let mut s = String::new();
s.push_str(&format!(
"{} nodes\n",
thousands_separate(self.num_nodes.load(Ordering::Relaxed))
));
s.push_str(&format!(
"{} transposition table hits\n",
thousands_separate(self.transposition_table_hits.load(Ordering::Relaxed))
));
s.push_str(&format!(
"{} delayed transposition table hits\n",
thousands_separate(
self.delayed_transposition_table_hits
.load(Ordering::Relaxed)
)
));
s.push_str(&format!(
"{} expansion contention events\n",
thousands_separate(self.expansion_contention_events.load(Ordering::Relaxed))
));
s.push_str(&format!(
"{} orphaned nodes\n",
self.orphaned.lock().unwrap().len()
));
s
}
}
pub type MoveInfoHandle<'a, Spec> = &'a MoveInfo<Spec>;
pub struct ChildStats<Spec: MCTS> {
pub mov: Move<Spec>,
pub visits: u64,
pub avg_reward: f64,
pub move_evaluation: MoveEvaluation<Spec>,
pub proven_value: ProvenValue,
pub score_bounds: ScoreBounds,
}
impl<Spec: MCTS> Debug for ChildStats<Spec>
where
Move<Spec>: Debug,
MoveEvaluation<Spec>: Debug,
{
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ChildStats")
.field("mov", &self.mov)
.field("visits", &self.visits)
.field("avg_reward", &self.avg_reward)
.field("move_evaluation", &self.move_evaluation)
.field("proven_value", &self.proven_value)
.field("score_bounds", &self.score_bounds)
.finish()
}
}
impl<Spec: MCTS> Clone for ChildStats<Spec>
where
Move<Spec>: Clone,
MoveEvaluation<Spec>: Clone,
{
fn clone(&self) -> Self {
Self {
mov: self.mov.clone(),
visits: self.visits,
avg_reward: self.avg_reward,
move_evaluation: self.move_evaluation.clone(),
proven_value: self.proven_value,
score_bounds: self.score_bounds,
}
}
}
impl<Spec: MCTS> SearchTree<Spec>
where
MoveEvaluation<Spec>: Clone,
{
pub fn root_child_stats(&self) -> Vec<ChildStats<Spec>> {
self.root_node
.moves
.iter()
.map(|mi| {
let visits = mi.visits();
let avg_reward = if visits == 0 {
0.0
} else {
mi.sum_rewards() as f64 / visits as f64
};
ChildStats {
mov: mi.get_move().clone(),
visits,
avg_reward,
move_evaluation: mi.move_evaluation().clone(),
proven_value: mi.child_proven_value(),
score_bounds: mi.child_score_bounds(),
}
})
.collect()
}
}
impl<Spec: MCTS> SearchTree<Spec>
where
Move<Spec>: Debug,
{
pub fn debug_moves(&self) {
let mut moves: Vec<&MoveInfo<Spec>> = self.root_node.moves.iter().collect();
moves.sort_by_key(|x| -(x.visits() as i64));
for mov in moves {
println!("{:?}", mov);
}
}
}
impl<Spec: MCTS> SearchTree<Spec>
where
Move<Spec>: Display,
{
pub fn display_moves(&self) {
let mut moves: Vec<&MoveInfo<Spec>> = self.root_node.moves.iter().collect();
moves.sort_by_key(|x| -(x.visits() as i64));
for mov in moves {
println!("{}", mov);
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AdvanceError {
MoveNotFound,
ChildNotExpanded,
ChildNotOwned,
}
impl std::error::Error for AdvanceError {}
impl Display for AdvanceError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
AdvanceError::MoveNotFound => write!(f, "move not found in root children"),
AdvanceError::ChildNotExpanded => write!(f, "child node was never expanded"),
AdvanceError::ChildNotOwned => {
write!(f, "child node is a transposition alias (not owned)")
}
}
}
}
impl<Spec: MCTS> SearchTree<Spec>
where
Move<Spec>: PartialEq,
{
pub fn advance_root(&mut self, mov: &Move<Spec>) -> Result<(), AdvanceError> {
let idx = self
.root_node
.moves
.iter()
.position(|m| m.mov == *mov)
.ok_or(AdvanceError::MoveNotFound)?;
let move_info = &self.root_node.moves[idx];
let child_ptr = move_info.child.load(Ordering::SeqCst);
if child_ptr.is_null() {
return Err(AdvanceError::ChildNotExpanded);
}
if !move_info.owned.load(Ordering::SeqCst) {
return Err(AdvanceError::ChildNotOwned);
}
self.root_node.moves[idx]
.child
.store(null_mut(), Ordering::SeqCst);
self.root_node.moves[idx]
.owned
.store(false, Ordering::SeqCst);
let new_root = unsafe { *Box::from_raw(child_ptr) };
self.root_state.make_move(mov);
let old_root = std::mem::replace(&mut self.root_node, new_root);
if let Some((epsilon, alpha)) = self.manager.dirichlet_noise() {
let mut rng = match self.manager.rng_seed() {
Some(seed) => Rng64::seed_from_u64(seed.wrapping_add(u64::MAX)),
None => Rng64::from_rng(rand::thread_rng()).unwrap(),
};
self.tree_policy.apply_dirichlet_noise(
&mut self.root_node.moves,
epsilon,
alpha,
&mut rng,
);
}
self.table.clear();
drop(old_root);
self.orphaned.lock().unwrap().clear();
self.num_nodes.store(1, Ordering::SeqCst);
self.transposition_table_hits.store(0, Ordering::SeqCst);
self.delayed_transposition_table_hits
.store(0, Ordering::SeqCst);
self.expansion_contention_events.store(0, Ordering::SeqCst);
Ok(())
}
}
#[derive(Clone, Copy)]
pub struct NodeHandle<'a, Spec: 'a + MCTS> {
node: &'a SearchNode<Spec>,
}
impl<'a, Spec: MCTS> NodeHandle<'a, Spec> {
pub fn data(&self) -> &'a Spec::NodeData {
&self.node.data
}
pub fn moves(&self) -> Moves<'_, Spec> {
Moves {
iter: self.node.moves.iter(),
}
}
pub fn proven_value(&self) -> ProvenValue {
self.node.proven_value()
}
pub fn score_bounds(&self) -> ScoreBounds {
self.node.score_bounds()
}
pub fn into_raw(&self) -> *const () {
self.node as *const _ as *const ()
}
pub unsafe fn from_raw(ptr: *const ()) -> Self {
NodeHandle {
node: &*(ptr as *const SearchNode<Spec>),
}
}
}
#[derive(Clone)]
pub struct Moves<'a, Spec: 'a + MCTS> {
iter: std::slice::Iter<'a, MoveInfo<Spec>>,
}
impl<'a, Spec: 'a + MCTS> Iterator for Moves<'a, Spec> {
type Item = &'a MoveInfo<Spec>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<'a, Spec: 'a + MCTS> ExactSizeIterator for Moves<'a, Spec> {
fn len(&self) -> usize {
self.iter.len()
}
}
pub struct SearchHandle<'a, Spec: 'a + MCTS> {
node: &'a SearchNode<Spec>,
tld: &'a mut ThreadData<Spec>,
manager: &'a Spec,
}
impl<'a, Spec: MCTS> SearchHandle<'a, Spec> {
pub fn node(&self) -> NodeHandle<'a, Spec> {
NodeHandle { node: self.node }
}
pub fn thread_data(&mut self) -> &mut ThreadData<Spec> {
self.tld
}
pub fn mcts(&self) -> &'a Spec {
self.manager
}
}
impl NodeStats {
fn new() -> Self {
NodeStats {
sum_evaluations: AtomicI64::new(0),
visits: AtomicUsize::new(0),
}
}
fn down<Spec: MCTS>(&self, manager: &Spec) {
self.sum_evaluations
.fetch_sub(manager.virtual_loss(), Ordering::Relaxed);
self.visits.fetch_add(1, Ordering::Relaxed);
}
fn up<Spec: MCTS>(&self, manager: &Spec, evaln: i64) {
let delta = evaln + manager.virtual_loss();
self.sum_evaluations.fetch_add(delta, Ordering::Relaxed);
}
fn replace(&self, other: &NodeStats) {
self.visits
.store(other.visits.load(Ordering::Relaxed), Ordering::Relaxed);
self.sum_evaluations.store(
other.sum_evaluations.load(Ordering::Relaxed),
Ordering::Relaxed,
);
}
}
struct IncreaseSentinel<'a> {
x: &'a AtomicUsize,
num_nodes: usize,
}
impl<'a> IncreaseSentinel<'a> {
fn new(x: &'a AtomicUsize) -> Self {
let num_nodes = x.fetch_add(1, Ordering::Relaxed);
Self { x, num_nodes }
}
}
impl<'a> Drop for IncreaseSentinel<'a> {
fn drop(&mut self) {
self.x.fetch_sub(1, Ordering::Relaxed);
}
}