use std::array;
use std::error::Error;
use std::fmt::Debug;
use iterstats::rank::Rank;
use iterstats::Iterstats;
use itertools::Itertools;
use super::Node;
#[cfg(doc)]
use super::{MonteCarloTreeSearch, Playout};
use crate::prelude::Game;
pub trait Backpropagation<const N: usize, G: Game<N>, T> {
type Error: Error;
fn backpropagate(
&self,
node: &mut Node<N, G::State, T>,
result: &[G::Score; N],
) -> Result<(), Self::Error>;
}
#[derive(Debug, thiserror::Error)]
pub enum AverageError<const N: usize, G: Game<N>, T>
where
T: TryInto<f64>,
<T as TryInto<f64>>::Error: Error,
G::Score: TryInto<f64>,
<G::Score as TryInto<f64>>::Error: Error,
f64: TryInto<T>,
<f64 as TryInto<T>>::Error: Error,
{
#[error("Error converting game score to a f64: {0}")]
GameScoreToF64Conversion(<G::Score as TryInto<f64>>::Error),
#[error("Error converting MCTS payout to a f64: {0}")]
PayoutToF64Conversion(<T as TryInto<f64>>::Error),
#[error("Error converting an f64 to a MCTS payout: {0}")]
F64ToPayoutConversion(<f64 as TryInto<T>>::Error),
}
#[derive(Debug, Default)]
pub struct AverageScore;
impl<const N: usize, G: Game<N>, T> Backpropagation<N, G, T> for AverageScore
where
T: Copy + Debug + TryInto<f64>,
<T as TryInto<f64>>::Error: Error,
G::Score: Copy + TryInto<f64>,
<G::Score as TryInto<f64>>::Error: Error,
f64: TryInto<T>,
<f64 as TryInto<T>>::Error: Error,
G::Action: Debug,
{
type Error = AverageError<N, G, T>;
fn backpropagate(
&self,
node: &mut Node<N, G::State, T>,
result: &[G::Score; N],
) -> Result<(), Self::Error> {
log::debug!(
"running the average score backpropagation strategy from {:?}",
node
);
let oldvisits: f64 = node.visits.into();
let newvisits = oldvisits + 1.0;
for (payout, newres) in node.payouts.iter_mut().zip(result) {
let res: f64 = (*newres)
.try_into()
.map_err(AverageError::GameScoreToF64Conversion)?;
let oldpayout: f64 = (*payout)
.try_into()
.map_err(AverageError::PayoutToF64Conversion)?;
*payout = ((oldpayout * oldvisits + res) / newvisits)
.try_into()
.map_err(AverageError::F64ToPayoutConversion)?;
}
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum WinRateError<T>
where
T: TryInto<f64>,
<T as TryInto<f64>>::Error: Error,
f64: TryInto<T>,
<f64 as TryInto<T>>::Error: Error,
{
#[error("Error converting MCTS payout to a f64: {0}")]
PayoutToF64Conversion(<T as TryInto<f64>>::Error),
#[error("Error converting an f64 to a MCTS payout: {0}")]
F64ToPayoutConversion(<f64 as TryInto<T>>::Error),
}
#[derive(Debug, Default)]
pub struct WinRate;
impl<const N: usize, G: Game<N>, T> Backpropagation<N, G, T> for WinRate
where
T: Debug + Copy + TryInto<f64> + TryFrom<f64>,
<T as TryInto<f64>>::Error: Error,
<f64 as TryInto<T>>::Error: Error,
G::Score: PartialEq + PartialOrd,
{
type Error = WinRateError<T>;
fn backpropagate(
&self,
node: &mut Node<N, G::State, T>,
result: &[G::Score; N],
) -> Result<(), Self::Error> {
log::debug!(
"running the win rate backpropagation strategy from {:?}",
node
);
let oldvisits: f64 = node.visits.into();
let newvisits = oldvisits + 1.0;
let max_score = result
.iter()
.max_by(|a, b| a.partial_cmp(b).expect("could not determine order"))
.expect("0-player game");
let is_max: [bool; N] = array::from_fn(|i| &result[i] == max_score);
let multi_winners = is_max.iter().filter(|i| **i).count() > 1;
let winrate_res: [f64; N] = array::from_fn(|i| match (is_max[i], multi_winners) {
(true, true) => 0.5,
(true, false) => 1.0,
(false, _) => 0.0,
});
for (current, new) in node.payouts.iter_mut().zip(winrate_res) {
let oldpayout: f64 = (*current)
.try_into()
.map_err(WinRateError::PayoutToF64Conversion)?;
*current = ((oldpayout * oldvisits + new) / newvisits)
.try_into()
.map_err(WinRateError::F64ToPayoutConversion)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AverageRank;
impl<const N: usize, G: Game<N>, T> Backpropagation<N, G, T> for AverageRank
where
T: Debug + Copy + TryInto<f64> + TryFrom<f64>,
<T as TryInto<f64>>::Error: Error,
<f64 as TryInto<T>>::Error: Error,
G::Score: Rank + Clone,
{
type Error = std::convert::Infallible;
fn backpropagate(
&self,
node: &mut Node<N, G::State, T>,
result: &[G::Score; N],
) -> Result<(), Self::Error> {
log::debug!(
"running the average rank backpropagation strategy from {:?}",
node
);
let oldvisits: f64 = node.visits.into();
let newvisits = oldvisits + 1.0;
let ranks = result
.iter()
.cloned()
.rank()
.map(|r| N - r - 1) .collect::<Vec<_>>();
for (currentrank, newrank) in node.payouts.iter_mut().zip_eq(ranks) {
let oldpayout: f64 = (*currentrank).try_into().unwrap();
let newrank = newrank as f64;
*currentrank = ((oldpayout * oldvisits + newrank) / newvisits)
.try_into()
.unwrap();
}
Ok(())
}
}