use crate::reward::Reward;
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum Outcome {
Ongoing,
Terminal(Reward),
Success(Reward),
Failure,
Neutral,
}
impl Outcome {
pub fn is_terminal(self) -> bool {
!matches!(self, Self::Ongoing)
}
pub fn reward(self) -> Option<Reward> {
match self {
Self::Ongoing => None,
Self::Terminal(reward) | Self::Success(reward) => Some(reward),
Self::Failure => Some(Reward::LOSS),
Self::Neutral => Some(Reward::DRAW),
}
}
}
impl std::fmt::Display for Outcome {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Ongoing => f.write_str("ongoing"),
Self::Terminal(reward) => write!(f, "terminal({reward})"),
Self::Success(reward) => write!(f, "success({reward})"),
Self::Failure => f.write_str("failure"),
Self::Neutral => f.write_str("neutral"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default, serde::Serialize, serde::Deserialize)]
pub struct Heuristic {
pub value: Option<Reward>,
}
impl Heuristic {
pub fn from_reward(value: Reward) -> Self {
Self { value: Some(value) }
}
}
pub trait Environment: Clone + Send + Sync {
type Action: Clone + Send + Sync + std::fmt::Debug + PartialEq;
fn legal_actions(&self) -> Vec<Self::Action>;
fn apply(&mut self, action: &Self::Action);
fn evaluate(&self) -> Outcome;
fn current_player(&self) -> usize {
0
}
fn num_players(&self) -> usize {
1
}
fn heuristic(&self) -> Heuristic {
Heuristic::default()
}
fn max_depth(&self) -> Option<usize> {
None
}
fn action_priors(&self, _actions: &[Self::Action]) -> Option<Vec<f64>> {
None
}
fn state_hash(&self) -> Option<u64> {
None
}
}
pub trait Evaluator<E: Environment>: Send + Sync {
fn evaluate(&self, env: &E) -> Reward;
}
#[cfg(test)]
mod tests {
use super::{Heuristic, Outcome, Reward};
#[test]
fn outcome_terminal_detection() {
assert!(Outcome::Terminal(Reward::WIN).is_terminal());
assert!(!Outcome::Ongoing.is_terminal());
assert_eq!(Outcome::Failure.reward(), Some(Reward::LOSS));
}
#[test]
fn heuristic_default_and_constructor() {
assert_eq!(Heuristic::default(), Heuristic { value: None });
let h = Heuristic::from_reward(Reward::new(0.25));
assert_eq!(h.value, Some(Reward::new(0.25)));
}
#[test]
fn format_terminal_states() {
assert_eq!(format!("{}", Outcome::Success(Reward::WIN)), "success(1)");
assert_eq!(format!("{}", Outcome::Failure), "failure");
assert_eq!(format!("{}", Outcome::Neutral), "neutral");
}
}