use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Sub;
use super::{InternalSearch, SearchError, SearchScore};
#[cfg(doc)]
use crate::game::Player;
use crate::game::{Game, GameError, Heuristic, State};
#[derive(Debug, thiserror::Error)]
pub enum MaxNError<const N: usize, G: Game<N>>
where
G::Player: Heuristic<N, G>,
{
#[error("Unable to convert vector into an array")]
VecToArrayError,
#[error("No available actions")]
NoAvailableActions,
#[error(transparent)]
GameError(G::Error),
#[error(transparent)]
PlayerHeuristicError(<G::Player as Heuristic<N, G>>::Error),
}
impl<const N: usize, G: Game<N>> SearchError for MaxNError<N, G> where G::Player: Heuristic<N, G> {}
impl<const N: usize, G: Game<N, Error = E>, E: GameError> From<E> for MaxNError<N, G>
where
G::Player: Heuristic<N, G>,
{
fn from(value: G::Error) -> Self {
Self::GameError(value)
}
}
#[derive(Debug)]
pub struct MaxN<const N: usize, G: Game<N>> {
upper_bound: Option<<G as Game<N>>::Score>,
max_depth: Option<usize>,
}
impl<const N: usize, G: Game<N>> Default for MaxN<N, G> {
fn default() -> Self {
Self {
upper_bound: None,
max_depth: None,
}
}
}
impl<const N: usize, G: Game<N>> MaxN<N, G> {
pub fn builder() -> MaxNBuilder<N, G> {
MaxNBuilder::default()
}
}
impl<const N: usize, G: Game<N>> InternalSearch for MaxN<N, G> {}
impl<const N: usize, G: Game<N, State = S>, S: State<N, G>> MaxN<N, G>
where
G::Action: Debug,
G::Score: Copy + Sub<Output = G::Score> + PartialOrd,
G::Player: Debug + Heuristic<N, G>,
G::State: Clone,
{
fn payoff(
&self,
action: &<G as Game<N>>::Action,
game: &G,
state: &<G as Game<N>>::State,
upper_bound: Option<&<G as Game<N>>::Score>,
depth: usize,
) -> Result<[<G as Game<N>>::Score; N], MaxNError<N, G>> {
let state = game.do_action(action, state.clone())?;
let active_player = game.active_player(&state)?;
let active_player_index = game.player_index(active_player)?;
log::trace!(
"(depth={:?}) testing action {:?} for player {:?}",
depth,
action,
active_player
);
let payoffs = if game.is_complete(&state)? {
game.players()
.iter()
.map(|player| game.score(player, &state))
.collect::<Result<Vec<_>, _>>()?
.try_into()
.map_err(|_| MaxNError::VecToArrayError)
} else if self.max_depth.is_some_and(|max_depth| depth > max_depth) {
log::info!(
"max depth ({}) was reached, using heuristic",
self.max_depth.expect("max depth was None?")
);
game.players()
.as_ref()
.iter()
.map(|player| player.heuristic(game, &state))
.collect::<Result<Vec<_>, _>>()
.map_err(|err| MaxNError::PlayerHeuristicError(err))?
.try_into()
.map_err(|_| MaxNError::VecToArrayError)
} else {
let actions = game.available_actions(&state)?;
let mut best = self.payoff(
actions.first().ok_or(MaxNError::NoAvailableActions)?,
game,
&state,
self.upper_bound.as_ref(),
depth + 1,
)?;
if actions.len() > 1 {
for action in &actions[1..] {
if let Some(upper_bound) = upper_bound {
if &best[active_player_index] >= upper_bound {
log::debug!(
"pruned some branches since {:?} >= {:?}",
&best[active_player_index],
&upper_bound
);
break;
}
}
let current = self.payoff(
action,
game,
&state,
self.upper_bound
.as_ref()
.map(|ub| *ub - best[active_player_index])
.as_ref(),
depth + 1,
)?;
if current[active_player_index] > best[active_player_index] {
best = current;
}
}
}
Ok(best)
}?;
log::trace!(
"(depth={:?}) payoffs of action {:?} for player {:?}: {:?}",
depth,
action,
game.active_player(&state)?,
payoffs
);
Ok(payoffs)
}
}
impl<const N: usize, G> SearchScore<N, G> for MaxN<N, G>
where
G: Game<N>,
G::State: State<N, G> + Clone,
G::Player: Clone + Debug + Heuristic<N, G>,
G::Action: Debug + Hash + Eq,
G::Score: Copy + Sub<Output = G::Score> + PartialOrd,
{
type Error = MaxNError<N, G>;
type Score = G::Score;
fn score_actions(
&self,
game: &G,
state: &G::State,
) -> Result<HashMap<G::Action, Self::Score>, Self::Error> {
let actions = game.available_actions(state)?;
let active_player = game.active_player(state)?;
log::info!(
"choosing action for player {:?} from {} branches ({} frontier nodes)",
active_player,
actions.len(),
partial_factorial(self.max_depth.unwrap_or(1), actions.len())
.map_or("OVERFLOW".to_string(), |f| f.to_string())
);
let active_player_index = game.player_index(active_player)?;
Ok(actions
.into_iter()
.map(
|action| match self.payoff(&action, game, state, self.upper_bound.as_ref(), 1) {
Ok(score) => {
log::debug!(
"action: {:?}, score: {:?}",
action,
score[active_player_index]
);
Ok((action, score))
}
Err(err) => Err(err),
},
)
.collect::<Result<Vec<_>, MaxNError<N, G>>>()?
.into_iter()
.map(|(action, scores)| (action, scores[active_player_index]))
.collect())
}
}
#[derive(Debug)]
pub struct MaxNBuilder<const N: usize, G: Game<N>> {
upper_bound: Option<<G as Game<N>>::Score>,
max_depth: Option<usize>,
}
impl<const N: usize, G: Game<N>> Default for MaxNBuilder<N, G> {
fn default() -> Self {
Self {
upper_bound: None,
max_depth: None,
}
}
}
impl<const N: usize, G: Game<N>> MaxNBuilder<N, G> {
pub fn pruning(mut self, upper_bound: <G as Game<N>>::Score) -> Self {
self.upper_bound = Some(upper_bound);
self
}
pub fn max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = Some(max_depth);
self
}
pub fn build(self) -> MaxN<N, G> {
MaxN {
upper_bound: self.upper_bound,
max_depth: self.max_depth,
}
}
}
fn partial_factorial(start: usize, end: usize) -> Option<usize> {
(start..=end).try_fold(1, usize::checked_mul)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_factorial() {
for (n, fact) in [(1, 1), (2, 2), (3, 6), (4, 24), (5, 120)] {
assert_eq!(partial_factorial(1, n), Some(fact));
}
}
}