use std::hash::Hash;
use super::Reward;
pub trait State: Clone + Eq + Hash {
fn is_terminal(&self) -> bool;
fn state_hash(&self) -> u64 {
use std::hash::Hasher;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
pub trait Action: Clone + Eq + Hash {
fn name(&self) -> &str;
fn prior(&self) -> f64 {
1.0 }
}
pub trait StateSpace<S: State, A: Action> {
fn apply(&self, state: &S, action: &A) -> S;
fn evaluate(&self, state: &S) -> Reward;
fn clone_space(&self) -> Box<dyn StateSpace<S, A> + Send + Sync>;
}
pub trait ActionSpace<S: State, A: Action> {
fn legal_actions(&self, state: &S) -> Vec<A>;
fn is_empty(&self, state: &S) -> bool {
self.legal_actions(state).is_empty()
}
}
pub trait PolicyNetwork<S: State, A: Action>: Send + Sync {
fn predict(&self, state: &S) -> Vec<(A, f64)>;
fn value(&self, _state: &S) -> f64 {
0.5 }
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestState {
value: i32,
terminal: bool,
}
impl State for TestState {
fn is_terminal(&self) -> bool {
self.terminal
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestAction {
delta: i32,
}
impl Action for TestAction {
fn name(&self) -> &'static str {
"test_action"
}
}
#[test]
fn test_state_trait_implementation() {
let state = TestState { value: 5, terminal: false };
assert!(!state.is_terminal());
let terminal = TestState { value: 10, terminal: true };
assert!(terminal.is_terminal());
}
#[test]
fn test_action_trait_implementation() {
let action = TestAction { delta: 1 };
assert_eq!(action.name(), "test_action");
assert_eq!(action.prior(), 1.0);
}
}