ncpig 0.6.1

Non-Cooperative Perfect Information Games, and algorithms to play them.
Documentation
//! Final MCTS node scoring.
//!
//! After MCTS has visited nodes, it needs to "score" them somehow to determine which ones are
//! better than others. These are some different methods used to do that.

use std::collections::HashMap;
use std::convert::Infallible;
use std::error::Error;

use petgraph::graph::{Neighbors, NodeIndex};
use petgraph::Graph;

#[cfg(doc)]
use super::Backpropagation;
use super::Node;
use crate::game::Game;

/// This determines how MCTS will "score" the nodes it has visited, as the final step before
/// ranking/choosing actions.
pub trait FinalScorer<const N: usize, G: Game<N>, T> {
    /// The type of error raised.
    type Error: Error;
    /// The type used as the "score" of each node.
    type Score: PartialOrd;

    /// Score each node.
    fn node_scores(
        &self,
        player_index: usize,
        children: Neighbors<G::Action, u32>,
        tree: &Graph<Node<N, G::State, T>, G::Action>,
    ) -> Result<HashMap<NodeIndex, Self::Score>, Self::Error>;
}

/// Score the nodes based on the value of the payout (defined by the [`Backpropagation`] method).
#[derive(Debug, Default)]
pub struct PayoutValue;

impl<const N: usize, G, T> FinalScorer<N, G, T> for PayoutValue
where
    G: Game<N>,
    T: PartialOrd + Clone,
{
    type Error = Infallible;
    type Score = T;

    fn node_scores(
        &self,
        player_index: usize,
        children: Neighbors<G::Action, u32>,
        tree: &Graph<Node<N, G::State, T>, G::Action>,
    ) -> Result<HashMap<NodeIndex, Self::Score>, Self::Error> {
        Ok(children
            .map(|index| (index, tree[index].payouts[player_index].clone()))
            .collect())
    }
}

/// Score the nodes based on the total number of visits they received.
#[derive(Debug, Default)]
pub struct NumVisits;

impl<const N: usize, G, T> FinalScorer<N, G, T> for NumVisits
where
    G: Game<N>,
    T: PartialOrd,
{
    type Error = Infallible;
    type Score = u32;

    fn node_scores(
        &self,
        _player_index: usize,
        children: Neighbors<G::Action, u32>,
        tree: &Graph<Node<N, G::State, T>, G::Action>,
    ) -> Result<HashMap<NodeIndex, u32>, Self::Error> {
        Ok(children.map(|c| (c, tree[c].visits)).collect())
    }
}