ncpig 0.6.1

Non-Cooperative Perfect Information Games, and algorithms to play them.
Documentation
//! Monte Carlo Tree Search backpropagation strategies.

use std::array;
use std::error::Error;
use std::fmt::Debug;

use iterstats::rank::Rank;
use iterstats::Iterstats;
use itertools::Itertools;

use super::Node;
#[cfg(doc)]
use super::{MonteCarloTreeSearch, Playout};
use crate::prelude::Game;

/// MCTS backpropagation strategy.
///
/// Propagate the result of a [`Playout`] up the tree.
pub trait Backpropagation<const N: usize, G: Game<N>, T> {
    /// Errors returned by the backpropagation strategy.
    type Error: Error;

    /// Incorporate the information of the result into all it's parent nodes.
    fn backpropagate(
        &self,
        node: &mut Node<N, G::State, T>,
        result: &[G::Score; N],
    ) -> Result<(), Self::Error>;
}

/// Errors associated with [`AverageScore`] [`Backpropagation`] strategy.
#[derive(Debug, thiserror::Error)]
pub enum AverageError<const N: usize, G: Game<N>, T>
where
    T: TryInto<f64>,
    <T as TryInto<f64>>::Error: Error,
    G::Score: TryInto<f64>,
    <G::Score as TryInto<f64>>::Error: Error,
    f64: TryInto<T>,
    <f64 as TryInto<T>>::Error: Error,
{
    /// Error converting [`Game::Score`] to an [`f64`].
    #[error("Error converting game score to a f64: {0}")]
    GameScoreToF64Conversion(<G::Score as TryInto<f64>>::Error),

    /// Error converting [`MonteCarloTreeSearch`] [`Node`] payout to an [`f64`].
    #[error("Error converting MCTS payout to a f64: {0}")]
    PayoutToF64Conversion(<T as TryInto<f64>>::Error),

    /// Error converting an [`f64`] to a [`MonteCarloTreeSearch`] [`Node`] payout.
    #[error("Error converting an f64 to a MCTS payout: {0}")]
    F64ToPayoutConversion(<f64 as TryInto<T>>::Error),
}

/// Average backpropagation method.
///
/// Takes the average [`Game::Score`] over all the times the node has been visited.
#[derive(Debug, Default)]
pub struct AverageScore;

impl<const N: usize, G: Game<N>, T> Backpropagation<N, G, T> for AverageScore
where
    T: Copy + Debug + TryInto<f64>,
    <T as TryInto<f64>>::Error: Error,
    G::Score: Copy + TryInto<f64>,
    <G::Score as TryInto<f64>>::Error: Error,
    f64: TryInto<T>,
    <f64 as TryInto<T>>::Error: Error,
    G::Action: Debug,
{
    type Error = AverageError<N, G, T>;

    fn backpropagate(
        &self,
        node: &mut Node<N, G::State, T>,
        result: &[G::Score; N],
    ) -> Result<(), Self::Error> {
        log::debug!(
            "running the average score backpropagation strategy from {:?}",
            node
        );
        let oldvisits: f64 = node.visits.into();
        let newvisits = oldvisits + 1.0;

        for (payout, newres) in node.payouts.iter_mut().zip(result) {
            let res: f64 = (*newres)
                .try_into()
                .map_err(AverageError::GameScoreToF64Conversion)?;
            let oldpayout: f64 = (*payout)
                .try_into()
                .map_err(AverageError::PayoutToF64Conversion)?;
            *payout = ((oldpayout * oldvisits + res) / newvisits)
                .try_into()
                .map_err(AverageError::F64ToPayoutConversion)?;
        }
        Ok(())
    }
}

/// Errors associated with [`WinRate`] [`Backpropagation`] strategy.
#[derive(Debug, thiserror::Error)]
pub enum WinRateError<T>
where
    T: TryInto<f64>,
    <T as TryInto<f64>>::Error: Error,
    f64: TryInto<T>,
    <f64 as TryInto<T>>::Error: Error,
{
    /// Error converting [`MonteCarloTreeSearch`] [`Node`] payout to an [`f64`].
    #[error("Error converting MCTS payout to a f64: {0}")]
    PayoutToF64Conversion(<T as TryInto<f64>>::Error),

    /// Error converting an [`f64`] to a [`MonteCarloTreeSearch`] [`Node`] payout.
    #[error("Error converting an f64 to a MCTS payout: {0}")]
    F64ToPayoutConversion(<f64 as TryInto<T>>::Error),
}

/// Win rate backpropagation method.
///
/// If there is an outright winner, they get 1.0 and all other players get 0.0. If there is a tie,
/// all tied winners get 0.5.
#[derive(Debug, Default)]
pub struct WinRate;

impl<const N: usize, G: Game<N>, T> Backpropagation<N, G, T> for WinRate
where
    T: Debug + Copy + TryInto<f64> + TryFrom<f64>,
    <T as TryInto<f64>>::Error: Error,
    <f64 as TryInto<T>>::Error: Error,
    G::Score: PartialEq + PartialOrd,
{
    type Error = WinRateError<T>;

    fn backpropagate(
        &self,
        node: &mut Node<N, G::State, T>,
        result: &[G::Score; N],
    ) -> Result<(), Self::Error> {
        log::debug!(
            "running the win rate backpropagation strategy from {:?}",
            node
        );
        let oldvisits: f64 = node.visits.into();
        let newvisits = oldvisits + 1.0;

        let max_score = result
            .iter()
            .max_by(|a, b| a.partial_cmp(b).expect("could not determine order"))
            .expect("0-player game");
        let is_max: [bool; N] = array::from_fn(|i| &result[i] == max_score);
        let multi_winners = is_max.iter().filter(|i| **i).count() > 1;
        let winrate_res: [f64; N] = array::from_fn(|i| match (is_max[i], multi_winners) {
            (true, true) => 0.5,
            (true, false) => 1.0,
            (false, _) => 0.0,
        });

        for (current, new) in node.payouts.iter_mut().zip(winrate_res) {
            let oldpayout: f64 = (*current)
                .try_into()
                .map_err(WinRateError::PayoutToF64Conversion)?;
            *current = ((oldpayout * oldvisits + new) / newvisits)
                .try_into()
                .map_err(WinRateError::F64ToPayoutConversion)?;
        }
        Ok(())
    }
}

/// Get the average ranking of the result scores.
///
/// Will get value N-1 if it is the highest score, and 0 if it is the lowest.
#[derive(Debug, Clone, Copy, Default)]
pub struct AverageRank;

impl<const N: usize, G: Game<N>, T> Backpropagation<N, G, T> for AverageRank
where
    T: Debug + Copy + TryInto<f64> + TryFrom<f64>,
    <T as TryInto<f64>>::Error: Error,
    <f64 as TryInto<T>>::Error: Error,
    G::Score: Rank + Clone,
{
    type Error = std::convert::Infallible;

    fn backpropagate(
        &self,
        node: &mut Node<N, G::State, T>,
        result: &[G::Score; N],
    ) -> Result<(), Self::Error> {
        log::debug!(
            "running the average rank backpropagation strategy from {:?}",
            node
        );
        let oldvisits: f64 = node.visits.into();
        let newvisits = oldvisits + 1.0;
        let ranks = result
            .iter()
            .cloned()
            .rank()
            .map(|r| N - r - 1) // make it so the rank is N-1 if you're in 1st and 0 if in last
            .collect::<Vec<_>>();
        for (currentrank, newrank) in node.payouts.iter_mut().zip_eq(ranks) {
            let oldpayout: f64 = (*currentrank).try_into().unwrap();
            let newrank = newrank as f64;
            *currentrank = ((oldpayout * oldvisits + newrank) / newvisits)
                .try_into()
                .unwrap();
        }
        Ok(())
    }
}