Skip to main content

entrenar/search/mcts/
traits.rs

1//! Trait definitions for MCTS components.
2//!
3//! This module defines the core traits that must be implemented
4//! to use MCTS for a particular domain.
5
6use std::hash::Hash;
7
8use super::Reward;
9
10/// Trait for states in the search space (e.g., partial AST)
11pub trait State: Clone + Eq + Hash {
12    /// Returns true if this is a terminal state (complete code)
13    fn is_terminal(&self) -> bool;
14
15    /// Returns a hash of this state for deduplication
16    fn state_hash(&self) -> u64 {
17        use std::hash::Hasher;
18        let mut hasher = std::collections::hash_map::DefaultHasher::new();
19        self.hash(&mut hasher);
20        hasher.finish()
21    }
22}
23
24/// Trait for actions in the search space (e.g., AST transformations)
25pub trait Action: Clone + Eq + Hash {
26    /// Returns the name/identifier of this action
27    fn name(&self) -> &str;
28
29    /// Returns the prior probability of this action (for policy network guidance)
30    fn prior(&self) -> f64 {
31        1.0 // Uniform prior by default
32    }
33}
34
35/// Trait for defining the state space (how states transition)
36pub trait StateSpace<S: State, A: Action> {
37    /// Apply an action to a state, returning the new state
38    fn apply(&self, state: &S, action: &A) -> S;
39
40    /// Evaluate a terminal state, returning the reward (0.0 to 1.0)
41    fn evaluate(&self, state: &S) -> Reward;
42
43    /// Clone the state space (for parallel simulations)
44    fn clone_space(&self) -> Box<dyn StateSpace<S, A> + Send + Sync>;
45}
46
47/// Trait for defining the action space (available actions from a state)
48pub trait ActionSpace<S: State, A: Action> {
49    /// Returns all legal actions from the given state
50    fn legal_actions(&self, state: &S) -> Vec<A>;
51
52    /// Returns true if there are no legal actions from this state
53    fn is_empty(&self, state: &S) -> bool {
54        self.legal_actions(state).is_empty()
55    }
56}
57
58/// Trait for policy networks that guide the search
59pub trait PolicyNetwork<S: State, A: Action>: Send + Sync {
60    /// Returns (action, prior probability) pairs for the given state
61    fn predict(&self, state: &S) -> Vec<(A, f64)>;
62
63    /// Returns the value estimate for a state (optional, for AlphaZero-style)
64    fn value(&self, _state: &S) -> f64 {
65        0.5 // Neutral value by default
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    // Simple test state for unit tests
74    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
75    struct TestState {
76        value: i32,
77        terminal: bool,
78    }
79
80    impl State for TestState {
81        fn is_terminal(&self) -> bool {
82            self.terminal
83        }
84    }
85
86    // Simple test action
87    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
88    struct TestAction {
89        delta: i32,
90    }
91
92    impl Action for TestAction {
93        fn name(&self) -> &'static str {
94            "test_action"
95        }
96    }
97
98    #[test]
99    fn test_state_trait_implementation() {
100        let state = TestState { value: 5, terminal: false };
101        assert!(!state.is_terminal());
102
103        let terminal = TestState { value: 10, terminal: true };
104        assert!(terminal.is_terminal());
105    }
106
107    #[test]
108    fn test_action_trait_implementation() {
109        let action = TestAction { delta: 1 };
110        assert_eq!(action.name(), "test_action");
111        assert_eq!(action.prior(), 1.0);
112    }
113}