use super::GameSearch;
use crate::config::{ProgressiveWideningConfig, TreePolicy};
use crate::environment::{Environment, GameState};
use crate::node::Node;
use crate::reward::Reward;
use rand::seq::SliceRandom;
use rand::Rng;
impl<E: Environment> GameSearch<E> {
pub(crate) fn pick_best_child(&mut self, current: u32, env: &mut E) -> Option<u32> {
let node = &self.nodes[current as usize];
let parent_visits = node.visits;
let legal = env.legal_actions();
let priors = match self.config.tree_policy {
TreePolicy::Puct { .. } => env.action_priors(&legal),
_ => None,
};
let child_ids: Vec<u32> = node.children.clone();
let mut best_child = None;
let mut best_score = f64::NEG_INFINITY;
for child in child_ids {
let score = self.selection_score(child, parent_visits, &legal, priors.as_ref());
if score
.partial_cmp(&best_score)
.is_some_and(std::cmp::Ordering::is_gt)
{
best_score = score;
best_child = Some(child);
}
}
best_child
}
pub(crate) fn select(&mut self, env: &mut E) -> (u32, Vec<u32>) {
let mut current = 0u32;
let mut path = vec![current];
loop {
let node = &self.nodes[current as usize];
if node.terminal || node.children.is_empty() || self.should_expand(current) {
return (current, path);
}
match self.pick_best_child(current, env) {
Some(child) => {
if let Some(ref action) = self.nodes[child as usize].action {
env.apply(action);
}
path.push(child);
current = child;
}
None => return (current, path),
}
}
}
pub(crate) fn selection_score(
&mut self,
child: u32,
parent_visits: u32,
legal_actions: &[E::Action],
priors: Option<&Vec<f64>>,
) -> f64 {
let node = &self.nodes[child as usize];
let score = node.uct_score_with_rave(
parent_visits,
self.config.exploration_constant,
self.config.rave.enabled,
self.config.rave.bias,
);
match self.config.tree_policy {
TreePolicy::Puct { prior_weight } => {
let prior = if let Some(priors) = priors {
if priors.is_empty() {
1.0
} else {
let idx = node
.action
.as_ref()
.and_then(|action| legal_actions.iter().position(|a| a == action));
if let Some(i) = idx {
priors.get(i).copied().unwrap_or(1.0).max(0.0)
} else {
1.0
}
}
} else {
1.0
};
score + prior * prior_weight
}
TreePolicy::ThompsonSampling { temperature } => {
let noise = (self.rng.gen_range(0.0..1.0) - 0.5) * 2.0 * temperature.max(0.0);
score + noise
}
TreePolicy::Uct => score,
}
}
pub(crate) fn should_expand(&self, node_id: u32) -> bool {
let node = &self.nodes[node_id as usize];
if node.is_fully_expanded() {
return false;
}
if let Some(cfg) = &self.config.progressive_widening {
let max_children = Self::progressive_limit(node.visits, cfg);
node.children.len() < max_children
} else {
true
}
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_sign_loss
)]
pub(crate) fn progressive_limit(parent_visits: u32, cfg: &ProgressiveWideningConfig) -> usize {
let budget = cfg.coefficient * (f64::from(parent_visits)).powf(cfg.exponent);
let budget = if budget.is_finite() && budget > 0.0 {
budget.floor().min(usize::MAX as f64) as usize
} else {
0
};
let candidate = cfg.minimum_children.max(budget);
candidate.max(cfg.minimum_children)
}
pub(crate) fn expand(&mut self, parent_id: u32, env: &mut E) -> u32 {
let Some(action) = self.nodes[parent_id as usize].unexpanded.pop() else {
return parent_id;
};
env.apply(&action);
let legal = env.legal_actions();
let Ok(child_id) = u32::try_from(self.nodes.len()) else {
return parent_id;
};
let mut child = Node::child(parent_id, action, legal);
let state = env.evaluate();
if state != GameState::Ongoing {
child.terminal = true;
}
self.nodes.push(child);
self.nodes[parent_id as usize].children.push(child_id);
child_id
}
pub(crate) fn simulate(&mut self, env: &mut E) -> Reward {
let mut depth = 0usize;
let depth_limit = env.max_depth().unwrap_or(self.config.max_depth);
loop {
match env.evaluate() {
GameState::Win(r) | GameState::Terminal(r) => return r,
GameState::Loss => return Reward::LOSS,
GameState::Draw => return Reward::DRAW,
GameState::Ongoing => {
if depth >= depth_limit {
if let Some(heuristic) = env.heuristic().value {
return Reward::new(heuristic.value() * self.config.heuristic_weight);
}
return Reward::DRAW;
}
let actions = env.legal_actions();
if actions.is_empty() {
return Reward::DRAW;
}
let action = match actions.choose(&mut self.rng) {
Some(a) => a.clone(),
None => return Reward::DRAW,
};
env.apply(&action);
depth += 1;
}
}
}
}
pub(crate) fn backpropagate(&mut self, path: &[u32], reward: Reward) {
let value = reward.value();
for &node_id in path {
let node = &mut self.nodes[node_id as usize];
node.apply_uct_update(value);
if self.config.rave.enabled {
node.apply_rave_update(value);
}
}
}
pub(crate) fn best_root_action(&self) -> Option<E::Action> {
let root = &self.nodes[0];
if root.children.is_empty() {
return None;
}
root.children
.iter()
.copied()
.max_by_key(|&id| self.nodes[id as usize].visits)
.and_then(|id| self.nodes[id as usize].action.clone())
}
}