1use std::collections::HashMap;
7
8use super::node::{Node, NodeId};
9use super::traits::{Action, State};
10
11#[derive(Debug)]
13pub struct SearchTree<S: State, A: Action> {
14 nodes: Vec<Node<S, A>>,
16 state_map: HashMap<u64, NodeId>,
18 pub(crate) root_id: NodeId,
20}
21
22impl<S: State, A: Action> SearchTree<S, A> {
23 #[must_use]
25 pub fn new(root_state: S, root_actions: Vec<A>) -> Self {
26 let root = Node::root(root_state.clone(), root_actions);
27 let state_hash = root_state.state_hash();
28 let mut state_map = HashMap::new();
29 state_map.insert(state_hash, NodeId::new(0));
30
31 Self { nodes: vec![root], state_map, root_id: NodeId::new(0) }
32 }
33
34 #[must_use]
36 pub fn root(&self) -> &Node<S, A> {
37 &self.nodes[self.root_id.0]
38 }
39
40 #[must_use]
42 pub fn get(&self, id: NodeId) -> Option<&Node<S, A>> {
43 self.nodes.get(id.0)
44 }
45
46 pub fn get_mut(&mut self, id: NodeId) -> Option<&mut Node<S, A>> {
48 self.nodes.get_mut(id.0)
49 }
50
51 pub fn add_child(
53 &mut self,
54 parent_id: NodeId,
55 state: S,
56 action: A,
57 untried_actions: Vec<A>,
58 prior: f64,
59 ) -> NodeId {
60 let state_hash = state.state_hash();
61
62 if let Some(&existing_id) = self.state_map.get(&state_hash) {
64 if let Some(parent) = self.nodes.get_mut(parent_id.0) {
66 if !parent.children.contains(&existing_id) {
67 parent.children.push(existing_id);
68 }
69 }
70 return existing_id;
71 }
72
73 let child_id = NodeId::new(self.nodes.len());
74 let child = Node::child(child_id, state, action, parent_id, untried_actions, prior);
75
76 self.nodes.push(child);
77 self.state_map.insert(state_hash, child_id);
78
79 if let Some(parent) = self.nodes.get_mut(parent_id.0) {
80 parent.children.push(child_id);
81 }
82
83 child_id
84 }
85
86 #[must_use]
88 pub fn size(&self) -> usize {
89 self.nodes.len()
90 }
91
92 #[must_use]
94 pub fn children(&self, id: NodeId) -> Vec<&Node<S, A>> {
95 self.nodes
96 .get(id.0)
97 .map(|n| n.children.iter().filter_map(|&cid| self.get(cid)).collect())
98 .unwrap_or_default()
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use proptest::prelude::*;
106
107 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
109 struct TestState {
110 value: i32,
111 terminal: bool,
112 }
113
114 impl State for TestState {
115 fn is_terminal(&self) -> bool {
116 self.terminal
117 }
118 }
119
120 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
122 struct TestAction {
123 delta: i32,
124 }
125
126 impl Action for TestAction {
127 fn name(&self) -> &'static str {
128 "test_action"
129 }
130 }
131
132 #[test]
133 fn test_search_tree_creation() {
134 let state = TestState { value: 0, terminal: false };
135 let actions = vec![TestAction { delta: 1 }, TestAction { delta: -1 }];
136 let tree = SearchTree::new(state.clone(), actions);
137
138 assert_eq!(tree.size(), 1);
139 assert_eq!(tree.root().state, state);
140 }
141
142 #[test]
143 fn test_search_tree_add_child() {
144 let state = TestState { value: 0, terminal: false };
145 let actions = vec![TestAction { delta: 1 }];
146 let mut tree = SearchTree::new(state, actions);
147
148 let child_state = TestState { value: 1, terminal: false };
149 let child_action = TestAction { delta: 1 };
150 let child_id =
151 tree.add_child(NodeId::new(0), child_state.clone(), child_action, vec![], 0.5);
152
153 assert_eq!(tree.size(), 2);
154 let child = tree.get(child_id).expect("key should exist");
155 assert_eq!(child.state, child_state);
156 }
157
158 #[test]
159 fn test_tree_children() {
160 let state = TestState { value: 0, terminal: false };
161 let mut tree = SearchTree::new(state.clone(), vec![]);
162
163 let child1 = TestState { value: 1, terminal: false };
164 let child2 = TestState { value: 2, terminal: false };
165
166 tree.add_child(NodeId::new(0), child1, TestAction { delta: 1 }, vec![], 0.5);
167 tree.add_child(NodeId::new(0), child2, TestAction { delta: 2 }, vec![], 0.5);
168
169 let children = tree.children(NodeId::new(0));
170 assert_eq!(children.len(), 2);
171 }
172
173 #[test]
174 fn test_transposition_table() {
175 let state = TestState { value: 0, terminal: false };
176 let mut tree = SearchTree::new(state.clone(), vec![]);
177
178 let child_state = TestState { value: 1, terminal: false };
180
181 let id1 = tree.add_child(
182 NodeId::new(0),
183 child_state.clone(),
184 TestAction { delta: 1 },
185 vec![],
186 0.5,
187 );
188
189 let id2 = tree.add_child(NodeId::new(0), child_state, TestAction { delta: 1 }, vec![], 0.5);
191
192 assert_eq!(id1, id2);
194 assert_eq!(tree.size(), 2);
196 }
197
198 proptest! {
199 #[test]
200 fn test_tree_size_increases_monotonically(num_children in 1usize..10) {
201 let state = TestState { value: 0, terminal: false };
202 let mut tree = SearchTree::new(state.clone(), vec![]);
203
204 let mut prev_size = tree.size();
205
206 for i in 0..num_children {
207 let child_state = TestState { value: i as i32, terminal: false };
208 tree.add_child(
209 NodeId::new(0),
210 child_state,
211 TestAction { delta: 1 },
212 vec![],
213 0.5,
214 );
215
216 prop_assert!(tree.size() >= prev_size);
218 prev_size = tree.size();
219 }
220 }
221 }
222}