mctrust 0.4.0

Universal search & planning toolkit — MCTS, bandit search, pluggable evaluators, tree reuse, DAG transpositions, root parallelism. Define an Environment, search handles the rest.
Documentation
//! The `Environment` trait — the core integration point for MCTS.

use crate::reward::Reward;

/// Terminal or non-terminal state of an environment.
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum Outcome {
    /// Search can continue from this state.
    Ongoing,
    /// A generalized terminal outcome with an explicit reward.
    Terminal(Reward),
    /// Convenience alias for a positive terminal outcome (e.g., goal reached).
    Success(Reward),
    /// Convenience alias for a negative terminal outcome (e.g., constraint violated).
    Failure,
    /// Convenience alias for a neutral terminal outcome (e.g., budget exhausted).
    Neutral,
}

impl Outcome {
    /// Reports whether the state is terminal (search should stop).
    ///
    /// # Returns
    ///
    /// Returns `true` for every variant except [`Outcome::Ongoing`].
    pub fn is_terminal(self) -> bool {
        !matches!(self, Self::Ongoing)
    }

    /// Converts terminal states into a reward when one is defined.
    ///
    /// Returns `Some(reward)` for terminal states and `None` for ongoing states.
    pub fn reward(self) -> Option<Reward> {
        match self {
            Self::Ongoing => None,
            Self::Terminal(reward) | Self::Success(reward) => Some(reward),
            Self::Failure => Some(Reward::LOSS),
            Self::Neutral => Some(Reward::DRAW),
        }
    }
}

impl std::fmt::Display for Outcome {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::Ongoing => f.write_str("ongoing"),
            Self::Terminal(reward) => write!(f, "terminal({reward})"),
            Self::Success(reward) => write!(f, "success({reward})"),
            Self::Failure => f.write_str("failure"),
            Self::Neutral => f.write_str("neutral"),
        }
    }
}

/// Optional heuristic signals for a state.
#[derive(Debug, Clone, Copy, PartialEq, Default, serde::Serialize, serde::Deserialize)]
pub struct Heuristic {
    /// Estimated value of the current state.
    pub value: Option<Reward>,
}

impl Heuristic {
    /// Creates a heuristic containing a concrete reward estimate.
    pub fn from_reward(value: Reward) -> Self {
        Self { value: Some(value) }
    }
}

/// A domain environment that the MCTS engine can explore.
///
/// Implement this trait for your specific problem domain. The trait is
/// intentionally broad enough for adversarial games, single-agent planning,
/// optimization, scheduling, financial portfolio construction, security
/// scan strategies, and scientific search problems.
pub trait Environment: Clone + Send + Sync {
    /// The action type chosen by the search.
    type Action: Clone + Send + Sync + std::fmt::Debug + PartialEq;

    /// Returns all legal actions from the current state.
    fn legal_actions(&self) -> Vec<Self::Action>;

    /// Applies an action to the current state.
    fn apply(&mut self, action: &Self::Action);

    /// Evaluates the current state.
    fn evaluate(&self) -> Outcome;

    /// Returns the index of the agent whose turn it is to act.
    ///
    /// For single-agent optimization or planning, return `0` (the default).
    /// For multi-agent / adversarial environments, alternate between agents.
    ///
    /// The search engine uses this to negate rewards during backpropagation
    /// in multi-agent environments.
    fn current_player(&self) -> usize {
        0
    }

    /// Returns the total number of agents in this environment.
    ///
    /// When this returns `1` (the default), backpropagation is standard.
    /// When `>1`, the engine applies negamax-style reward flipping.
    fn num_players(&self) -> usize {
        1
    }

    /// Optional value estimate for non-terminal states.
    ///
    /// This is used when rollouts hit a depth cap or when a domain wants to
    /// inject prior knowledge into the search.
    fn heuristic(&self) -> Heuristic {
        Heuristic::default()
    }

    /// Optional state-specific depth budget override.
    fn max_depth(&self) -> Option<usize> {
        None
    }

    /// Optional action priors for PUCT-style search.
    ///
    /// The returned vector must have the same length and ordering as `actions`.
    /// Invalid priors are ignored and the engine falls back to a uniform prior.
    fn action_priors(&self, _actions: &[Self::Action]) -> Option<Vec<f64>> {
        None
    }

    /// Generates a deterministic, collision-resistant hash of the environment state.
    ///
    /// If implemented, the MCTS engine will utilize this in a Transposition Table
    /// to merge identical states reached through different paths, compressing the
    /// search tree into a Directed Acyclic Graph (DAG) and radically improving convergence.
    fn state_hash(&self) -> Option<u64> {
        None
    }
}

/// Pluggable leaf evaluator for replacing random rollouts with domain-specific
/// or neural network-based evaluation.
///
/// This is the integration point for AlphaZero-style search: instead of running
/// a random rollout to a terminal state, the engine calls [`evaluate`](Self::evaluate)
/// and uses the returned [`Reward`] directly.
///
/// # Default
///
/// When no evaluator is provided, the engine uses the built-in random rollout
/// simulator, which is the default MCTS behavior.
///
/// # Examples
///
/// ```rust
/// use mctrust::{Evaluator, Environment, Reward};
///
/// struct MyNeuralEval;
///
/// impl<E: Environment> Evaluator<E> for MyNeuralEval {
///     fn evaluate(&self, _env: &E) -> Reward {
///         // In a real implementation, pass the state through a neural network
///         // and return the value head output.
///         Reward::new(0.5)
///     }
/// }
/// ```
pub trait Evaluator<E: Environment>: Send + Sync {
    /// Evaluates the given environment state and returns a reward estimate.
    ///
    /// This replaces the default random-rollout simulation. The engine will
    /// use this value directly for backpropagation.
    fn evaluate(&self, env: &E) -> Reward;
}

#[cfg(test)]
mod tests {
    use super::{Heuristic, Outcome, Reward};

    #[test]
    fn outcome_terminal_detection() {
        assert!(Outcome::Terminal(Reward::WIN).is_terminal());
        assert!(!Outcome::Ongoing.is_terminal());
        assert_eq!(Outcome::Failure.reward(), Some(Reward::LOSS));
    }

    #[test]
    fn heuristic_default_and_constructor() {
        assert_eq!(Heuristic::default(), Heuristic { value: None });
        let h = Heuristic::from_reward(Reward::new(0.25));
        assert_eq!(h.value, Some(Reward::new(0.25)));
    }

    #[test]
    fn format_terminal_states() {
        assert_eq!(format!("{}", Outcome::Success(Reward::WIN)), "success(1)");
        assert_eq!(format!("{}", Outcome::Failure), "failure");
        assert_eq!(format!("{}", Outcome::Neutral), "neutral");
    }
}