use super::traits::{Action, State};
use super::Reward;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeId(pub usize);
impl NodeId {
#[must_use]
pub const fn new(id: usize) -> Self {
Self(id)
}
#[must_use]
pub const fn value(&self) -> usize {
self.0
}
}
#[derive(Debug, Clone)]
pub struct NodeStats {
pub visits: usize,
pub total_reward: f64,
pub mean_reward: f64,
pub prior: f64,
}
impl Default for NodeStats {
fn default() -> Self {
Self { visits: 0, total_reward: 0.0, mean_reward: 0.0, prior: 1.0 }
}
}
impl NodeStats {
pub fn update(&mut self, reward: Reward) {
contract_pre_update!();
self.visits += 1;
self.total_reward += reward;
self.mean_reward = self.total_reward / self.visits as f64;
}
#[must_use]
pub fn ucb1(&self, parent_visits: usize, c: f64) -> f64 {
if self.visits == 0 {
return f64::INFINITY;
}
let exploitation = self.mean_reward;
let exploration = c * ((parent_visits as f64).max(1.0).ln() / self.visits as f64).sqrt();
exploitation + exploration
}
#[must_use]
pub fn puct(&self, parent_visits: usize, c: f64) -> f64 {
let exploitation = self.mean_reward;
let exploration =
c * self.prior * (parent_visits as f64).sqrt() / (1.0 + self.visits as f64);
exploitation + exploration
}
}
#[derive(Debug, Clone)]
pub struct Node<S: State, A: Action> {
pub id: NodeId,
pub state: S,
pub action: Option<A>,
pub parent: Option<NodeId>,
pub children: Vec<NodeId>,
pub stats: NodeStats,
pub expanded: bool,
pub untried_actions: Vec<A>,
}
impl<S: State, A: Action> Node<S, A> {
#[must_use]
pub fn root(state: S, untried_actions: Vec<A>) -> Self {
Self {
id: NodeId::new(0),
state,
action: None,
parent: None,
children: Vec::new(),
stats: NodeStats::default(),
expanded: false,
untried_actions,
}
}
#[must_use]
pub fn child(
id: NodeId,
state: S,
action: A,
parent: NodeId,
untried_actions: Vec<A>,
prior: f64,
) -> Self {
Self {
id,
state,
action: Some(action),
parent: Some(parent),
children: Vec::new(),
stats: NodeStats { prior, ..Default::default() },
expanded: false,
untried_actions,
}
}
#[must_use]
pub fn is_leaf(&self) -> bool {
self.children.is_empty()
}
#[must_use]
pub fn is_fully_expanded(&self) -> bool {
self.expanded && self.untried_actions.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestState {
value: i32,
terminal: bool,
}
impl State for TestState {
fn is_terminal(&self) -> bool {
self.terminal
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct TestAction {
delta: i32,
}
impl Action for TestAction {
fn name(&self) -> &'static str {
"test_action"
}
}
#[test]
fn test_node_id_creation() {
let id = NodeId::new(42);
assert_eq!(id.value(), 42);
}
#[test]
fn test_node_stats_default() {
let stats = NodeStats::default();
assert_eq!(stats.visits, 0);
assert_eq!(stats.total_reward, 0.0);
assert_eq!(stats.mean_reward, 0.0);
assert_eq!(stats.prior, 1.0);
}
#[test]
fn test_node_stats_update() {
let mut stats = NodeStats::default();
stats.update(1.0);
assert_eq!(stats.visits, 1);
assert_eq!(stats.total_reward, 1.0);
assert_eq!(stats.mean_reward, 1.0);
stats.update(0.0);
assert_eq!(stats.visits, 2);
assert_eq!(stats.total_reward, 1.0);
assert_eq!(stats.mean_reward, 0.5);
}
#[test]
fn test_node_root_creation() {
let state = TestState { value: 0, terminal: false };
let actions = vec![TestAction { delta: 1 }];
let node = Node::root(state.clone(), actions);
assert_eq!(node.id, NodeId::new(0));
assert_eq!(node.state, state);
assert!(node.action.is_none());
assert!(node.parent.is_none());
assert!(node.children.is_empty());
assert!(!node.expanded);
}
#[test]
fn test_node_child_creation() {
let state = TestState { value: 1, terminal: false };
let action = TestAction { delta: 1 };
let node =
Node::child(NodeId::new(1), state.clone(), action.clone(), NodeId::new(0), vec![], 0.5);
assert_eq!(node.id, NodeId::new(1));
assert_eq!(node.state, state);
assert_eq!(node.action, Some(action));
assert_eq!(node.parent, Some(NodeId::new(0)));
assert_eq!(node.stats.prior, 0.5);
}
#[test]
fn test_node_is_leaf() {
let state = TestState { value: 0, terminal: false };
let node: Node<TestState, TestAction> = Node::root(state, vec![]);
assert!(node.is_leaf());
}
#[test]
fn test_ucb1_unvisited_node() {
let stats = NodeStats::default();
let score = stats.ucb1(10, std::f64::consts::SQRT_2);
assert!(score.is_infinite());
}
#[test]
fn test_ucb1_visited_node() {
let mut stats = NodeStats::default();
stats.update(0.5);
let score = stats.ucb1(10, std::f64::consts::SQRT_2);
assert!(score > 0.5);
assert!(score < 5.0);
}
#[test]
fn test_ucb1_more_visits_lower_exploration() {
let mut stats1 = NodeStats::default();
stats1.visits = 10;
stats1.total_reward = 5.0;
stats1.mean_reward = 0.5;
let mut stats2 = NodeStats::default();
stats2.visits = 100;
stats2.total_reward = 50.0;
stats2.mean_reward = 0.5;
let score1 = stats1.ucb1(1000, std::f64::consts::SQRT_2);
let score2 = stats2.ucb1(1000, std::f64::consts::SQRT_2);
assert!(score1 > score2);
}
#[test]
fn test_puct_with_prior() {
let mut stats = NodeStats::default();
stats.prior = 0.5;
stats.update(0.3);
let score = stats.puct(100, 2.0);
assert!((score - 5.3).abs() < 0.01);
}
proptest! {
#[test]
fn test_node_stats_update_invariants(rewards in prop::collection::vec(0.0f64..=1.0, 1..100)) {
let mut stats = NodeStats::default();
for r in &rewards {
stats.update(*r);
}
prop_assert_eq!(stats.visits, rewards.len());
prop_assert!((stats.total_reward - rewards.iter().sum::<f64>()).abs() < 1e-10);
prop_assert!((stats.mean_reward - rewards.iter().sum::<f64>() / rewards.len() as f64).abs() < 1e-10);
}
#[test]
fn test_ucb1_exploration_decreases_with_visits(parent_visits in 10usize..1000, c in 0.1f64..5.0) {
let mut stats1 = NodeStats::default();
stats1.visits = 10;
stats1.mean_reward = 0.5;
let mut stats2 = NodeStats::default();
stats2.visits = 100;
stats2.mean_reward = 0.5;
let ucb1 = stats1.ucb1(parent_visits, c);
let ucb2 = stats2.ucb1(parent_visits, c);
prop_assert!(ucb1 > ucb2, "UCB1 with fewer visits should be higher");
}
#[test]
fn test_ucb1_higher_reward_higher_score(parent_visits in 10usize..1000, c in 0.1f64..5.0) {
let mut stats1 = NodeStats::default();
stats1.visits = 50;
stats1.mean_reward = 0.3;
let mut stats2 = NodeStats::default();
stats2.visits = 50;
stats2.mean_reward = 0.7;
let ucb1 = stats1.ucb1(parent_visits, c);
let ucb2 = stats2.ucb1(parent_visits, c);
prop_assert!(ucb2 > ucb1, "Higher reward should give higher UCB");
}
#[test]
fn test_puct_prior_increases_exploration(prior in 0.1f64..0.9) {
let mut stats1 = NodeStats::default();
stats1.visits = 10;
stats1.mean_reward = 0.5;
stats1.prior = prior;
let mut stats2 = NodeStats::default();
stats2.visits = 10;
stats2.mean_reward = 0.5;
stats2.prior = prior * 2.0;
let puct1 = stats1.puct(100, 2.0);
let puct2 = stats2.puct(100, 2.0);
prop_assert!(puct2 > puct1, "Higher prior should give higher PUCT");
}
}
}