use std::collections::HashMap;
use std::convert::Infallible;
use std::error::Error;
use petgraph::graph::{Neighbors, NodeIndex};
use petgraph::Graph;
#[cfg(doc)]
use super::Backpropagation;
use super::Node;
use crate::game::Game;
pub trait FinalScorer<const N: usize, G: Game<N>, T> {
type Error: Error;
type Score: PartialOrd;
fn node_scores(
&self,
player_index: usize,
children: Neighbors<G::Action, u32>,
tree: &Graph<Node<N, G::State, T>, G::Action>,
) -> Result<HashMap<NodeIndex, Self::Score>, Self::Error>;
}
#[derive(Debug, Default)]
pub struct PayoutValue;
impl<const N: usize, G, T> FinalScorer<N, G, T> for PayoutValue
where
G: Game<N>,
T: PartialOrd + Clone,
{
type Error = Infallible;
type Score = T;
fn node_scores(
&self,
player_index: usize,
children: Neighbors<G::Action, u32>,
tree: &Graph<Node<N, G::State, T>, G::Action>,
) -> Result<HashMap<NodeIndex, Self::Score>, Self::Error> {
Ok(children
.map(|index| (index, tree[index].payouts[player_index].clone()))
.collect())
}
}
#[derive(Debug, Default)]
pub struct NumVisits;
impl<const N: usize, G, T> FinalScorer<N, G, T> for NumVisits
where
G: Game<N>,
T: PartialOrd,
{
type Error = Infallible;
type Score = u32;
fn node_scores(
&self,
_player_index: usize,
children: Neighbors<G::Action, u32>,
tree: &Graph<Node<N, G::State, T>, G::Action>,
) -> Result<HashMap<NodeIndex, u32>, Self::Error> {
Ok(children.map(|c| (c, tree[c].visits)).collect())
}
}