use rand::SeedableRng;
use crate::config::SearchConfig;
use crate::environment::{Environment, GameState};
use crate::node::{Node, NodeStats};
pub struct GameSearch<E: Environment> {
pub(crate) root_env: E,
pub(crate) config: SearchConfig,
pub(crate) nodes: Vec<Node<E::Action>>,
pub(crate) rng: rand::rngs::StdRng,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct GameSearchCheckpoint<E>
where
E: Environment + Clone,
E::Action: serde::Serialize + for<'action> serde::Deserialize<'action>,
{
pub root_env: E,
pub config: SearchConfig,
pub nodes: Vec<Node<E::Action>>,
}
impl<E: Environment> GameSearch<E> {
#[must_use]
pub fn new(environment: E, config: SearchConfig) -> Self {
let root_actions = environment.legal_actions();
let root = Node::root(root_actions);
Self {
root_env: environment,
config,
nodes: vec![root],
rng: rand::rngs::StdRng::from_entropy(),
}
}
#[must_use]
pub fn with_seed(environment: E, config: SearchConfig, seed: u64) -> Self {
let root_actions = environment.legal_actions();
let root = Node::root(root_actions);
Self {
root_env: environment,
config,
nodes: vec![root],
rng: rand::rngs::StdRng::seed_from_u64(seed),
}
}
#[must_use]
pub fn checkpoint(&self) -> GameSearchCheckpoint<E>
where
E: serde::Serialize + for<'de> serde::Deserialize<'de>,
E::Action: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
{
GameSearchCheckpoint {
root_env: self.root_env.clone(),
config: self.config.clone(),
nodes: self.nodes.clone(),
}
}
#[must_use]
pub fn restore(checkpoint: GameSearchCheckpoint<E>) -> Self
where
E: serde::Serialize + for<'de> serde::Deserialize<'de>,
E::Action: serde::Serialize + for<'de> serde::Deserialize<'de> + Clone,
{
Self {
root_env: checkpoint.root_env,
config: checkpoint.config,
nodes: checkpoint.nodes,
rng: rand::rngs::StdRng::from_entropy(),
}
}
pub fn run(&mut self) -> Option<E::Action> {
for _ in 0..self.config.iterations {
let mut env = self.root_env.clone();
let (node_id, mut path) = self.select(&mut env);
let state = env.evaluate();
if state == GameState::Ongoing && self.should_expand(node_id) {
let expanded = self.expand(node_id, &mut env);
if expanded != node_id {
path.push(expanded);
}
} else if state != GameState::Ongoing {
self.nodes[node_id as usize].terminal = true;
}
let reward = self.simulate(&mut env);
self.backpropagate(&path, reward);
}
self.best_root_action()
}
pub fn root_stats(&self) -> Vec<(E::Action, NodeStats)>
where
E::Action: Clone,
{
let root = &self.nodes[0];
root.children
.iter()
.filter_map(|&child_id| {
let child = &self.nodes[child_id as usize];
let action = child.action.clone()?;
let avg = if child.visits > 0 {
child.cumulative_reward / f64::from(child.visits)
} else {
0.0
};
Some((
action,
NodeStats {
visits: child.visits,
average_reward: avg,
children_count: child.children.len(),
unexpanded_count: child.unexpanded.len(),
},
))
})
.collect()
}
pub fn tree_size(&self) -> usize {
self.nodes.len()
}
pub fn total_simulations(&self) -> u32 {
self.nodes[0].visits
}
#[must_use]
pub fn uses_rave(&self) -> bool {
self.config.rave.enabled
}
}
mod phases;
#[cfg(test)]
mod tests;