entrenar/search/mcts/
traits.rs1use std::hash::Hash;
7
8use super::Reward;
9
10pub trait State: Clone + Eq + Hash {
12 fn is_terminal(&self) -> bool;
14
15 fn state_hash(&self) -> u64 {
17 use std::hash::Hasher;
18 let mut hasher = std::collections::hash_map::DefaultHasher::new();
19 self.hash(&mut hasher);
20 hasher.finish()
21 }
22}
23
24pub trait Action: Clone + Eq + Hash {
26 fn name(&self) -> &str;
28
29 fn prior(&self) -> f64 {
31 1.0 }
33}
34
35pub trait StateSpace<S: State, A: Action> {
37 fn apply(&self, state: &S, action: &A) -> S;
39
40 fn evaluate(&self, state: &S) -> Reward;
42
43 fn clone_space(&self) -> Box<dyn StateSpace<S, A> + Send + Sync>;
45}
46
47pub trait ActionSpace<S: State, A: Action> {
49 fn legal_actions(&self, state: &S) -> Vec<A>;
51
52 fn is_empty(&self, state: &S) -> bool {
54 self.legal_actions(state).is_empty()
55 }
56}
57
58pub trait PolicyNetwork<S: State, A: Action>: Send + Sync {
60 fn predict(&self, state: &S) -> Vec<(A, f64)>;
62
63 fn value(&self, _state: &S) -> f64 {
65 0.5 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
75 struct TestState {
76 value: i32,
77 terminal: bool,
78 }
79
80 impl State for TestState {
81 fn is_terminal(&self) -> bool {
82 self.terminal
83 }
84 }
85
86 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
88 struct TestAction {
89 delta: i32,
90 }
91
92 impl Action for TestAction {
93 fn name(&self) -> &'static str {
94 "test_action"
95 }
96 }
97
98 #[test]
99 fn test_state_trait_implementation() {
100 let state = TestState { value: 5, terminal: false };
101 assert!(!state.is_terminal());
102
103 let terminal = TestState { value: 10, terminal: true };
104 assert!(terminal.is_terminal());
105 }
106
107 #[test]
108 fn test_action_trait_implementation() {
109 let action = TestAction { delta: 1 };
110 assert_eq!(action.name(), "test_action");
111 assert_eq!(action.prior(), 1.0);
112 }
113}