ncpig 0.6.1

Non-Cooperative Perfect Information Games, and algorithms to play them.
Documentation
//! Monte Carlo Tree Search selection strategies.

use std::cell::{BorrowMutError, RefCell};
use std::cmp::Ordering;
use std::error::Error;
use std::fmt::Debug;
use std::hash::Hash;

use petgraph::graph::NodeIndex;
use petgraph::{Direction, Graph};
use rand::rngs::ThreadRng;
use rand::seq::IteratorRandom;

use super::{is_fully_expanded, Node};
use crate::prelude::{Game, GameError};

/// MCTS selection strategy.
///
/// Given a node, a selection strategy determines how to traverse the currently-built tree
pub trait Selection<const N: usize, G: Game<N>, T>: Debug {
    /// Error returned by selection method.
    type Error: Error;

    /// Traverse to a frontier node, starting from some node.
    fn select(
        &self,
        game: &G,
        tree: &Graph<Node<N, G::State, T>, G::Action>,
        node_idx: NodeIndex,
    ) -> Result<NodeIndex, Self::Error>;
}

/// Errors returned by [`UCT`].
#[derive(Debug, thiserror::Error)]
pub enum UCTError<const N: usize, G>
where
    G: Game<N>,
    G::Action: Debug,
{
    /// Errors in the [`Game`].
    #[error(transparent)]
    GameError(G::Error),

    /// Errors converting numbers.
    #[error("{0}")]
    NumberConversionError(String),

    /// No payout for a player.
    #[error("No payout for requested player index")]
    NoPlayerPayout,

    /// Unable find the state associated with an action.
    #[error("Could not find child associated with action {0:?}")]
    UnknownChildAction(G::Action),

    /// Unable to borrow a refcell mutably
    #[error(transparent)]
    BorrowMutError(#[from] BorrowMutError),
}

impl<const N: usize, G, E> From<E> for UCTError<N, G>
where
    G: Game<N, Error = E>,
    G::Action: Debug,
    E: GameError,
{
    fn from(value: G::Error) -> Self {
        Self::GameError(value)
    }
}

/// UCT (Upper Confidence Bounds, applied to trees) algorithm.
///
/// [^1]: Levente Kocsis and Csaba Szepesvári, "Bandit based Monte-Carlo Planning", *Proceedings
/// of the 17th European Conference on Machine Learning*, 282–293. Springer-Verlag, 2006.
#[derive(Debug)]
pub struct UCT {
    c: f64,
    threshold: Option<(u32, RefCell<ThreadRng>)>,
}

impl Default for UCT {
    fn default() -> Self {
        Self {
            c: 0.7,
            threshold: Some((30, RefCell::new(ThreadRng::default()))),
        }
    }
}

impl UCT {
    /// Create a new [`UCT`].
    ///
    /// The `c` parameter is a tuning constant to trade off between exploration & exploitation.
    /// Larger values result in more exploration.
    ///
    /// [^1]: Nathan R. Sturtevant, "An Analaysis of UCT in Multi-Player Games", *Computers and
    /// Games*, Vol. 5131 of LNCS, 37–49, 2008.
    pub fn new(c: f64, threshold: Option<u32>) -> Self {
        Self {
            c,
            threshold: threshold.map(|t| (t, RefCell::new(ThreadRng::default()))),
        }
    }

    fn value<const N: usize, G: Game<N>, T>(
        &self,
        game: &G,
        node: &Node<N, G::State, T>,
        parent: &Node<N, G::State, T>,
    ) -> Result<f64, UCTError<N, G>>
    where
        T: Copy + TryInto<f64>,
        <T as TryInto<f64>>::Error: Debug,
        G::Action: Debug,
    {
        let payout: f64 = (*node
            .payouts
            .get(game.active_player_index(&parent.state)?)
            .ok_or_else(|| UCTError::NoPlayerPayout)?)
        .try_into()
        .map_err(|err| UCTError::NumberConversionError(format!("{:?}", err)))?;
        Ok(payout + self.c * ((parent.visits as f64).ln() / (node.visits as f64)).sqrt())
    }
}

impl<const N: usize, G: Game<N>, T> Selection<N, G, T> for UCT
where
    T: Copy + TryInto<f64> + Debug,
    <T as TryInto<f64>>::Error: Debug,
    G::Action: Clone + Hash + Eq + Debug,
    Node<N, G::State, T>: Clone,
{
    type Error = UCTError<N, G>;

    fn select(
        &self,
        game: &G,
        tree: &Graph<Node<N, G::State, T>, G::Action>,
        node_idx: NodeIndex,
    ) -> Result<NodeIndex, UCTError<N, G>> {
        let node = &tree[node_idx];
        log::debug!("running the UCT selection strategy from {:?}", node);
        if !is_fully_expanded(node_idx, game, tree)? {
            log::debug!("Node is not fully expanded, so not using UTC");
            return Ok(node_idx);
        }

        let children = tree.neighbors_directed(node_idx, Direction::Outgoing);

        match &self.threshold {
            Some((threshold, rng)) if node.visits < *threshold => {
                log::debug!("Selecting randomly since below UCT threshold");
                let mut rng = rng.try_borrow_mut()?;
                children.choose(&mut *rng).map(Ok)
            }
            _ => {
                log::debug!("Using UCT strategy since above (or no) threshold");
                children
                    .map(|child_idx| Ok((child_idx, self.value(game, &tree[child_idx], node)?)))
                    .collect::<Result<Vec<_>, UCTError<N, G>>>()?
                    .into_iter()
                    .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less))
                    .map(|(best_child_idx, _)| self.select(game, tree, best_child_idx))
            }
        }
        .unwrap_or(Ok(node_idx))
    }
}