arboriter_mcts/
tree.rs

1use rand::prelude::IteratorRandom;
2use std::fmt;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use crate::game_state::GameState;
6
7/// Represents a node in the MCTS tree
8///
9/// Each node contains the game state, the action that led to it,
10/// statistics about visits and rewards, and references to child nodes.
11/// The tree is built incrementally during the search process.
12pub struct MCTSNode<S: GameState> {
13    /// The game state at this node
14    pub state: S,
15
16    /// The action that led to this state (None for root)
17    pub action: Option<S::Action>,
18
19    /// Number of times this node has been visited
20    /// Uses atomic operations to support potential future parallelization
21    pub visits: AtomicU64,
22
23    /// Total reward accumulated from simulations through this node
24    /// Uses atomic operations and fixed-point representation internally
25    pub total_reward: AtomicU64,
26
27    /// Sum of squared rewards (for variance calculation in UCB1-Tuned)
28    pub sum_squared_reward: AtomicU64,
29
30    /// Number of RAVE visits (AMAF)
31    pub rave_visits: AtomicU64,
32
33    /// Total RAVE reward
34    pub rave_reward: AtomicU64,
35
36    /// Prior probability for this node (P(s,a))
37    /// Used by PUCT policy. Defaults to 1.0 if not set.
38    pub prior: AtomicU64,
39
40    /// Children nodes representing states reachable from this one
41    pub children: Vec<MCTSNode<S>>,
42
43    /// Actions that have not yet been expanded into child nodes
44    /// As the search progresses, actions are moved from this list to children
45    pub unexpanded_actions: Vec<S::Action>,
46
47    /// Depth of this node in the tree (root = 0)
48    pub depth: usize,
49
50    /// Player who made the move to reach this state
51    /// For the root node, this is the starting player
52    pub player: S::Player,
53}
54
55/// Internal representation of a fixed-point value for rewards
56/// This allows atomic operations on floating point rewards
57const REWARD_SCALE: f64 = 1_000_000.0;
58
59/// Safely convert a floating point reward to a scaled integer
60fn float_to_scaled_u64(value: f64) -> u64 {
61    ((value * REWARD_SCALE).max(0.0) as u64).min(u64::MAX / 2)
62}
63
64/// Safely convert a scaled integer back to a floating point reward
65fn scaled_u64_to_float(value: u64) -> f64 {
66    value as f64 / REWARD_SCALE
67}
68
69impl<S: GameState> MCTSNode<S> {
70    /// Creates a new node with the given state and action
71    pub fn new(
72        state: S,
73        action: Option<S::Action>,
74        parent_player: Option<S::Player>,
75        depth: usize,
76    ) -> Self {
77        let player = parent_player.unwrap_or_else(|| state.get_current_player());
78        let unexpanded_actions = state.get_legal_actions();
79
80        MCTSNode {
81            state,
82            action,
83            visits: AtomicU64::new(0),
84            total_reward: AtomicU64::new(0),
85            sum_squared_reward: AtomicU64::new(0),
86            rave_visits: AtomicU64::new(0),
87            rave_reward: AtomicU64::new(0),
88            prior: AtomicU64::new(float_to_scaled_u64(1.0)), // Default prior is 1.0
89            children: Vec::new(),
90            unexpanded_actions,
91            depth,
92            player,
93        }
94    }
95
96    /// Returns the number of visits to this node
97    pub fn visits(&self) -> u64 {
98        self.visits.load(Ordering::Relaxed)
99    }
100
101    /// Returns the total reward accumulated at this node
102    pub fn total_reward(&self) -> f64 {
103        scaled_u64_to_float(self.total_reward.load(Ordering::Relaxed))
104    }
105
106    /// Returns the prior probability of this node
107    pub fn prior(&self) -> f64 {
108        scaled_u64_to_float(self.prior.load(Ordering::Relaxed))
109    }
110
111    /// Sets the prior probability of this node
112    pub fn set_prior(&self, prior: f64) {
113        self.prior.store(float_to_scaled_u64(prior), Ordering::Relaxed);
114    }
115
116    /// Returns the average reward (value) of this node
117    pub fn value(&self) -> f64 {
118        let visits = self.visits();
119        if visits == 0 {
120            return 0.0;
121        }
122        self.total_reward() / visits as f64
123    }
124
125    /// Increments the visit count
126    pub fn increment_visits(&self) {
127        self.visits.fetch_add(1, Ordering::Relaxed);
128    }
129
130    /// Adds reward to the total
131    pub fn add_reward(&self, reward: f64) {
132        self.total_reward
133            .fetch_add(float_to_scaled_u64(reward), Ordering::Relaxed);
134    }
135
136    /// Adds squared reward (for UCB1-Tuned)
137    pub fn add_squared_reward(&self, reward: f64) {
138        self.sum_squared_reward
139            .fetch_add(float_to_scaled_u64(reward * reward), Ordering::Relaxed);
140    }
141
142    /// Returns the sum of squared rewards
143    pub fn sum_squared_reward(&self) -> f64 {
144        scaled_u64_to_float(self.sum_squared_reward.load(Ordering::Relaxed))
145    }
146
147    /// Increments the RAVE visit count
148    pub fn increment_rave_visits(&self) {
149        self.rave_visits.fetch_add(1, Ordering::Relaxed);
150    }
151
152    /// Adds RAVE reward
153    pub fn add_rave_reward(&self, reward: f64) {
154        self.rave_reward
155            .fetch_add(float_to_scaled_u64(reward), Ordering::Relaxed);
156    }
157
158    /// Returns the number of RAVE visits
159    pub fn rave_visits(&self) -> u64 {
160        self.rave_visits.load(Ordering::Relaxed)
161    }
162
163    /// Returns the RAVE value (average RAVE reward)
164    pub fn rave_value(&self) -> f64 {
165        let visits = self.rave_visits();
166        if visits == 0 {
167            return 0.0;
168        }
169        scaled_u64_to_float(self.rave_reward.load(Ordering::Relaxed)) / visits as f64
170    }
171
172    /// Returns true if this node is fully expanded
173    pub fn is_fully_expanded(&self) -> bool {
174        self.unexpanded_actions.is_empty()
175    }
176
177    /// Returns true if this node is a leaf (has no children)
178    pub fn is_leaf(&self) -> bool {
179        self.children.is_empty()
180    }
181
182    /// Expands the node by creating a child for an unexpanded action
183    ///
184    /// This method takes an action from the unexpanded actions list,
185    /// applies it to create a new game state, and creates a child node
186    /// for this new state.
187    ///
188    /// # Arguments
189    ///
190    /// * `action_index` - Index into the `unexpanded_actions` list
191    ///
192    /// # Returns
193    ///
194    /// * `Some(&mut MCTSNode<S>)` - Reference to the newly created child node
195    /// * `None` - If the action index is out of bounds
196    ///
197    /// # Note
198    ///
199    /// This method uses `swap_remove` on the unexpanded actions list, which
200    /// changes the order of the remaining unexpanded actions. If order
201    /// matters to your application, be aware of this side effect.
202    pub fn expand(&mut self, action_index: usize) -> Option<&mut MCTSNode<S>> {
203        if action_index >= self.unexpanded_actions.len() {
204            return None;
205        }
206
207        let action = self.unexpanded_actions.swap_remove(action_index);
208        let next_state = self.state.apply_action(&action);
209        let current_player = self.state.get_current_player();
210
211        let child = MCTSNode::new(
212            next_state,
213            Some(action),
214            Some(current_player),
215            self.depth + 1,
216        );
217
218        self.children.push(child);
219        self.children.last_mut()
220    }
221
222    /// Expands the node using a node pool for better performance
223    ///
224    /// This version of expand uses a node pool to reduce allocation overhead.
225    /// It's recommended for performance-critical applications.
226    pub fn expand_with_pool(
227        &mut self,
228        action_index: usize,
229        pool: &mut NodePool<S>,
230    ) -> Option<&mut MCTSNode<S>> {
231        if action_index >= self.unexpanded_actions.len() {
232            return None;
233        }
234
235        let action = self.unexpanded_actions.swap_remove(action_index);
236        let next_state = self.state.apply_action(&action);
237        let current_player = self.state.get_current_player();
238
239        // Create a new node using the pool
240        let node = pool.create_node(
241            next_state,
242            Some(action),
243            Some(current_player),
244            self.depth + 1,
245        );
246
247        self.children.push(node);
248        self.children.last_mut()
249    }
250
251    /// Expands a random unexpanded action
252    pub fn expand_random(&mut self) -> Option<&mut MCTSNode<S>> {
253        if self.unexpanded_actions.is_empty() {
254            return None;
255        }
256
257        // Use IteratorRandom trait for choose method on range
258        let mut rng = rand::thread_rng();
259        let index = (0..self.unexpanded_actions.len()).choose(&mut rng).unwrap();
260
261        self.expand(index)
262    }
263
264    /// Expands a random unexpanded action using a node pool
265    pub fn expand_random_with_pool(&mut self, pool: &mut NodePool<S>) -> Option<&mut MCTSNode<S>> {
266        if self.unexpanded_actions.is_empty() {
267            return None;
268        }
269
270        // Use IteratorRandom trait for choose method on range
271        let mut rng = rand::thread_rng();
272        let index = (0..self.unexpanded_actions.len()).choose(&mut rng).unwrap();
273
274        self.expand_with_pool(index, pool)
275    }
276}
277
278/// Pool for efficient node allocation in MCTS
279///
280/// This implementation provides memory reuse by creating and recycling nodes
281/// instead of frequently allocating and deallocating them. This can significantly
282/// improve performance in large MCTS searches.
283pub struct NodePool<S: GameState> {
284    /// Template state used for creating new nodes
285    template_state: S,
286
287    /// Preallocated, reusable nodes for efficient reuse
288    free_nodes: Vec<MCTSNode<S>>,
289
290    /// Statistics about allocations
291    stats: NodePoolStats,
292}
293
294/// Statistics for node pool performance tracking
295#[derive(Debug, Default, Clone)]
296pub struct NodePoolStats {
297    /// Total nodes created by the pool
298    pub total_created: usize,
299
300    /// Total nodes allocated (both new and reused)
301    pub total_allocations: usize,
302
303    /// Total nodes recycled back to the pool
304    pub total_recycled: usize,
305}
306
307impl<S: GameState> NodePool<S> {
308    /// Creates a new node pool with the given template state
309    ///
310    /// # Arguments
311    ///
312    /// * `template_state` - A template state that can be cloned when creating new nodes
313    /// * `initial_size` - Number of nodes to preallocate
314    pub fn new(template_state: S, initial_size: usize) -> Self {
315        let mut pool = NodePool {
316            template_state,
317            free_nodes: Vec::with_capacity(initial_size),
318            stats: NodePoolStats::default(),
319        };
320
321        // Preallocate nodes if requested
322        if initial_size > 0 {
323            pool.preallocate(initial_size);
324        }
325
326        pool
327    }
328
329    /// Preallocate nodes to reduce allocation pressure during search
330    fn preallocate(&mut self, count: usize) {
331        for _ in 0..count {
332            let node = MCTSNode {
333                state: self.template_state.clone(),
334                action: None,
335                visits: AtomicU64::new(0),
336                total_reward: AtomicU64::new(0),
337                sum_squared_reward: AtomicU64::new(0),
338                rave_visits: AtomicU64::new(0),
339                rave_reward: AtomicU64::new(0),
340                prior: AtomicU64::new(float_to_scaled_u64(1.0)),
341                children: Vec::new(),
342                unexpanded_actions: Vec::new(),
343                depth: 0,
344                player: self.template_state.get_current_player(),
345            };
346
347            self.free_nodes.push(node);
348            self.stats.total_created += 1;
349        }
350    }
351
352    /// Creates a new node, either from the pool or by allocating a new one
353    pub fn create_node(
354        &mut self,
355        state: S,
356        action: Option<S::Action>,
357        parent_player: Option<S::Player>,
358        depth: usize,
359    ) -> MCTSNode<S> {
360        self.stats.total_allocations += 1;
361
362        if let Some(mut node) = self.free_nodes.pop() {
363            // Get player before moving state
364            let player = match &parent_player {
365                Some(p) => p.clone(),
366                None => state.get_current_player(),
367            };
368
369            // Get legal actions before moving state
370            let legal_actions = state.get_legal_actions();
371
372            // Reuse an existing node
373            node.state = state;
374            node.action = action;
375            node.visits = AtomicU64::new(0);
376            node.total_reward = AtomicU64::new(0);
377            node.sum_squared_reward = AtomicU64::new(0);
378            node.rave_visits = AtomicU64::new(0);
379            node.rave_reward = AtomicU64::new(0);
380            node.prior = AtomicU64::new(float_to_scaled_u64(1.0));
381            node.children.clear();
382            node.depth = depth;
383            node.player = player;
384            node.unexpanded_actions = legal_actions;
385
386            node
387        } else {
388            // Create a new node if the pool is empty
389            self.stats.total_created += 1;
390            MCTSNode::new(state, action, parent_player, depth)
391        }
392    }
393
394    /// Recycles a node back to the pool for future reuse
395    pub fn recycle_node(&mut self, mut node: MCTSNode<S>) {
396        self.stats.total_recycled += 1;
397
398        // Clear any large data structures to prevent memory bloat
399        node.children.clear();
400        node.unexpanded_actions.clear();
401
402        // Add the node back to the free list
403        self.free_nodes.push(node);
404    }
405
406    /// Recycles all nodes in a tree by recursively adding them to the pool
407    pub fn recycle_tree(&mut self, mut root: MCTSNode<S>) {
408        // First, recursively recycle all children
409        let mut children = std::mem::take(&mut root.children);
410        for child in children.drain(..) {
411            self.recycle_tree(child);
412        }
413
414        // Then recycle the root node itself
415        self.recycle_node(root);
416    }
417
418    /// Get statistics about pool utilization
419    pub fn get_stats(&self) -> &NodePoolStats {
420        &self.stats
421    }
422
423    /// Get current pool size (available nodes)
424    pub fn available_nodes(&self) -> usize {
425        self.free_nodes.len()
426    }
427}
428
429// Manual Clone implementation for NodePool
430impl<S: GameState> Clone for NodePool<S> {
431    fn clone(&self) -> Self {
432        // Create a new pool with the same template state and stats
433        // We don't clone the free_nodes as they cannot be shared between instances
434        // Instead, we'll create new nodes when needed
435        NodePool {
436            template_state: self.template_state.clone(),
437            free_nodes: Vec::new(), // Start with empty free_nodes
438            stats: self.stats.clone(),
439        }
440    }
441}
442
443/// Represents a path through the MCTS tree
444///
445/// A path is a sequence of indices that can be used to navigate from
446/// the root node to a specific node in the tree.
447#[derive(Debug, Clone)]
448pub struct NodePath {
449    /// Indices of children to follow from the root
450    pub indices: Vec<usize>,
451}
452
453impl NodePath {
454    /// Creates a new empty path (pointing to the root)
455    pub fn new() -> Self {
456        NodePath {
457            indices: Vec::new(),
458        }
459    }
460
461    /// Creates a path with the given indices
462    pub fn from_indices(indices: Vec<usize>) -> Self {
463        NodePath { indices }
464    }
465
466    /// Extends the path with a new index
467    pub fn push(&mut self, index: usize) {
468        self.indices.push(index);
469    }
470
471    /// Returns the length of the path
472    pub fn len(&self) -> usize {
473        self.indices.len()
474    }
475
476    /// Returns true if the path is empty
477    pub fn is_empty(&self) -> bool {
478        self.indices.is_empty()
479    }
480}
481
482impl Default for NodePath {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488impl fmt::Display for NodePath {
489    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
490        write!(f, "Path[")?;
491        for (i, idx) in self.indices.iter().enumerate() {
492            if i > 0 {
493                write!(f, " -> ")?;
494            }
495            write!(f, "{}", idx)?;
496        }
497        write!(f, "]")
498    }
499}
500
501/// Standalone helper function for tree recycling
502///
503/// This needs to be outside the MCTS impl to avoid borrow checker issues
504pub fn recycle_subtree_recursive<S: GameState>(mut node: MCTSNode<S>, pool: &mut NodePool<S>) {
505    // First take all children
506    let mut children = std::mem::take(&mut node.children);
507
508    // Recursively recycle each child
509    for child in children.drain(..) {
510        recycle_subtree_recursive(child, pool);
511    }
512
513    // Now recycle the node itself
514    pool.recycle_node(node);
515}