arboriter_mcts/policy/
simulation.rs

1//! Simulation policies for the MCTS algorithm
2//!
3//! Simulation policies determine how to play out a game from a given state
4//! to estimate the value of that state.
5
6use crate::game_state::GameState;
7
8/// Trait for policies that simulate games
9pub trait SimulationPolicy<S: GameState>: Send + Sync {
10    /// Simulates a game from the given state and returns the result and action trace
11    fn simulate(&self, state: &S) -> (f64, Vec<S::Action>);
12
13    /// Create a boxed clone of this policy
14    fn clone_box(&self) -> Box<dyn SimulationPolicy<S>>;
15}
16
17/// Random simulation policy
18///
19/// This policy plays random legal moves until the game ends.
20#[derive(Debug, Clone)]
21pub struct RandomPolicy;
22
23impl RandomPolicy {
24    /// Creates a new random policy
25    pub fn new() -> Self {
26        RandomPolicy
27    }
28}
29
30impl Default for RandomPolicy {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<S: GameState> SimulationPolicy<S> for RandomPolicy {
37    fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
38        // Use the built-in random playout method
39        let player = state.get_current_player();
40        state.simulate_random_playout(&player)
41    }
42
43    fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
44        Box::new(self.clone())
45    }
46}
47
48/// Heuristic simulation policy
49///
50/// This policy uses a heuristic function to guide the simulation.
51#[derive(Debug, Clone)]
52pub struct HeuristicPolicy<F, S>
53where
54    F: Fn(&S) -> f64 + Clone + Send + Sync + 'static,
55    S: GameState + 'static,
56{
57    /// The heuristic function
58    heuristic: F,
59    _phantom: std::marker::PhantomData<S>,
60}
61
62impl<F, S> HeuristicPolicy<F, S>
63where
64    F: Fn(&S) -> f64 + Clone + Send + Sync + 'static,
65    S: GameState + 'static,
66{
67    /// Creates a new heuristic policy with the given function
68    pub fn new(heuristic: F) -> Self {
69        HeuristicPolicy {
70            heuristic,
71            _phantom: std::marker::PhantomData,
72        }
73    }
74}
75
76impl<F, S> SimulationPolicy<S> for HeuristicPolicy<F, S>
77where
78    F: Fn(&S) -> f64 + Clone + Send + Sync + 'static,
79    S: GameState + 'static,
80{
81    fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
82        // If terminal, return the actual result
83        if state.is_terminal() {
84            let player = state.get_current_player();
85            return (state.get_result(&player), Vec::new());
86        }
87
88        // Otherwise, use the heuristic function
89        ((self.heuristic)(state), Vec::new())
90    }
91
92    fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
93        Box::new(self.clone())
94    }
95}
96
97/// Mixture simulation policy
98///
99/// This policy combines multiple simulation policies, using each with
100/// a specified probability.
101pub struct MixturePolicy<S: GameState> {
102    /// Policies and their associated probabilities
103    policies: Vec<(Box<dyn SimulationPolicy<S>>, f64)>,
104}
105
106impl<S: GameState> std::fmt::Debug for MixturePolicy<S> {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        f.debug_struct("MixturePolicy")
109            .field("policies_count", &self.policies.len())
110            .finish()
111    }
112}
113
114impl<S: GameState> Clone for MixturePolicy<S> {
115    fn clone(&self) -> Self {
116        // We can't clone the policies directly, so we return a new empty MixturePolicy
117        // This is not ideal, but it's a reasonable fallback for the Clone requirement
118        MixturePolicy {
119            policies: Vec::new(),
120        }
121    }
122}
123
124impl<S: GameState> MixturePolicy<S> {
125    /// Creates a new mixture policy
126    pub fn new() -> Self {
127        MixturePolicy {
128            policies: Vec::new(),
129        }
130    }
131
132    /// Adds a policy with the given probability
133    pub fn add_policy<P: SimulationPolicy<S> + 'static>(
134        mut self,
135        policy: P,
136        probability: f64,
137    ) -> Self {
138        self.policies.push((Box::new(policy), probability));
139        self
140    }
141}
142
143impl<S: GameState + 'static> SimulationPolicy<S> for MixturePolicy<S> {
144    fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
145        use rand::Rng;
146
147        if self.policies.is_empty() {
148            // Fallback to random policy
149            let random_policy = RandomPolicy::new();
150            return random_policy.simulate(state);
151        }
152
153        // Calculate total probability
154        let total: f64 = self.policies.iter().map(|(_, p)| *p).sum();
155
156        // Select a policy based on probabilities
157        let mut rng = rand::thread_rng();
158        let r: f64 = rng.gen_range(0.0..total);
159
160        let mut cumulative = 0.0;
161        for (policy, prob) in &self.policies {
162            cumulative += prob;
163            if r < cumulative {
164                return policy.simulate(state);
165            }
166        }
167
168        // Fallback to the last policy
169        self.policies.last().unwrap().0.simulate(state)
170    }
171
172    fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
173        let mut new_policies = Vec::new();
174        for (policy, prob) in &self.policies {
175            new_policies.push((policy.clone_box(), *prob));
176        }
177
178        Box::new(MixturePolicy {
179            policies: new_policies,
180        })
181    }
182}
183
184impl<S: GameState> Default for MixturePolicy<S> {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189// Implement SimulationPolicy for Box<dyn SimulationPolicy>
190impl<S: GameState> SimulationPolicy<S> for Box<dyn SimulationPolicy<S>> {
191    fn simulate(&self, state: &S) -> (f64, Vec<S::Action>) {
192        (**self).simulate(state)
193    }
194
195    fn clone_box(&self) -> Box<dyn SimulationPolicy<S>> {
196        (**self).clone_box()
197    }
198}