mctrust 0.4.0

Universal search & planning toolkit — MCTS, bandit search, pluggable evaluators, tree reuse, DAG transpositions, root parallelism. Define an Environment, search handles the rest.
Documentation
use super::TreeSearch;
use crate::config::{ProgressiveWideningConfig, TreePolicy};
use crate::environment::{Environment, Outcome};
use crate::node::Node;
use crate::reward::Reward;
use rand::prelude::IndexedRandom;
use rand::RngExt;

impl<E: Environment> TreeSearch<E> {
    // ── Internal phases ─────────────────────────────────────────

    pub(crate) fn pick_best_child(&mut self, current: u32, env: &mut E) -> Option<u32> {
        let node = &self.nodes[current as usize];
        let parent_visits = node.visits;

        // Only compute legal actions and priors when PUCT needs them.
        let (legal, priors) = if matches!(self.config.tree_policy, TreePolicy::Puct { .. }) {
            let legal = env.legal_actions();
            let priors = env.action_priors(&legal);
            (legal, priors)
        } else {
            (Vec::new(), None)
        };

        // Special-case root selection for Gumbel MuZero-style policy. When at the root
        // we perform a sampled selection followed by sequential halving using Gumbel
        // perturbations as described in the crate docs. For non-root nodes we fall
        // back to completed Q + single-sample Gumbel.
        if let TreePolicy::Gumbel {
            sampled_actions: _sampled_actions,
            max_completions_coeff,
        } = &self.config.tree_policy
        {
            if current == 0 {
                // Root-level Gumbel selection: single-sample Gumbel perturbation per child
                // followed by choosing the highest perturbed completed-Q. This preserves
                // the original RNG consumption order and makes behavior deterministic
                // for a given seed.
                let children = node.children.clone();
                if children.is_empty() {
                    return None;
                }

                let mut best = None;
                let mut best_score = f64::NEG_INFINITY;

                for child_idx in children {
                    let c = &self.nodes[child_idx as usize];
                    let visits_f = f64::from(c.visits);
                    let completed_q = if c.visits > 0 {
                        let raw_q = c.cumulative_reward / visits_f;
                        let mix = *max_completions_coeff / (*max_completions_coeff + visits_f);
                        mix * 0.0 + (1.0 - mix) * raw_q
                    } else {
                        0.0
                    };

                    let u: f64 = self.rng.random_range(1e-10..1.0_f64 - 1e-10);
                    let gumbel_noise = -((-u.ln()).ln());
                    let sc = completed_q + gumbel_noise;

                    if sc
                        .partial_cmp(&best_score)
                        .is_some_and(std::cmp::Ordering::is_gt)
                    {
                        best_score = sc;
                        best = Some(child_idx);
                    }
                }

                return best;
            }
        }

        // Default selection for UCT/PUCT/Thompson/Gumbel(non-root)
        let child_ids: Vec<u32> = node.children.clone();
        let mut best_child = None;
        let mut best_score = f64::NEG_INFINITY;

        for child in child_ids {
            let score = self.selection_score(child, parent_visits, &legal, priors.as_ref());

            if score
                .partial_cmp(&best_score)
                .is_some_and(std::cmp::Ordering::is_gt)
            {
                best_score = score;
                best_child = Some(child);
            }
        }
        best_child
    }

    /// Selection: descend the tree via configured policy, applying actions to the environment clone.
    ///
    /// When the `dag` feature is enabled, transposition reuse can create graph cycles.
    /// This function detects revisited nodes and treats them as leaves to prevent
    /// infinite descent.
    pub(crate) fn select(&mut self, env: &mut E) -> (u32, Vec<u32>) {
        let mut current = 0u32;
        let mut path = vec![current];

        // DAG cycle guard: track which node IDs we've visited on this descent.
        #[cfg(feature = "dag")]
        let mut visited = std::collections::HashSet::new();
        #[cfg(feature = "dag")]
        visited.insert(current);

        loop {
            let node = &self.nodes[current as usize];

            // Expand immediately if this is a terminal node, a leaf, or we are
            // still below progressive widening capacity.
            if node.terminal || node.children.is_empty() || self.should_expand(current) {
                return (current, path);
            }

            match self.pick_best_child(current, env) {
                Some(child) => {
                    // DAG cycle detection: if we've already visited this node
                    // on this descent path, treat it as a leaf to break the cycle.
                    #[cfg(feature = "dag")]
                    if self.transposition_table.is_some() && !visited.insert(child) {
                        return (current, path);
                    }

                    if let Some(ref action) = self.nodes[child as usize].action {
                        env.apply(action);
                    }
                    path.push(child);
                    current = child;
                }
                // All children are empty — treat as leaf.
                None => return (current, path),
            }
        }
    }

    pub(crate) fn selection_score(
        &mut self,
        child: u32,
        parent_visits: u32,
        legal_actions: &[E::Action],
        priors: Option<&Vec<f64>>,
    ) -> f64 {
        let node = &self.nodes[child as usize];
        let score = node.uct_score_with_rave(
            parent_visits,
            self.config.exploration_constant,
            self.config.rave.enabled,
            self.config.rave.bias,
        );

        match &self.config.tree_policy {
            TreePolicy::Puct { prior_weight } => {
                // Extract a non-negative, finite prior for this action (default 1.0).
                let prior = if let Some(priors) = priors {
                    if priors.is_empty() {
                        1.0
                    } else {
                        let idx = node
                            .action
                            .as_ref()
                            .and_then(|action| legal_actions.iter().position(|a| a == action));
                        if let Some(i) = idx {
                            let p = priors.get(i).copied().unwrap_or(1.0);
                            if !p.is_finite() || p < 0.0 {
                                1.0
                            } else {
                                p
                            }
                        } else {
                            1.0
                        }
                    }
                } else {
                    1.0
                };

                // Canonical PUCT: add a prior-driven exploration bonus scaled by
                // prior_weight (c_puct): c_puct * P(a) * sqrt(parent_visits) / (1 + n(a))
                let c_puct = *prior_weight;
                let parent_sqrt = f64::from(parent_visits).sqrt();
                let child_visits = f64::from(node.visits);
                let puct_bonus = if c_puct.is_finite() {
                    c_puct * prior * parent_sqrt / (1.0 + child_visits)
                } else {
                    0.0
                };

                score + puct_bonus
            }
            TreePolicy::ThompsonSampling { temperature } => {
                let noise = (self.rng.random_range(0.0..1.0) - 0.5) * 2.0 * temperature.max(0.0);
                score + noise
            }
            TreePolicy::Gumbel {
                max_completions_coeff,
                ..
            } => {
                // Gumbel MuZero: σ(a) = logit(a) + Gumbel(0,1)
                // For non-root nodes we fall back to completed Q-value mixing.
                // Gumbel(0,1) = -ln(-ln(U)) where U ~ Uniform(0,1)
                let u: f64 = self.rng.random_range(1e-10..1.0_f64 - 1e-10);
                let gumbel_noise = -((-u.ln()).ln());

                // Completed Q-value: mix the empirical Q with a prior towards 0
                let visits_f = f64::from(node.visits);
                let completed_q = if node.visits > 0 {
                    let raw_q = node.cumulative_reward / visits_f;
                    // Sigmoid mixing: as visits grow, trust the empirical Q more
                    let mix = max_completions_coeff / (max_completions_coeff + visits_f);
                    mix * 0.0 + (1.0 - mix) * raw_q
                } else {
                    0.0
                };

                completed_q + gumbel_noise
            }
            TreePolicy::Uct => score,
        }
    }

    pub(crate) fn should_expand(&self, node_id: u32) -> bool {
        let node = &self.nodes[node_id as usize];

        if node.is_fully_expanded() {
            return false;
        }

        if let Some(cfg) = &self.config.progressive_widening {
            let max_children = Self::progressive_limit(node.visits, cfg);
            node.children.len() < max_children
        } else {
            true
        }
    }

    #[allow(
        clippy::cast_possible_truncation,
        clippy::cast_precision_loss,
        clippy::cast_sign_loss
    )]
    pub(crate) fn progressive_limit(parent_visits: u32, cfg: &ProgressiveWideningConfig) -> usize {
        let budget = cfg.coefficient * (f64::from(parent_visits)).powf(cfg.exponent);
        let budget = if budget.is_finite() && budget > 0.0 {
            budget.floor().min(usize::MAX as f64) as usize
        } else {
            0
        };
        cfg.minimum_children.max(budget)
    }

    /// Expansion: pop one unexpanded action, create a child node.
    pub(crate) fn expand(&mut self, parent_id: u32, env: &mut E) -> u32 {
        // Node budget guard: stop allocating new nodes if limit is reached.
        if let Some(limit) = self.max_nodes {
            if self.nodes.len() >= limit {
                return parent_id;
            }
        }

        let Some(action) = self.nodes[parent_id as usize].unexpanded.pop() else {
            return parent_id;
        };

        env.apply(&action);

        #[cfg(feature = "dag")]
        {
            if let Some(hash) = env.state_hash() {
                if let Some(ref table) = self.transposition_table {
                    if let Some(&existing_id) = table.get(&hash) {
                        self.nodes[parent_id as usize].children.push(existing_id);
                        return existing_id;
                    }
                }
            }
        }

        let legal = env.legal_actions();
        let Ok(child_id) = u32::try_from(self.nodes.len()) else {
            return parent_id;
        };
        let mut child = Node::child(parent_id, action, legal);

        let state = env.evaluate();
        if state != Outcome::Ongoing {
            child.terminal = true;
        }

        self.nodes.push(child);
        self.nodes[parent_id as usize].children.push(child_id);

        #[cfg(feature = "dag")]
        {
            if let Some(hash) = env.state_hash() {
                if let Some(ref mut table) = self.transposition_table {
                    table.insert(hash, child_id);
                }
            }
        }

        child_id
    }

    /// Simulation: random rollout or pluggable evaluator from the current state.
    pub(crate) fn simulate(&mut self, env: &mut E) -> Reward {
        // If a pluggable evaluator is attached, use it instead of random rollout.
        if let Some(ref evaluator) = self.evaluator {
            return evaluator.evaluate(env);
        }

        let mut depth = 0usize;
        let depth_limit = env.max_depth().unwrap_or(self.config.max_depth);

        loop {
            match env.evaluate() {
                Outcome::Success(r) | Outcome::Terminal(r) => return r,
                Outcome::Failure => return Reward::LOSS,
                Outcome::Neutral => return Reward::DRAW,
                Outcome::Ongoing => {
                    if depth >= depth_limit {
                        if let Some(heuristic) = env.heuristic().value {
                            return Reward::new(heuristic.value() * self.config.heuristic_weight);
                        }
                        return Reward::DRAW;
                    }

                    let actions = env.legal_actions();
                    if actions.is_empty() {
                        return Reward::DRAW;
                    }

                    // Choose a random action.
                    let action = match actions.choose(&mut self.rng) {
                        Some(a) => a.clone(),
                        // `actions` is non-empty — this branch is unreachable.
                        // Returning DRAW is fail-safe.
                        None => return Reward::DRAW,
                    };

                    env.apply(&action);
                    depth += 1;
                }
            }
        }
    }

    /// Backpropagation: push reward up from root to the selected path.
    pub(crate) fn backpropagate(&mut self, path: &[u32], reward: Reward) {
        let value = reward.value();
        for &node_id in path {
            let node = &mut self.nodes[node_id as usize];
            node.apply_uct_update(value);
            if self.config.rave.enabled {
                node.apply_rave_update(value);
            }
        }
    }

    /// Returns the root child with the most visits (robust child policy).
    ///
    /// Use this to retrieve the best action after manual iteration via
    /// [`run_step()`](TreeSearch::run_step) or [`run_until()`](TreeSearch::run_until).
    pub fn best_root_action(&self) -> Option<E::Action> {
        let id = self.best_root_child_id()?;
        self.nodes[id as usize].action.clone()
    }
}