1use super::traits::{Action, State};
7use super::Reward;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub struct NodeId(pub usize);
12
13impl NodeId {
14 #[must_use]
16 pub const fn new(id: usize) -> Self {
17 Self(id)
18 }
19
20 #[must_use]
22 pub const fn value(&self) -> usize {
23 self.0
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct NodeStats {
30 pub visits: usize,
32 pub total_reward: f64,
34 pub mean_reward: f64,
36 pub prior: f64,
38}
39
40impl Default for NodeStats {
41 fn default() -> Self {
42 Self { visits: 0, total_reward: 0.0, mean_reward: 0.0, prior: 1.0 }
43 }
44}
45
46impl NodeStats {
47 pub fn update(&mut self, reward: Reward) {
49 contract_pre_update!();
50 self.visits += 1;
51 self.total_reward += reward;
52 self.mean_reward = self.total_reward / self.visits as f64;
53 }
54
55 #[must_use]
57 pub fn ucb1(&self, parent_visits: usize, c: f64) -> f64 {
58 if self.visits == 0 {
59 return f64::INFINITY;
60 }
61 let exploitation = self.mean_reward;
62 let exploration = c * ((parent_visits as f64).max(1.0).ln() / self.visits as f64).sqrt();
63 exploitation + exploration
64 }
65
66 #[must_use]
68 pub fn puct(&self, parent_visits: usize, c: f64) -> f64 {
69 let exploitation = self.mean_reward;
70 let exploration =
71 c * self.prior * (parent_visits as f64).sqrt() / (1.0 + self.visits as f64);
72 exploitation + exploration
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct Node<S: State, A: Action> {
79 pub id: NodeId,
81 pub state: S,
83 pub action: Option<A>,
85 pub parent: Option<NodeId>,
87 pub children: Vec<NodeId>,
89 pub stats: NodeStats,
91 pub expanded: bool,
93 pub untried_actions: Vec<A>,
95}
96
97impl<S: State, A: Action> Node<S, A> {
98 #[must_use]
100 pub fn root(state: S, untried_actions: Vec<A>) -> Self {
101 Self {
102 id: NodeId::new(0),
103 state,
104 action: None,
105 parent: None,
106 children: Vec::new(),
107 stats: NodeStats::default(),
108 expanded: false,
109 untried_actions,
110 }
111 }
112
113 #[must_use]
115 pub fn child(
116 id: NodeId,
117 state: S,
118 action: A,
119 parent: NodeId,
120 untried_actions: Vec<A>,
121 prior: f64,
122 ) -> Self {
123 Self {
124 id,
125 state,
126 action: Some(action),
127 parent: Some(parent),
128 children: Vec::new(),
129 stats: NodeStats { prior, ..Default::default() },
130 expanded: false,
131 untried_actions,
132 }
133 }
134
135 #[must_use]
137 pub fn is_leaf(&self) -> bool {
138 self.children.is_empty()
139 }
140
141 #[must_use]
143 pub fn is_fully_expanded(&self) -> bool {
144 self.expanded && self.untried_actions.is_empty()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use proptest::prelude::*;
152
153 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
155 struct TestState {
156 value: i32,
157 terminal: bool,
158 }
159
160 impl State for TestState {
161 fn is_terminal(&self) -> bool {
162 self.terminal
163 }
164 }
165
166 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
168 struct TestAction {
169 delta: i32,
170 }
171
172 impl Action for TestAction {
173 fn name(&self) -> &'static str {
174 "test_action"
175 }
176 }
177
178 #[test]
183 fn test_node_id_creation() {
184 let id = NodeId::new(42);
185 assert_eq!(id.value(), 42);
186 }
187
188 #[test]
189 fn test_node_stats_default() {
190 let stats = NodeStats::default();
191 assert_eq!(stats.visits, 0);
192 assert_eq!(stats.total_reward, 0.0);
193 assert_eq!(stats.mean_reward, 0.0);
194 assert_eq!(stats.prior, 1.0);
195 }
196
197 #[test]
198 fn test_node_stats_update() {
199 let mut stats = NodeStats::default();
200 stats.update(1.0);
201 assert_eq!(stats.visits, 1);
202 assert_eq!(stats.total_reward, 1.0);
203 assert_eq!(stats.mean_reward, 1.0);
204
205 stats.update(0.0);
206 assert_eq!(stats.visits, 2);
207 assert_eq!(stats.total_reward, 1.0);
208 assert_eq!(stats.mean_reward, 0.5);
209 }
210
211 #[test]
212 fn test_node_root_creation() {
213 let state = TestState { value: 0, terminal: false };
214 let actions = vec![TestAction { delta: 1 }];
215 let node = Node::root(state.clone(), actions);
216
217 assert_eq!(node.id, NodeId::new(0));
218 assert_eq!(node.state, state);
219 assert!(node.action.is_none());
220 assert!(node.parent.is_none());
221 assert!(node.children.is_empty());
222 assert!(!node.expanded);
223 }
224
225 #[test]
226 fn test_node_child_creation() {
227 let state = TestState { value: 1, terminal: false };
228 let action = TestAction { delta: 1 };
229 let node =
230 Node::child(NodeId::new(1), state.clone(), action.clone(), NodeId::new(0), vec![], 0.5);
231
232 assert_eq!(node.id, NodeId::new(1));
233 assert_eq!(node.state, state);
234 assert_eq!(node.action, Some(action));
235 assert_eq!(node.parent, Some(NodeId::new(0)));
236 assert_eq!(node.stats.prior, 0.5);
237 }
238
239 #[test]
240 fn test_node_is_leaf() {
241 let state = TestState { value: 0, terminal: false };
242 let node: Node<TestState, TestAction> = Node::root(state, vec![]);
243 assert!(node.is_leaf());
244 }
245
246 #[test]
251 fn test_ucb1_unvisited_node() {
252 let stats = NodeStats::default();
253 let score = stats.ucb1(10, std::f64::consts::SQRT_2);
254 assert!(score.is_infinite());
255 }
256
257 #[test]
258 fn test_ucb1_visited_node() {
259 let mut stats = NodeStats::default();
260 stats.update(0.5);
261 let score = stats.ucb1(10, std::f64::consts::SQRT_2);
262
263 assert!(score > 0.5);
266 assert!(score < 5.0);
267 }
268
269 #[test]
270 fn test_ucb1_more_visits_lower_exploration() {
271 let mut stats1 = NodeStats::default();
272 stats1.visits = 10;
273 stats1.total_reward = 5.0;
274 stats1.mean_reward = 0.5;
275
276 let mut stats2 = NodeStats::default();
277 stats2.visits = 100;
278 stats2.total_reward = 50.0;
279 stats2.mean_reward = 0.5;
280
281 let score1 = stats1.ucb1(1000, std::f64::consts::SQRT_2);
282 let score2 = stats2.ucb1(1000, std::f64::consts::SQRT_2);
283
284 assert!(score1 > score2);
286 }
287
288 #[test]
289 fn test_puct_with_prior() {
290 let mut stats = NodeStats::default();
291 stats.prior = 0.5;
292 stats.update(0.3);
293
294 let score = stats.puct(100, 2.0);
295
296 assert!((score - 5.3).abs() < 0.01);
299 }
300
301 proptest! {
306 #[test]
307 fn test_node_stats_update_invariants(rewards in prop::collection::vec(0.0f64..=1.0, 1..100)) {
308 let mut stats = NodeStats::default();
309
310 for r in &rewards {
311 stats.update(*r);
312 }
313
314 prop_assert_eq!(stats.visits, rewards.len());
315 prop_assert!((stats.total_reward - rewards.iter().sum::<f64>()).abs() < 1e-10);
316 prop_assert!((stats.mean_reward - rewards.iter().sum::<f64>() / rewards.len() as f64).abs() < 1e-10);
317 }
318
319 #[test]
320 fn test_ucb1_exploration_decreases_with_visits(parent_visits in 10usize..1000, c in 0.1f64..5.0) {
321 let mut stats1 = NodeStats::default();
322 stats1.visits = 10;
323 stats1.mean_reward = 0.5;
324
325 let mut stats2 = NodeStats::default();
326 stats2.visits = 100;
327 stats2.mean_reward = 0.5;
328
329 let ucb1 = stats1.ucb1(parent_visits, c);
330 let ucb2 = stats2.ucb1(parent_visits, c);
331
332 prop_assert!(ucb1 > ucb2, "UCB1 with fewer visits should be higher");
334 }
335
336 #[test]
337 fn test_ucb1_higher_reward_higher_score(parent_visits in 10usize..1000, c in 0.1f64..5.0) {
338 let mut stats1 = NodeStats::default();
339 stats1.visits = 50;
340 stats1.mean_reward = 0.3;
341
342 let mut stats2 = NodeStats::default();
343 stats2.visits = 50;
344 stats2.mean_reward = 0.7;
345
346 let ucb1 = stats1.ucb1(parent_visits, c);
347 let ucb2 = stats2.ucb1(parent_visits, c);
348
349 prop_assert!(ucb2 > ucb1, "Higher reward should give higher UCB");
351 }
352
353 #[test]
354 fn test_puct_prior_increases_exploration(prior in 0.1f64..0.9) {
355 let mut stats1 = NodeStats::default();
356 stats1.visits = 10;
357 stats1.mean_reward = 0.5;
358 stats1.prior = prior;
359
360 let mut stats2 = NodeStats::default();
361 stats2.visits = 10;
362 stats2.mean_reward = 0.5;
363 stats2.prior = prior * 2.0;
364
365 let puct1 = stats1.puct(100, 2.0);
366 let puct2 = stats2.puct(100, 2.0);
367
368 prop_assert!(puct2 > puct1, "Higher prior should give higher PUCT");
370 }
371 }
372}