use std::cell::{BorrowMutError, RefCell};
use std::cmp::Ordering;
use std::error::Error;
use std::fmt::Debug;
use std::hash::Hash;
use petgraph::graph::NodeIndex;
use petgraph::{Direction, Graph};
use rand::rngs::ThreadRng;
use rand::seq::IteratorRandom;
use super::{is_fully_expanded, Node};
use crate::prelude::{Game, GameError};
pub trait Selection<const N: usize, G: Game<N>, T>: Debug {
type Error: Error;
fn select(
&self,
game: &G,
tree: &Graph<Node<N, G::State, T>, G::Action>,
node_idx: NodeIndex,
) -> Result<NodeIndex, Self::Error>;
}
#[derive(Debug, thiserror::Error)]
pub enum UCTError<const N: usize, G>
where
G: Game<N>,
G::Action: Debug,
{
#[error(transparent)]
GameError(G::Error),
#[error("{0}")]
NumberConversionError(String),
#[error("No payout for requested player index")]
NoPlayerPayout,
#[error("Could not find child associated with action {0:?}")]
UnknownChildAction(G::Action),
#[error(transparent)]
BorrowMutError(#[from] BorrowMutError),
}
impl<const N: usize, G, E> From<E> for UCTError<N, G>
where
G: Game<N, Error = E>,
G::Action: Debug,
E: GameError,
{
fn from(value: G::Error) -> Self {
Self::GameError(value)
}
}
#[derive(Debug)]
pub struct UCT {
c: f64,
threshold: Option<(u32, RefCell<ThreadRng>)>,
}
impl Default for UCT {
fn default() -> Self {
Self {
c: 0.7,
threshold: Some((30, RefCell::new(ThreadRng::default()))),
}
}
}
impl UCT {
pub fn new(c: f64, threshold: Option<u32>) -> Self {
Self {
c,
threshold: threshold.map(|t| (t, RefCell::new(ThreadRng::default()))),
}
}
fn value<const N: usize, G: Game<N>, T>(
&self,
game: &G,
node: &Node<N, G::State, T>,
parent: &Node<N, G::State, T>,
) -> Result<f64, UCTError<N, G>>
where
T: Copy + TryInto<f64>,
<T as TryInto<f64>>::Error: Debug,
G::Action: Debug,
{
let payout: f64 = (*node
.payouts
.get(game.active_player_index(&parent.state)?)
.ok_or_else(|| UCTError::NoPlayerPayout)?)
.try_into()
.map_err(|err| UCTError::NumberConversionError(format!("{:?}", err)))?;
Ok(payout + self.c * ((parent.visits as f64).ln() / (node.visits as f64)).sqrt())
}
}
impl<const N: usize, G: Game<N>, T> Selection<N, G, T> for UCT
where
T: Copy + TryInto<f64> + Debug,
<T as TryInto<f64>>::Error: Debug,
G::Action: Clone + Hash + Eq + Debug,
Node<N, G::State, T>: Clone,
{
type Error = UCTError<N, G>;
fn select(
&self,
game: &G,
tree: &Graph<Node<N, G::State, T>, G::Action>,
node_idx: NodeIndex,
) -> Result<NodeIndex, UCTError<N, G>> {
let node = &tree[node_idx];
log::debug!("running the UCT selection strategy from {:?}", node);
if !is_fully_expanded(node_idx, game, tree)? {
log::debug!("Node is not fully expanded, so not using UTC");
return Ok(node_idx);
}
let children = tree.neighbors_directed(node_idx, Direction::Outgoing);
match &self.threshold {
Some((threshold, rng)) if node.visits < *threshold => {
log::debug!("Selecting randomly since below UCT threshold");
let mut rng = rng.try_borrow_mut()?;
children.choose(&mut *rng).map(Ok)
}
_ => {
log::debug!("Using UCT strategy since above (or no) threshold");
children
.map(|child_idx| Ok((child_idx, self.value(game, &tree[child_idx], node)?)))
.collect::<Result<Vec<_>, UCTError<N, G>>>()?
.into_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Less))
.map(|(best_child_idx, _)| self.select(game, tree, best_child_idx))
}
}
.unwrap_or(Ok(node_idx))
}
}