arboriter_mcts/policy/
simulation.rs1use crate::game_state::GameState;
7
8pub trait SimulationPolicy<S: GameState>: Send + Sync {
10 fn simulate(&self, state: &S) -> (f64, Vec<S::Action>);
12
13 fn clone_box(&self) -> Box<dyn SimulationPolicy<S>>;
15}
16
17#[derive(Debug, Clone)]
21pub struct RandomPolicy;
22
23impl RandomPolicy {
24 pub fn new() -> Self {
26 RandomPolicy
27 }
28}
29
30impl Default for RandomPolicy {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<S: GameState> SimulationPolicy<S> for RandomPolicy {
37 fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
38 let player = state.get_current_player();
40 state.simulate_random_playout(&player)
41 }
42
43 fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
44 Box::new(self.clone())
45 }
46}
47
48#[derive(Debug, Clone)]
52pub struct HeuristicPolicy<F, S>
53where
54 F: Fn(&S) -> f64 + Clone + Send + Sync + 'static,
55 S: GameState + 'static,
56{
57 heuristic: F,
59 _phantom: std::marker::PhantomData<S>,
60}
61
62impl<F, S> HeuristicPolicy<F, S>
63where
64 F: Fn(&S) -> f64 + Clone + Send + Sync + 'static,
65 S: GameState + 'static,
66{
67 pub fn new(heuristic: F) -> Self {
69 HeuristicPolicy {
70 heuristic,
71 _phantom: std::marker::PhantomData,
72 }
73 }
74}
75
76impl<F, S> SimulationPolicy<S> for HeuristicPolicy<F, S>
77where
78 F: Fn(&S) -> f64 + Clone + Send + Sync + 'static,
79 S: GameState + 'static,
80{
81 fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
82 if state.is_terminal() {
84 let player = state.get_current_player();
85 return (state.get_result(&player), Vec::new());
86 }
87
88 ((self.heuristic)(state), Vec::new())
90 }
91
92 fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
93 Box::new(self.clone())
94 }
95}
96
97pub struct MixturePolicy<S: GameState> {
102 policies: Vec<(Box<dyn SimulationPolicy<S>>, f64)>,
104}
105
106impl<S: GameState> std::fmt::Debug for MixturePolicy<S> {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("MixturePolicy")
109 .field("policies_count", &self.policies.len())
110 .finish()
111 }
112}
113
114impl<S: GameState> Clone for MixturePolicy<S> {
115 fn clone(&self) -> Self {
116 MixturePolicy {
119 policies: Vec::new(),
120 }
121 }
122}
123
124impl<S: GameState> MixturePolicy<S> {
125 pub fn new() -> Self {
127 MixturePolicy {
128 policies: Vec::new(),
129 }
130 }
131
132 pub fn add_policy<P: SimulationPolicy<S> + 'static>(
134 mut self,
135 policy: P,
136 probability: f64,
137 ) -> Self {
138 self.policies.push((Box::new(policy), probability));
139 self
140 }
141}
142
143impl<S: GameState + 'static> SimulationPolicy<S> for MixturePolicy<S> {
144 fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
145 use rand::Rng;
146
147 if self.policies.is_empty() {
148 let random_policy = RandomPolicy::new();
150 return random_policy.simulate(state);
151 }
152
153 let total: f64 = self.policies.iter().map(|(_, p)| *p).sum();
155
156 let mut rng = rand::thread_rng();
158 let r: f64 = rng.gen_range(0.0..total);
159
160 let mut cumulative = 0.0;
161 for (policy, prob) in &self.policies {
162 cumulative += prob;
163 if r < cumulative {
164 return policy.simulate(state);
165 }
166 }
167
168 self.policies.last().unwrap().0.simulate(state)
170 }
171
172 fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
173 let mut new_policies = Vec::new();
174 for (policy, prob) in &self.policies {
175 new_policies.push((policy.clone_box(), *prob));
176 }
177
178 Box::new(MixturePolicy {
179 policies: new_policies,
180 })
181 }
182}
183
184impl<S: GameState> Default for MixturePolicy<S> {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189impl<S: GameState> SimulationPolicy<S> for Box<dyn SimulationPolicy<S>> {
191 fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
192 (**self).simulate(state)
193 }
194
195 fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
196 (**self).clone_box()
197 }
198}