Skip to main content

entrenar/search/mcts/
tree.rs

1//! Search tree structure for MCTS.
2//!
3//! This module contains the tree data structure that stores nodes
4//! and handles transposition detection.
5
6use std::collections::HashMap;
7
8use super::node::{Node, NodeId};
9use super::traits::{Action, State};
10
11/// The search tree structure
12#[derive(Debug)]
13pub struct SearchTree<S: State, A: Action> {
14    /// All nodes indexed by NodeId
15    nodes: Vec<Node<S, A>>,
16    /// Map from state hash to node id for deduplication
17    state_map: HashMap<u64, NodeId>,
18    /// Root node id
19    pub(crate) root_id: NodeId,
20}
21
22impl<S: State, A: Action> SearchTree<S, A> {
23    /// Create a new search tree with the given root state
24    #[must_use]
25    pub fn new(root_state: S, root_actions: Vec<A>) -> Self {
26        let root = Node::root(root_state.clone(), root_actions);
27        let state_hash = root_state.state_hash();
28        let mut state_map = HashMap::new();
29        state_map.insert(state_hash, NodeId::new(0));
30
31        Self { nodes: vec![root], state_map, root_id: NodeId::new(0) }
32    }
33
34    /// Get the root node
35    #[must_use]
36    pub fn root(&self) -> &Node<S, A> {
37        &self.nodes[self.root_id.0]
38    }
39
40    /// Get a node by id
41    #[must_use]
42    pub fn get(&self, id: NodeId) -> Option<&Node<S, A>> {
43        self.nodes.get(id.0)
44    }
45
46    /// Get a mutable node by id
47    pub fn get_mut(&mut self, id: NodeId) -> Option<&mut Node<S, A>> {
48        self.nodes.get_mut(id.0)
49    }
50
51    /// Add a child node, returning its id
52    pub fn add_child(
53        &mut self,
54        parent_id: NodeId,
55        state: S,
56        action: A,
57        untried_actions: Vec<A>,
58        prior: f64,
59    ) -> NodeId {
60        let state_hash = state.state_hash();
61
62        // Check for transposition (same state reached via different path)
63        if let Some(&existing_id) = self.state_map.get(&state_hash) {
64            // Add as child but reuse existing node's stats
65            if let Some(parent) = self.nodes.get_mut(parent_id.0) {
66                if !parent.children.contains(&existing_id) {
67                    parent.children.push(existing_id);
68                }
69            }
70            return existing_id;
71        }
72
73        let child_id = NodeId::new(self.nodes.len());
74        let child = Node::child(child_id, state, action, parent_id, untried_actions, prior);
75
76        self.nodes.push(child);
77        self.state_map.insert(state_hash, child_id);
78
79        if let Some(parent) = self.nodes.get_mut(parent_id.0) {
80            parent.children.push(child_id);
81        }
82
83        child_id
84    }
85
86    /// Get number of nodes in the tree
87    #[must_use]
88    pub fn size(&self) -> usize {
89        self.nodes.len()
90    }
91
92    /// Get all children of a node
93    #[must_use]
94    pub fn children(&self, id: NodeId) -> Vec<&Node<S, A>> {
95        self.nodes
96            .get(id.0)
97            .map(|n| n.children.iter().filter_map(|&cid| self.get(cid)).collect())
98            .unwrap_or_default()
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use proptest::prelude::*;
106
107    // Simple test state for unit tests
108    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
109    struct TestState {
110        value: i32,
111        terminal: bool,
112    }
113
114    impl State for TestState {
115        fn is_terminal(&self) -> bool {
116            self.terminal
117        }
118    }
119
120    // Simple test action
121    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
122    struct TestAction {
123        delta: i32,
124    }
125
126    impl Action for TestAction {
127        fn name(&self) -> &'static str {
128            "test_action"
129        }
130    }
131
132    #[test]
133    fn test_search_tree_creation() {
134        let state = TestState { value: 0, terminal: false };
135        let actions = vec![TestAction { delta: 1 }, TestAction { delta: -1 }];
136        let tree = SearchTree::new(state.clone(), actions);
137
138        assert_eq!(tree.size(), 1);
139        assert_eq!(tree.root().state, state);
140    }
141
142    #[test]
143    fn test_search_tree_add_child() {
144        let state = TestState { value: 0, terminal: false };
145        let actions = vec![TestAction { delta: 1 }];
146        let mut tree = SearchTree::new(state, actions);
147
148        let child_state = TestState { value: 1, terminal: false };
149        let child_action = TestAction { delta: 1 };
150        let child_id =
151            tree.add_child(NodeId::new(0), child_state.clone(), child_action, vec![], 0.5);
152
153        assert_eq!(tree.size(), 2);
154        let child = tree.get(child_id).expect("key should exist");
155        assert_eq!(child.state, child_state);
156    }
157
158    #[test]
159    fn test_tree_children() {
160        let state = TestState { value: 0, terminal: false };
161        let mut tree = SearchTree::new(state.clone(), vec![]);
162
163        let child1 = TestState { value: 1, terminal: false };
164        let child2 = TestState { value: 2, terminal: false };
165
166        tree.add_child(NodeId::new(0), child1, TestAction { delta: 1 }, vec![], 0.5);
167        tree.add_child(NodeId::new(0), child2, TestAction { delta: 2 }, vec![], 0.5);
168
169        let children = tree.children(NodeId::new(0));
170        assert_eq!(children.len(), 2);
171    }
172
173    #[test]
174    fn test_transposition_table() {
175        let state = TestState { value: 0, terminal: false };
176        let mut tree = SearchTree::new(state.clone(), vec![]);
177
178        // Add same state via two different actions
179        let child_state = TestState { value: 1, terminal: false };
180
181        let id1 = tree.add_child(
182            NodeId::new(0),
183            child_state.clone(),
184            TestAction { delta: 1 },
185            vec![],
186            0.5,
187        );
188
189        // Same state again
190        let id2 = tree.add_child(NodeId::new(0), child_state, TestAction { delta: 1 }, vec![], 0.5);
191
192        // Should return same node id (transposition)
193        assert_eq!(id1, id2);
194        // Tree size should be 2 (root + one child)
195        assert_eq!(tree.size(), 2);
196    }
197
198    proptest! {
199        #[test]
200        fn test_tree_size_increases_monotonically(num_children in 1usize..10) {
201            let state = TestState { value: 0, terminal: false };
202            let mut tree = SearchTree::new(state.clone(), vec![]);
203
204            let mut prev_size = tree.size();
205
206            for i in 0..num_children {
207                let child_state = TestState { value: i as i32, terminal: false };
208                tree.add_child(
209                    NodeId::new(0),
210                    child_state,
211                    TestAction { delta: 1 },
212                    vec![],
213                    0.5,
214                );
215
216                // Size should increase or stay same (transposition)
217                prop_assert!(tree.size() >= prev_size);
218                prev_size = tree.size();
219            }
220        }
221    }
222}