use rand::Rng;
use rand::SeedableRng;
use crate::config::SearchConfig;
use crate::environment::{Environment, Outcome};
use crate::node::{Node, NodeStats};
#[derive(Clone)]
pub struct TreeSearch<E: Environment> {
pub(crate) root_env: E,
pub(crate) config: SearchConfig,
pub(crate) nodes: Vec<Node<E::Action>>,
pub(crate) rng: rand_chacha::ChaCha8Rng,
#[cfg(feature = "dag")]
pub transposition_table: Option<hashbrown::HashMap<u64, u32>>,
pub(crate) evaluator: Option<std::sync::Arc<dyn crate::environment::Evaluator<E>>>,
pub(crate) max_nodes: Option<usize>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TreeSearchCheckpoint<E>
where
E: Environment + Clone,
E::Action: serde::Serialize + for<'action> serde::Deserialize<'action>,
{
pub root_env: E,
pub config: SearchConfig,
pub nodes: Vec<Node<E::Action>>,
pub rng_seed: Option<u64>,
}
impl<E: Environment> TreeSearch<E> {
pub fn new(environment: E, config: SearchConfig) -> Self {
let root_actions = environment.legal_actions();
let root = Node::root(root_actions);
Self {
root_env: environment,
config,
nodes: vec![root],
rng: entropy_rng(),
#[cfg(feature = "dag")]
transposition_table: None,
evaluator: None,
max_nodes: None,
}
}
pub fn with_seed(environment: E, config: SearchConfig, seed: u64) -> Self {
let root_actions = environment.legal_actions();
let root = Node::root(root_actions);
Self {
root_env: environment,
config,
nodes: vec![root],
rng: rand_chacha::ChaCha8Rng::seed_from_u64(seed),
#[cfg(feature = "dag")]
transposition_table: None,
evaluator: None,
max_nodes: None,
}
}
pub fn checkpoint(&self) -> TreeSearchCheckpoint<E>
where
E: serde::Serialize + for<'de> serde::Deserialize<'de>,
E::Action: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
{
let mut rng_clone = self.rng.clone();
let seed = rng_clone.next_u64();
TreeSearchCheckpoint {
root_env: self.root_env.clone(),
config: self.config.clone(),
nodes: self.nodes.clone(),
rng_seed: Some(seed),
}
}
pub fn restore(checkpoint: TreeSearchCheckpoint<E>) -> Self
where
E: serde::Serialize + for<'de> serde::Deserialize<'de>,
E::Action: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
{
let rng = if let Some(seed) = checkpoint.rng_seed {
rand_chacha::ChaCha8Rng::seed_from_u64(seed)
} else {
entropy_rng()
};
Self {
root_env: checkpoint.root_env,
config: checkpoint.config,
nodes: checkpoint.nodes,
rng,
#[cfg(feature = "dag")]
transposition_table: None,
evaluator: None,
max_nodes: None,
}
}
pub fn run(&mut self) -> Option<E::Action> {
let deadline = self
.config
.time_budget
.map(|d| std::time::Instant::now() + d);
for _ in 0..self.config.iterations {
if let Some(dl) = deadline {
if std::time::Instant::now() >= dl {
break;
}
}
self.run_step();
}
self.best_root_action()
}
pub fn run_step(&mut self) {
let mut env = self.root_env.clone();
let (node_id, mut path) = self.select(&mut env);
let state = env.evaluate();
if state == Outcome::Ongoing && self.should_expand(node_id) {
let expanded = self.expand(node_id, &mut env);
if expanded != node_id {
path.push(expanded);
}
} else if state != Outcome::Ongoing {
self.nodes[node_id as usize].terminal = true;
}
let reward = self.simulate(&mut env);
self.backpropagate(&path, reward);
}
#[cfg(feature = "parallel")]
pub fn run_parallel(&mut self, threads: usize) -> Option<E::Action>
where
E: Send + Sync,
E::Action: Eq + std::hash::Hash + Send + Sync,
{
use rayon::prelude::*;
let mut seeds = Vec::with_capacity(threads);
for _ in 0..threads {
seeds.push(self.rng.next_u64());
}
let merged_stats = seeds
.into_par_iter()
.map(|seed| {
let mut search =
TreeSearch::with_seed(self.root_env.clone(), self.config.clone(), seed);
let iterations_per_thread = self.config.iterations / threads.max(1);
search.config.iterations = iterations_per_thread;
#[cfg(feature = "dag")]
if self.transposition_table.is_some() {
search.enable_dag();
}
search.run();
search.root_stats()
})
.reduce(Vec::new, |mut acc, thread_stats| {
if acc.is_empty() {
return thread_stats;
}
for (action, stat) in thread_stats {
if let Some((_, acc_stat)) = acc.iter_mut().find(|(a, _)| a == &action) {
acc_stat.visits += stat.visits;
let total_visits = acc_stat.visits;
if total_visits > 0 {
acc_stat.average_reward = ((acc_stat.average_reward
* f64::from(acc_stat.visits - stat.visits))
+ (stat.average_reward * f64::from(stat.visits)))
/ f64::from(total_visits);
}
} else {
acc.push((action, stat));
}
}
acc
});
merged_stats
.into_iter()
.max_by_key(|(_, stat)| stat.visits)
.map(|(action, _)| action)
}
pub fn root_stats(&self) -> Vec<(E::Action, NodeStats)>
where
E::Action: Clone,
{
let root = &self.nodes[0];
root.children
.iter()
.filter_map(|&child_id| {
let child = &self.nodes[child_id as usize];
let action = child.action.clone()?;
let avg = if child.visits > 0 {
child.cumulative_reward / f64::from(child.visits)
} else {
0.0
};
Some((
action,
NodeStats {
visits: child.visits,
average_reward: avg,
children_count: child.children.len(),
unexpanded_count: child.unexpanded.len(),
},
))
})
.collect()
}
pub fn tree_size(&self) -> usize {
self.nodes.len()
}
pub fn total_simulations(&self) -> u32 {
self.nodes[0].visits
}
pub fn with_evaluator(
&mut self,
evaluator: std::sync::Arc<dyn crate::environment::Evaluator<E>>,
) {
self.evaluator = Some(evaluator);
}
pub fn with_max_nodes(&mut self, limit: usize) {
self.max_nodes = Some(limit);
}
pub fn run_until<F>(&mut self, mut predicate: F) -> Option<E::Action>
where
F: FnMut(&Self) -> bool,
{
loop {
if predicate(self) {
break;
}
self.run_step();
}
self.best_root_action()
}
pub fn principal_variation_states(&self) -> Vec<E>
where
E::Action: Clone,
{
let pv = self.principal_variation();
let mut current = self.root_env.clone();
let mut states = vec![current.clone()];
for action in &pv {
current.apply(action);
states.push(current.clone());
}
states
}
pub fn export_dot(&self, depth: usize) -> String
where
E::Action: std::fmt::Debug,
{
let mut dot = String::from("digraph mctrust {\n rankdir=TB;\n node [shape=record, style=filled, fillcolor=\"#f0f0f0\"];\n");
self.export_dot_recursive(0, 0, depth, &mut dot);
dot.push_str("}\n");
dot
}
fn export_dot_recursive(
&self,
node_id: u32,
current_depth: usize,
max_depth: usize,
dot: &mut String,
) where
E::Action: std::fmt::Debug,
{
use std::fmt::Write;
if current_depth >= max_depth {
return;
}
let node = &self.nodes[node_id as usize];
let avg = if node.visits > 0 {
node.cumulative_reward / f64::from(node.visits)
} else {
0.0
};
let label = if let Some(ref action) = node.action {
format!("{action:?}|V:{} R:{avg:.3}", node.visits)
} else {
format!("root|V:{} R:{avg:.3}", node.visits)
};
let _ = writeln!(dot, " n{node_id} [label=\"{{{label}}}\"];");
for &child_id in &node.children {
let _ = writeln!(dot, " n{node_id} -> n{child_id};");
self.export_dot_recursive(child_id, current_depth + 1, max_depth, dot);
}
}
pub fn uses_rave(&self) -> bool {
self.config.rave.enabled
}
#[cfg(feature = "dag")]
pub fn enable_dag(&mut self) {
self.transposition_table = Some(hashbrown::HashMap::new());
}
#[cfg(feature = "dag")]
pub fn disable_dag(&mut self) {
self.transposition_table = None;
}
#[cfg(feature = "dag")]
pub fn dag_hit_count(&self) -> usize {
self.transposition_table
.as_ref()
.map_or(0, hashbrown::HashMap::len)
}
pub fn principal_variation(&self) -> Vec<E::Action>
where
E::Action: Clone,
{
let mut pv = Vec::new();
let mut current = 0u32;
let mut visited = std::collections::HashSet::new();
visited.insert(current);
loop {
let node = &self.nodes[current as usize];
if node.children.is_empty() {
break;
}
let best_child = node
.children
.iter()
.copied()
.max_by_key(|&id| self.nodes[id as usize].visits);
match best_child {
Some(child_id) => {
if !visited.insert(child_id) {
break;
}
if let Some(ref action) = self.nodes[child_id as usize].action {
pv.push(action.clone());
}
current = child_id;
}
None => break,
}
}
pv
}
pub fn best_root_reward(&self) -> Option<f64> {
let id = self.best_root_child_id()?;
let node = &self.nodes[id as usize];
if node.visits > 0 {
Some(node.cumulative_reward / f64::from(node.visits))
} else {
Some(0.0)
}
}
fn best_root_child_id(&self) -> Option<u32> {
let root = &self.nodes[0];
root.children
.iter()
.copied()
.max_by_key(|&id| self.nodes[id as usize].visits)
}
pub fn advance_to_action(&mut self, action: &E::Action) -> bool {
let root = &self.nodes[0];
let child_id = root
.children
.iter()
.copied()
.find(|&id| self.nodes[id as usize].action.as_ref() == Some(action));
match child_id {
Some(id) => {
self.advance_to_child(id);
true
}
None => false,
}
}
fn advance_to_child(&mut self, child_id: u32) {
if let Some(ref action) = self.nodes[child_id as usize].action {
self.root_env.apply(action);
}
let mut old_to_new = std::collections::HashMap::new();
let mut queue = std::collections::VecDeque::new();
let mut new_nodes = Vec::new();
queue.push_back(child_id);
old_to_new.insert(child_id, 0u32);
while let Some(old_id) = queue.pop_front() {
let node = &self.nodes[old_id as usize];
for &child in &node.children {
if !old_to_new.contains_key(&child) {
let new_id = u32::try_from(old_to_new.len()).unwrap_or(u32::MAX);
old_to_new.insert(child, new_id);
queue.push_back(child);
}
}
}
new_nodes.resize_with(old_to_new.len(), || Node::root(Vec::new()));
for (&old_id, &new_id) in &old_to_new {
let old_node = &self.nodes[old_id as usize];
let mut new_node = old_node.clone();
new_node.parent = old_node.parent.and_then(|p| old_to_new.get(&p).copied());
new_node.children = old_node
.children
.iter()
.filter_map(|c| old_to_new.get(c).copied())
.collect();
new_nodes[new_id as usize] = new_node;
}
if let Some(root) = new_nodes.first_mut() {
root.parent = None;
root.action = None;
}
self.nodes = new_nodes;
#[cfg(feature = "dag")]
{
if self.transposition_table.is_some() {
self.transposition_table = Some(hashbrown::HashMap::new());
}
}
}
}
fn entropy_rng() -> rand_chacha::ChaCha8Rng {
match rand_chacha::ChaCha8Rng::try_from_rng(&mut rand::rngs::SysRng) {
Ok(rng) => rng,
Err(error) => panic!("failed to seed ChaCha8Rng from system RNG: {error}"),
}
}
mod phases;
#[cfg(test)]
mod tests;