#![allow(clippy::field_reassign_with_default)]
use super::result::MctsResult;
use super::stats::MctsStats;
use crate::search::mcts::config::MctsConfig;
use crate::search::mcts::node::NodeId;
use crate::search::mcts::traits::{Action, ActionSpace, PolicyNetwork, State, StateSpace};
use crate::search::mcts::tree::SearchTree;
use crate::search::mcts::Reward;
pub struct MctsSearch<S: State, A: Action> {
tree: SearchTree<S, A>,
config: MctsConfig,
rng: rand::rngs::StdRng,
}
impl<S: State + Send + Sync, A: Action + Send + Sync> MctsSearch<S, A> {
pub fn new<AS: ActionSpace<S, A>>(
initial_state: S,
action_space: &AS,
config: MctsConfig,
) -> Self {
use rand::SeedableRng;
let actions = action_space.legal_actions(&initial_state);
let tree = SearchTree::new(initial_state, actions);
Self { tree, config, rng: rand::rngs::StdRng::from_os_rng() }
}
pub fn with_seed<AS: ActionSpace<S, A>>(
initial_state: S,
action_space: &AS,
config: MctsConfig,
seed: u64,
) -> Self {
use rand::SeedableRng;
let actions = action_space.legal_actions(&initial_state);
let tree = SearchTree::new(initial_state, actions);
Self { tree, config, rng: rand::rngs::StdRng::seed_from_u64(seed) }
}
pub fn search<SS, AS>(
&mut self,
state_space: &SS,
action_space: &AS,
policy: Option<&dyn PolicyNetwork<S, A>>,
) -> MctsResult<S, A>
where
SS: StateSpace<S, A>,
AS: ActionSpace<S, A>,
{
let mut total_sim_length = 0usize;
let mut max_depth = 0usize;
for _ in 0..self.config.max_iterations {
let (leaf_id, depth) = self.select();
max_depth = max_depth.max(depth);
let leaf_state = self.tree.get(leaf_id).map(|n| n.state.clone());
let Some(leaf_state) = leaf_state else {
continue;
};
if leaf_state.is_terminal() {
let reward = state_space.evaluate(&leaf_state);
self.backpropagate(leaf_id, reward);
continue;
}
let child_id = self.expand(leaf_id, state_space, action_space, policy);
let Some(child_id) = child_id else {
continue;
};
let child_state = self.tree.get(child_id).map(|n| n.state.clone());
let Some(child_state) = child_state else {
continue;
};
let (reward, sim_length) = self.simulate(&child_state, state_space, action_space);
total_sim_length += sim_length;
self.backpropagate(child_id, reward);
}
let root = self.tree.root();
let root_visits = root.stats.visits;
let action_visits: Vec<(A, usize)> = self
.tree
.children(self.tree.root_id)
.iter()
.filter_map(|child| child.action.clone().map(|a| (a, child.stats.visits)))
.collect();
let best_child =
self.tree.children(self.tree.root_id).into_iter().max_by_key(|n| n.stats.visits);
let (best_action, expected_reward, resulting_state) = if let Some(child) = best_child {
(child.action.clone(), child.stats.mean_reward, Some(child.state.clone()))
} else {
(None, 0.0, None)
};
let avg_simulation_length = if self.config.max_iterations > 0 {
total_sim_length as f64 / self.config.max_iterations as f64
} else {
0.0
};
MctsResult {
best_action,
expected_reward,
action_visits,
stats: MctsStats {
iterations: self.config.max_iterations,
tree_size: self.tree.size(),
max_depth,
avg_simulation_length,
root_visits,
},
resulting_state,
}
}
fn select(&self) -> (NodeId, usize) {
let mut current_id = self.tree.root_id;
let mut depth = 0;
loop {
let node = match self.tree.get(current_id) {
Some(n) => n,
None => return (current_id, depth),
};
if !node.untried_actions.is_empty() || node.state.is_terminal() {
return (current_id, depth);
}
if node.children.is_empty() {
return (current_id, depth);
}
let parent_visits = node.stats.visits;
let best_child =
node.children.iter().filter_map(|&cid| self.tree.get(cid)).max_by(|a, b| {
let score_a = if self.config.use_policy_priors {
a.stats.puct(parent_visits, self.config.exploration_constant)
} else {
a.stats.ucb1(parent_visits, self.config.exploration_constant)
};
let score_b = if self.config.use_policy_priors {
b.stats.puct(parent_visits, self.config.exploration_constant)
} else {
b.stats.ucb1(parent_visits, self.config.exploration_constant)
};
score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
});
match best_child {
Some(child) => {
current_id = child.id;
depth += 1;
}
None => return (current_id, depth),
}
}
}
fn expand<SS, AS>(
&mut self,
node_id: NodeId,
state_space: &SS,
action_space: &AS,
policy: Option<&dyn PolicyNetwork<S, A>>,
) -> Option<NodeId>
where
SS: StateSpace<S, A>,
AS: ActionSpace<S, A>,
{
let (action, parent_state) = {
let node = self.tree.get_mut(node_id)?;
let action = node.untried_actions.pop()?;
let parent_state = node.state.clone();
node.expanded = node.untried_actions.is_empty();
(action, parent_state)
};
let new_state = state_space.apply(&parent_state, &action);
let new_actions = action_space.legal_actions(&new_state);
let prior = policy
.and_then(|p| {
p.predict(&parent_state).iter().find(|(a, _)| a == &action).map(|(_, p)| *p)
})
.unwrap_or(1.0 / (new_actions.len().max(1) as f64));
let child_id = self.tree.add_child(node_id, new_state, action, new_actions, prior);
Some(child_id)
}
fn simulate<SS, AS>(
&mut self,
initial_state: &S,
state_space: &SS,
action_space: &AS,
) -> (Reward, usize)
where
SS: StateSpace<S, A>,
AS: ActionSpace<S, A>,
{
use rand::prelude::IndexedRandom;
let mut state = initial_state.clone();
let mut depth = 0;
while !state.is_terminal() && depth < self.config.max_simulation_depth {
let actions = action_space.legal_actions(&state);
if actions.is_empty() {
break;
}
if let Some(action) = actions.choose(&mut self.rng) {
state = state_space.apply(&state, action);
}
depth += 1;
}
(state_space.evaluate(&state), depth)
}
fn backpropagate(&mut self, leaf_id: NodeId, reward: Reward) {
let mut current_id = Some(leaf_id);
while let Some(id) = current_id {
if let Some(node) = self.tree.get_mut(id) {
node.stats.update(reward);
current_id = node.parent;
} else {
break;
}
}
}
#[must_use]
pub fn tree_size(&self) -> usize {
self.tree.size()
}
#[must_use]
pub fn tree(&self) -> &SearchTree<S, A> {
&self.tree
}
}