use std::cell::RefCell;
use rand::rngs::ThreadRng;
use rand::seq::IteratorRandom;
use super::SearchError;
#[cfg(doc)]
use crate::game::Action;
use crate::game::{Game, GameError};
use crate::search::SearchChoose;
#[derive(Debug, thiserror::Error)]
pub enum RandomError<const N: usize, G: Game<N>> {
#[error(transparent)]
RngBorrowError(#[from] std::cell::BorrowMutError),
#[error(transparent)]
GameError(G::Error),
}
impl<const N: usize, G: Game<N>> SearchError for RandomError<N, G> {}
impl<const N: usize, G: Game<N, Error = E>, E: GameError> From<E> for RandomError<N, G> {
fn from(value: G::Error) -> Self {
Self::GameError(value)
}
}
#[derive(Debug, Clone, Default)]
pub struct Random {
rng: RefCell<ThreadRng>,
}
impl Random {
pub fn new() -> Self {
Random {
rng: RefCell::new(ThreadRng::default()),
}
}
}
impl<const N: usize, G: Game<N>> SearchChoose<N, G> for Random
where
G::Action: Clone,
{
type Error = RandomError<N, G>;
fn choose_action(
&self,
game: &G,
state: &<G as Game<N>>::State,
) -> Result<Option<<G as Game<N>>::Action>, RandomError<N, G>> {
let actions = game.available_actions(state)?;
let mut rng = self.rng.try_borrow_mut()?;
Ok(actions.iter().choose(&mut *rng).cloned())
}
}