1use rand::prelude::IteratorRandom;
2use std::fmt;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use crate::game_state::GameState;
6
7pub struct MCTSNode<S: GameState> {
13 pub state: S,
15
16 pub action: Option<S::Action>,
18
19 pub visits: AtomicU64,
22
23 pub total_reward: AtomicU64,
26
27 pub sum_squared_reward: AtomicU64,
29
30 pub rave_visits: AtomicU64,
32
33 pub rave_reward: AtomicU64,
35
36 pub prior: AtomicU64,
39
40 pub children: Vec<MCTSNode<S>>,
42
43 pub unexpanded_actions: Vec<S::Action>,
46
47 pub depth: usize,
49
50 pub player: S::Player,
53}
54
55const REWARD_SCALE: f64 = 1_000_000.0;
58
59fn float_to_scaled_u64(value: f64) -> u64 {
61 ((value * REWARD_SCALE).max(0.0) as u64).min(u64::MAX / 2)
62}
63
64fn scaled_u64_to_float(value: u64) -> f64 {
66 value as f64 / REWARD_SCALE
67}
68
69impl<S: GameState> MCTSNode<S> {
70 pub fn new(
72 state: S,
73 action: Option<S::Action>,
74 parent_player: Option<S::Player>,
75 depth: usize,
76 ) -> Self {
77 let player = parent_player.unwrap_or_else(|| state.get_current_player());
78 let unexpanded_actions = state.get_legal_actions();
79
80 MCTSNode {
81 state,
82 action,
83 visits: AtomicU64::new(0),
84 total_reward: AtomicU64::new(0),
85 sum_squared_reward: AtomicU64::new(0),
86 rave_visits: AtomicU64::new(0),
87 rave_reward: AtomicU64::new(0),
88 prior: AtomicU64::new(float_to_scaled_u64(1.0)), children: Vec::new(),
90 unexpanded_actions,
91 depth,
92 player,
93 }
94 }
95
96 pub fn visits(&self) -> u64 {
98 self.visits.load(Ordering::Relaxed)
99 }
100
101 pub fn total_reward(&self) -> f64 {
103 scaled_u64_to_float(self.total_reward.load(Ordering::Relaxed))
104 }
105
106 pub fn prior(&self) -> f64 {
108 scaled_u64_to_float(self.prior.load(Ordering::Relaxed))
109 }
110
111 pub fn set_prior(&self, prior: f64) {
113 self.prior.store(float_to_scaled_u64(prior), Ordering::Relaxed);
114 }
115
116 pub fn value(&self) -> f64 {
118 let visits = self.visits();
119 if visits == 0 {
120 return 0.0;
121 }
122 self.total_reward() / visits as f64
123 }
124
125 pub fn increment_visits(&self) {
127 self.visits.fetch_add(1, Ordering::Relaxed);
128 }
129
130 pub fn add_reward(&self, reward: f64) {
132 self.total_reward
133 .fetch_add(float_to_scaled_u64(reward), Ordering::Relaxed);
134 }
135
136 pub fn add_squared_reward(&self, reward: f64) {
138 self.sum_squared_reward
139 .fetch_add(float_to_scaled_u64(reward * reward), Ordering::Relaxed);
140 }
141
142 pub fn sum_squared_reward(&self) -> f64 {
144 scaled_u64_to_float(self.sum_squared_reward.load(Ordering::Relaxed))
145 }
146
147 pub fn increment_rave_visits(&self) {
149 self.rave_visits.fetch_add(1, Ordering::Relaxed);
150 }
151
152 pub fn add_rave_reward(&self, reward: f64) {
154 self.rave_reward
155 .fetch_add(float_to_scaled_u64(reward), Ordering::Relaxed);
156 }
157
158 pub fn rave_visits(&self) -> u64 {
160 self.rave_visits.load(Ordering::Relaxed)
161 }
162
163 pub fn rave_value(&self) -> f64 {
165 let visits = self.rave_visits();
166 if visits == 0 {
167 return 0.0;
168 }
169 scaled_u64_to_float(self.rave_reward.load(Ordering::Relaxed)) / visits as f64
170 }
171
172 pub fn is_fully_expanded(&self) -> bool {
174 self.unexpanded_actions.is_empty()
175 }
176
177 pub fn is_leaf(&self) -> bool {
179 self.children.is_empty()
180 }
181
182 pub fn expand(&mut self, action_index: usize) -> Option<&mut MCTSNode<S>> {
203 if action_index >= self.unexpanded_actions.len() {
204 return None;
205 }
206
207 let action = self.unexpanded_actions.swap_remove(action_index);
208 let next_state = self.state.apply_action(&action);
209 let current_player = self.state.get_current_player();
210
211 let child = MCTSNode::new(
212 next_state,
213 Some(action),
214 Some(current_player),
215 self.depth + 1,
216 );
217
218 self.children.push(child);
219 self.children.last_mut()
220 }
221
222 pub fn expand_with_pool(
227 &mut self,
228 action_index: usize,
229 pool: &mut NodePool<S>,
230 ) -> Option<&mut MCTSNode<S>> {
231 if action_index >= self.unexpanded_actions.len() {
232 return None;
233 }
234
235 let action = self.unexpanded_actions.swap_remove(action_index);
236 let next_state = self.state.apply_action(&action);
237 let current_player = self.state.get_current_player();
238
239 let node = pool.create_node(
241 next_state,
242 Some(action),
243 Some(current_player),
244 self.depth + 1,
245 );
246
247 self.children.push(node);
248 self.children.last_mut()
249 }
250
251 pub fn expand_random(&mut self) -> Option<&mut MCTSNode<S>> {
253 if self.unexpanded_actions.is_empty() {
254 return None;
255 }
256
257 let mut rng = rand::thread_rng();
259 let index = (0..self.unexpanded_actions.len()).choose(&mut rng).unwrap();
260
261 self.expand(index)
262 }
263
264 pub fn expand_random_with_pool(&mut self, pool: &mut NodePool<S>) -> Option<&mut MCTSNode<S>> {
266 if self.unexpanded_actions.is_empty() {
267 return None;
268 }
269
270 let mut rng = rand::thread_rng();
272 let index = (0..self.unexpanded_actions.len()).choose(&mut rng).unwrap();
273
274 self.expand_with_pool(index, pool)
275 }
276}
277
278pub struct NodePool<S: GameState> {
284 template_state: S,
286
287 free_nodes: Vec<MCTSNode<S>>,
289
290 stats: NodePoolStats,
292}
293
294#[derive(Debug, Default, Clone)]
296pub struct NodePoolStats {
297 pub total_created: usize,
299
300 pub total_allocations: usize,
302
303 pub total_recycled: usize,
305}
306
307impl<S: GameState> NodePool<S> {
308 pub fn new(template_state: S, initial_size: usize) -> Self {
315 let mut pool = NodePool {
316 template_state,
317 free_nodes: Vec::with_capacity(initial_size),
318 stats: NodePoolStats::default(),
319 };
320
321 if initial_size > 0 {
323 pool.preallocate(initial_size);
324 }
325
326 pool
327 }
328
329 fn preallocate(&mut self, count: usize) {
331 for _ in 0..count {
332 let node = MCTSNode {
333 state: self.template_state.clone(),
334 action: None,
335 visits: AtomicU64::new(0),
336 total_reward: AtomicU64::new(0),
337 sum_squared_reward: AtomicU64::new(0),
338 rave_visits: AtomicU64::new(0),
339 rave_reward: AtomicU64::new(0),
340 prior: AtomicU64::new(float_to_scaled_u64(1.0)),
341 children: Vec::new(),
342 unexpanded_actions: Vec::new(),
343 depth: 0,
344 player: self.template_state.get_current_player(),
345 };
346
347 self.free_nodes.push(node);
348 self.stats.total_created += 1;
349 }
350 }
351
352 pub fn create_node(
354 &mut self,
355 state: S,
356 action: Option<S::Action>,
357 parent_player: Option<S::Player>,
358 depth: usize,
359 ) -> MCTSNode<S> {
360 self.stats.total_allocations += 1;
361
362 if let Some(mut node) = self.free_nodes.pop() {
363 let player = match &parent_player {
365 Some(p) => p.clone(),
366 None => state.get_current_player(),
367 };
368
369 let legal_actions = state.get_legal_actions();
371
372 node.state = state;
374 node.action = action;
375 node.visits = AtomicU64::new(0);
376 node.total_reward = AtomicU64::new(0);
377 node.sum_squared_reward = AtomicU64::new(0);
378 node.rave_visits = AtomicU64::new(0);
379 node.rave_reward = AtomicU64::new(0);
380 node.prior = AtomicU64::new(float_to_scaled_u64(1.0));
381 node.children.clear();
382 node.depth = depth;
383 node.player = player;
384 node.unexpanded_actions = legal_actions;
385
386 node
387 } else {
388 self.stats.total_created += 1;
390 MCTSNode::new(state, action, parent_player, depth)
391 }
392 }
393
394 pub fn recycle_node(&mut self, mut node: MCTSNode<S>) {
396 self.stats.total_recycled += 1;
397
398 node.children.clear();
400 node.unexpanded_actions.clear();
401
402 self.free_nodes.push(node);
404 }
405
406 pub fn recycle_tree(&mut self, mut root: MCTSNode<S>) {
408 let mut children = std::mem::take(&mut root.children);
410 for child in children.drain(..) {
411 self.recycle_tree(child);
412 }
413
414 self.recycle_node(root);
416 }
417
418 pub fn get_stats(&self) -> &NodePoolStats {
420 &self.stats
421 }
422
423 pub fn available_nodes(&self) -> usize {
425 self.free_nodes.len()
426 }
427}
428
429impl<S: GameState> Clone for NodePool<S> {
431 fn clone(&self) -> Self {
432 NodePool {
436 template_state: self.template_state.clone(),
437 free_nodes: Vec::new(), stats: self.stats.clone(),
439 }
440 }
441}
442
443#[derive(Debug, Clone)]
448pub struct NodePath {
449 pub indices: Vec<usize>,
451}
452
453impl NodePath {
454 pub fn new() -> Self {
456 NodePath {
457 indices: Vec::new(),
458 }
459 }
460
461 pub fn from_indices(indices: Vec<usize>) -> Self {
463 NodePath { indices }
464 }
465
466 pub fn push(&mut self, index: usize) {
468 self.indices.push(index);
469 }
470
471 pub fn len(&self) -> usize {
473 self.indices.len()
474 }
475
476 pub fn is_empty(&self) -> bool {
478 self.indices.is_empty()
479 }
480}
481
482impl Default for NodePath {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488impl fmt::Display for NodePath {
489 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
490 write!(f, "Path[")?;
491 for (i, idx) in self.indices.iter().enumerate() {
492 if i > 0 {
493 write!(f, " -> ")?;
494 }
495 write!(f, "{}", idx)?;
496 }
497 write!(f, "]")
498 }
499}
500
501pub fn recycle_subtree_recursive<S: GameState>(mut node: MCTSNode<S>, pool: &mut NodePool<S>) {
505 let mut children = std::mem::take(&mut node.children);
507
508 for child in children.drain(..) {
510 recycle_subtree_recursive(child, pool);
511 }
512
513 pool.recycle_node(node);
515}