use crate::reward::Reward;
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum GameState {
Ongoing,
Terminal(Reward),
Win(Reward),
Loss,
Draw,
}
impl GameState {
#[must_use]
pub fn is_terminal(self) -> bool {
!matches!(self, Self::Ongoing)
}
#[must_use]
pub fn reward(self) -> Option<Reward> {
match self {
Self::Ongoing => None,
Self::Terminal(reward) | Self::Win(reward) => Some(reward),
Self::Loss => Some(Reward::LOSS),
Self::Draw => Some(Reward::DRAW),
}
}
}
impl std::fmt::Display for GameState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ongoing => f.write_str("ongoing"),
Self::Terminal(reward) => write!(f, "terminal({reward})"),
Self::Win(reward) => write!(f, "win({reward})"),
Self::Loss => f.write_str("loss"),
Self::Draw => f.write_str("draw"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default, serde::Serialize, serde::Deserialize)]
pub struct Heuristic {
pub value: Option<Reward>,
}
impl Heuristic {
#[must_use]
pub fn from_reward(value: Reward) -> Self {
Self { value: Some(value) }
}
}
pub trait Environment: Clone + Send + Sync {
type Action: Clone + Send + Sync + std::fmt::Debug + PartialEq;
fn legal_actions(&self) -> Vec<Self::Action>;
fn apply(&mut self, action: &Self::Action);
fn evaluate(&self) -> GameState;
fn heuristic(&self) -> Heuristic {
Heuristic::default()
}
fn max_depth(&self) -> Option<usize> {
None
}
fn action_priors(&self, _actions: &[Self::Action]) -> Option<Vec<f64>> {
None
}
}
#[cfg(test)]
mod tests {
use super::{GameState, Heuristic, Reward};
#[test]
fn game_state_terminal_detection() {
assert!(GameState::Terminal(Reward::WIN).is_terminal());
assert!(!GameState::Ongoing.is_terminal());
assert_eq!(GameState::Loss.reward(), Some(Reward::LOSS));
}
#[test]
fn heuristic_default_and_constructor() {
assert_eq!(Heuristic::default(), Heuristic { value: None });
let h = Heuristic::from_reward(Reward::new(0.25));
assert_eq!(h.value, Some(Reward::new(0.25)));
}
#[test]
fn format_terminal_states() {
assert_eq!(format!("{}", GameState::Win(Reward::WIN)), "win(1)");
assert_eq!(format!("{}", GameState::Loss), "loss");
assert_eq!(format!("{}", GameState::Draw), "draw");
}
}