Skip to main content

entrenar/search/mcts/
node.rs

1//! Node types and statistics for MCTS.
2//!
3//! This module contains the node representation, statistics tracking,
4//! and UCB1/PUCT score calculations.
5
6use super::traits::{Action, State};
7use super::Reward;
8
9/// Unique identifier for nodes in the search tree
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub struct NodeId(pub usize);
12
13impl NodeId {
14    /// Creates a new NodeId
15    #[must_use]
16    pub const fn new(id: usize) -> Self {
17        Self(id)
18    }
19
20    /// Returns the underlying id value
21    #[must_use]
22    pub const fn value(&self) -> usize {
23        self.0
24    }
25}
26
27/// Statistics for a node in the search tree
28#[derive(Debug, Clone)]
29pub struct NodeStats {
30    /// Total visits to this node
31    pub visits: usize,
32    /// Total accumulated reward
33    pub total_reward: f64,
34    /// Mean reward (total_reward / visits)
35    pub mean_reward: f64,
36    /// Prior probability from policy network
37    pub prior: f64,
38}
39
40impl Default for NodeStats {
41    fn default() -> Self {
42        Self { visits: 0, total_reward: 0.0, mean_reward: 0.0, prior: 1.0 }
43    }
44}
45
46impl NodeStats {
47    /// Update statistics with a new reward
48    pub fn update(&mut self, reward: Reward) {
49        contract_pre_update!();
50        self.visits += 1;
51        self.total_reward += reward;
52        self.mean_reward = self.total_reward / self.visits as f64;
53    }
54
55    /// Calculate UCB1 score
56    #[must_use]
57    pub fn ucb1(&self, parent_visits: usize, c: f64) -> f64 {
58        if self.visits == 0 {
59            return f64::INFINITY;
60        }
61        let exploitation = self.mean_reward;
62        let exploration = c * ((parent_visits as f64).max(1.0).ln() / self.visits as f64).sqrt();
63        exploitation + exploration
64    }
65
66    /// Calculate PUCT score (Polynomial Upper Confidence Trees) for policy-guided search
67    #[must_use]
68    pub fn puct(&self, parent_visits: usize, c: f64) -> f64 {
69        let exploitation = self.mean_reward;
70        let exploration =
71            c * self.prior * (parent_visits as f64).sqrt() / (1.0 + self.visits as f64);
72        exploitation + exploration
73    }
74}
75
76/// A node in the search tree
77#[derive(Debug, Clone)]
78pub struct Node<S: State, A: Action> {
79    /// Unique identifier
80    pub id: NodeId,
81    /// State at this node
82    pub state: S,
83    /// Action that led to this node (None for root)
84    pub action: Option<A>,
85    /// Parent node id (None for root)
86    pub parent: Option<NodeId>,
87    /// Child node ids
88    pub children: Vec<NodeId>,
89    /// Statistics for this node
90    pub stats: NodeStats,
91    /// Whether this node is fully expanded
92    pub expanded: bool,
93    /// Untried actions from this state
94    pub untried_actions: Vec<A>,
95}
96
97impl<S: State, A: Action> Node<S, A> {
98    /// Create a new root node
99    #[must_use]
100    pub fn root(state: S, untried_actions: Vec<A>) -> Self {
101        Self {
102            id: NodeId::new(0),
103            state,
104            action: None,
105            parent: None,
106            children: Vec::new(),
107            stats: NodeStats::default(),
108            expanded: false,
109            untried_actions,
110        }
111    }
112
113    /// Create a new child node
114    #[must_use]
115    pub fn child(
116        id: NodeId,
117        state: S,
118        action: A,
119        parent: NodeId,
120        untried_actions: Vec<A>,
121        prior: f64,
122    ) -> Self {
123        Self {
124            id,
125            state,
126            action: Some(action),
127            parent: Some(parent),
128            children: Vec::new(),
129            stats: NodeStats { prior, ..Default::default() },
130            expanded: false,
131            untried_actions,
132        }
133    }
134
135    /// Returns true if this node is a leaf (no children)
136    #[must_use]
137    pub fn is_leaf(&self) -> bool {
138        self.children.is_empty()
139    }
140
141    /// Returns true if this node is fully expanded
142    #[must_use]
143    pub fn is_fully_expanded(&self) -> bool {
144        self.expanded && self.untried_actions.is_empty()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use proptest::prelude::*;
152
153    // Simple test state for unit tests
154    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
155    struct TestState {
156        value: i32,
157        terminal: bool,
158    }
159
160    impl State for TestState {
161        fn is_terminal(&self) -> bool {
162            self.terminal
163        }
164    }
165
166    // Simple test action
167    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
168    struct TestAction {
169        delta: i32,
170    }
171
172    impl Action for TestAction {
173        fn name(&self) -> &'static str {
174            "test_action"
175        }
176    }
177
178    // ========================================
179    // ENT-090: Core types tests
180    // ========================================
181
182    #[test]
183    fn test_node_id_creation() {
184        let id = NodeId::new(42);
185        assert_eq!(id.value(), 42);
186    }
187
188    #[test]
189    fn test_node_stats_default() {
190        let stats = NodeStats::default();
191        assert_eq!(stats.visits, 0);
192        assert_eq!(stats.total_reward, 0.0);
193        assert_eq!(stats.mean_reward, 0.0);
194        assert_eq!(stats.prior, 1.0);
195    }
196
197    #[test]
198    fn test_node_stats_update() {
199        let mut stats = NodeStats::default();
200        stats.update(1.0);
201        assert_eq!(stats.visits, 1);
202        assert_eq!(stats.total_reward, 1.0);
203        assert_eq!(stats.mean_reward, 1.0);
204
205        stats.update(0.0);
206        assert_eq!(stats.visits, 2);
207        assert_eq!(stats.total_reward, 1.0);
208        assert_eq!(stats.mean_reward, 0.5);
209    }
210
211    #[test]
212    fn test_node_root_creation() {
213        let state = TestState { value: 0, terminal: false };
214        let actions = vec![TestAction { delta: 1 }];
215        let node = Node::root(state.clone(), actions);
216
217        assert_eq!(node.id, NodeId::new(0));
218        assert_eq!(node.state, state);
219        assert!(node.action.is_none());
220        assert!(node.parent.is_none());
221        assert!(node.children.is_empty());
222        assert!(!node.expanded);
223    }
224
225    #[test]
226    fn test_node_child_creation() {
227        let state = TestState { value: 1, terminal: false };
228        let action = TestAction { delta: 1 };
229        let node =
230            Node::child(NodeId::new(1), state.clone(), action.clone(), NodeId::new(0), vec![], 0.5);
231
232        assert_eq!(node.id, NodeId::new(1));
233        assert_eq!(node.state, state);
234        assert_eq!(node.action, Some(action));
235        assert_eq!(node.parent, Some(NodeId::new(0)));
236        assert_eq!(node.stats.prior, 0.5);
237    }
238
239    #[test]
240    fn test_node_is_leaf() {
241        let state = TestState { value: 0, terminal: false };
242        let node: Node<TestState, TestAction> = Node::root(state, vec![]);
243        assert!(node.is_leaf());
244    }
245
246    // ========================================
247    // ENT-091: UCB1/UCT tests
248    // ========================================
249
250    #[test]
251    fn test_ucb1_unvisited_node() {
252        let stats = NodeStats::default();
253        let score = stats.ucb1(10, std::f64::consts::SQRT_2);
254        assert!(score.is_infinite());
255    }
256
257    #[test]
258    fn test_ucb1_visited_node() {
259        let mut stats = NodeStats::default();
260        stats.update(0.5);
261        let score = stats.ucb1(10, std::f64::consts::SQRT_2);
262
263        // Should be exploitation + exploration
264        // 0.5 + sqrt(2) * sqrt(ln(10) / 1) ≈ 0.5 + 2.14 = 2.64
265        assert!(score > 0.5);
266        assert!(score < 5.0);
267    }
268
269    #[test]
270    fn test_ucb1_more_visits_lower_exploration() {
271        let mut stats1 = NodeStats::default();
272        stats1.visits = 10;
273        stats1.total_reward = 5.0;
274        stats1.mean_reward = 0.5;
275
276        let mut stats2 = NodeStats::default();
277        stats2.visits = 100;
278        stats2.total_reward = 50.0;
279        stats2.mean_reward = 0.5;
280
281        let score1 = stats1.ucb1(1000, std::f64::consts::SQRT_2);
282        let score2 = stats2.ucb1(1000, std::f64::consts::SQRT_2);
283
284        // More visits should have lower exploration bonus
285        assert!(score1 > score2);
286    }
287
288    #[test]
289    fn test_puct_with_prior() {
290        let mut stats = NodeStats::default();
291        stats.prior = 0.5;
292        stats.update(0.3);
293
294        let score = stats.puct(100, 2.0);
295
296        // PUCT = mean_reward + c * prior * sqrt(parent_visits) / (1 + visits)
297        // = 0.3 + 2.0 * 0.5 * sqrt(100) / 2 = 0.3 + 5.0 = 5.3
298        assert!((score - 5.3).abs() < 0.01);
299    }
300
301    // ========================================
302    // ENT-095: Property tests
303    // ========================================
304
305    proptest! {
306        #[test]
307        fn test_node_stats_update_invariants(rewards in prop::collection::vec(0.0f64..=1.0, 1..100)) {
308            let mut stats = NodeStats::default();
309
310            for r in &rewards {
311                stats.update(*r);
312            }
313
314            prop_assert_eq!(stats.visits, rewards.len());
315            prop_assert!((stats.total_reward - rewards.iter().sum::<f64>()).abs() < 1e-10);
316            prop_assert!((stats.mean_reward - rewards.iter().sum::<f64>() / rewards.len() as f64).abs() < 1e-10);
317        }
318
319        #[test]
320        fn test_ucb1_exploration_decreases_with_visits(parent_visits in 10usize..1000, c in 0.1f64..5.0) {
321            let mut stats1 = NodeStats::default();
322            stats1.visits = 10;
323            stats1.mean_reward = 0.5;
324
325            let mut stats2 = NodeStats::default();
326            stats2.visits = 100;
327            stats2.mean_reward = 0.5;
328
329            let ucb1 = stats1.ucb1(parent_visits, c);
330            let ucb2 = stats2.ucb1(parent_visits, c);
331
332            // More visits should lead to lower UCB (less exploration bonus)
333            prop_assert!(ucb1 > ucb2, "UCB1 with fewer visits should be higher");
334        }
335
336        #[test]
337        fn test_ucb1_higher_reward_higher_score(parent_visits in 10usize..1000, c in 0.1f64..5.0) {
338            let mut stats1 = NodeStats::default();
339            stats1.visits = 50;
340            stats1.mean_reward = 0.3;
341
342            let mut stats2 = NodeStats::default();
343            stats2.visits = 50;
344            stats2.mean_reward = 0.7;
345
346            let ucb1 = stats1.ucb1(parent_visits, c);
347            let ucb2 = stats2.ucb1(parent_visits, c);
348
349            // Higher reward should lead to higher UCB (same visits)
350            prop_assert!(ucb2 > ucb1, "Higher reward should give higher UCB");
351        }
352
353        #[test]
354        fn test_puct_prior_increases_exploration(prior in 0.1f64..0.9) {
355            let mut stats1 = NodeStats::default();
356            stats1.visits = 10;
357            stats1.mean_reward = 0.5;
358            stats1.prior = prior;
359
360            let mut stats2 = NodeStats::default();
361            stats2.visits = 10;
362            stats2.mean_reward = 0.5;
363            stats2.prior = prior * 2.0;
364
365            let puct1 = stats1.puct(100, 2.0);
366            let puct2 = stats2.puct(100, 2.0);
367
368            // Higher prior should give higher PUCT
369            prop_assert!(puct2 > puct1, "Higher prior should give higher PUCT");
370        }
371    }
372}