use std::sync::Arc;
use parking_lot::RwLock;
#[derive(Debug, PartialEq)]
pub struct TraversalStateInternal {
pub node_idx: usize,
pub chosen_child_idx: usize,
pub player_idx: u8,
}
#[derive(Debug, Clone)]
pub struct TraversalState {
inner_state: Arc<RwLock<TraversalStateInternal>>,
}
impl PartialEq for TraversalState {
fn eq(&self, other: &Self) -> bool {
*self.inner_state.read() == *other.inner_state.read()
}
}
impl Eq for TraversalState {}
impl TraversalState {
pub fn new(node_idx: usize, chosen_child_idx: usize, player_idx: u8) -> Self {
TraversalState {
inner_state: Arc::new(RwLock::new(TraversalStateInternal {
node_idx,
chosen_child_idx,
player_idx,
})),
}
}
pub fn new_root(player_idx: u8) -> Self {
TraversalState::new(0, 0, player_idx)
}
pub fn node_idx(&self) -> usize {
self.inner_state.read().node_idx
}
pub fn player_idx(&self) -> u8 {
self.inner_state.read().player_idx
}
pub fn chosen_child_idx(&self) -> usize {
self.inner_state.read().chosen_child_idx
}
pub fn move_to(&self, node_idx: usize, chosen_child_idx: usize) {
let mut state = self.inner_state.write();
state.node_idx = node_idx;
state.chosen_child_idx = chosen_child_idx;
}
#[inline]
pub fn get_all(&self) -> (usize, usize, u8) {
let state = self.inner_state.read();
(state.node_idx, state.chosen_child_idx, state.player_idx)
}
#[inline]
pub fn get_position(&self) -> (usize, usize) {
let state = self.inner_state.read();
(state.node_idx, state.chosen_child_idx)
}
}
#[derive(Debug, Clone)]
pub struct TraversalSet {
states: Vec<TraversalState>,
}
impl TraversalSet {
pub fn new(num_players: usize) -> Self {
let states = (0..num_players)
.map(|i| TraversalState::new_root(i as u8))
.collect();
TraversalSet { states }
}
pub fn get(&self, player_idx: usize) -> TraversalState {
self.states[player_idx].clone()
}
pub fn fork(&self) -> TraversalSet {
let states = self
.states
.iter()
.map(|ts| {
let (node_idx, chosen_child_idx, player_idx) = ts.get_all();
TraversalState::new(node_idx, chosen_child_idx, player_idx)
})
.collect();
TraversalSet { states }
}
pub fn num_players(&self) -> usize {
self.states.len()
}
pub fn move_all_to(&self, to_node_idx: usize, child_idx: usize) {
for state in &self.states {
state.move_to(to_node_idx, child_idx);
}
}
pub fn iter(&self) -> impl Iterator<Item = &TraversalState> {
self.states.iter()
}
}
#[cfg(test)]
mod tests {
use super::{TraversalSet, TraversalState};
#[test]
fn test_new_and_getters() {
let traversal = TraversalState::new(5, 10, 2);
assert_eq!(traversal.node_idx(), 5);
assert_eq!(traversal.chosen_child_idx(), 10);
assert_eq!(traversal.player_idx(), 2);
}
#[test]
fn test_new_root() {
let traversal = TraversalState::new_root(3);
assert_eq!(traversal.node_idx(), 0);
assert_eq!(traversal.chosen_child_idx(), 0);
assert_eq!(traversal.player_idx(), 3);
}
#[test]
fn test_move_to() {
let traversal = TraversalState::new_root(0);
assert_eq!(traversal.node_idx(), 0);
assert_eq!(traversal.chosen_child_idx(), 0);
traversal.move_to(42, 7);
assert_eq!(traversal.node_idx(), 42);
assert_eq!(traversal.chosen_child_idx(), 7);
assert_eq!(traversal.player_idx(), 0);
}
#[test]
fn test_cloned_traversal_share_loc() {
let traversal = TraversalState::new(0, 0, 0);
let cloned = traversal.clone();
assert_eq!(traversal.node_idx(), 0);
assert_eq!(traversal.player_idx(), 0);
assert_eq!(traversal.chosen_child_idx(), 0);
assert_eq!(cloned.node_idx(), 0);
assert_eq!(cloned.player_idx(), 0);
assert_eq!(cloned.chosen_child_idx(), 0);
traversal.move_to(2, 42);
assert_eq!(traversal.node_idx(), 2);
assert_eq!(traversal.chosen_child_idx(), 42);
assert_eq!(cloned.node_idx(), 2);
assert_eq!(cloned.chosen_child_idx(), 42);
}
#[test]
fn test_get_all_after_move() {
let traversal = TraversalState::new(0, 0, 2);
traversal.move_to(100, 50);
let (node_idx, chosen_child_idx, player_idx) = traversal.get_all();
assert_eq!(node_idx, 100);
assert_eq!(chosen_child_idx, 50);
assert_eq!(player_idx, 2); }
#[test]
fn test_traversal_set_new() {
let set = TraversalSet::new(3);
assert_eq!(set.num_players(), 3);
for i in 0..3 {
let ts = set.get(i);
assert_eq!(ts.node_idx(), 0);
assert_eq!(ts.chosen_child_idx(), 0);
assert_eq!(ts.player_idx(), i as u8);
}
}
#[test]
fn test_clone_shares_state() {
let set = TraversalSet::new(2);
let cloned = set.clone();
set.get(0).move_to(10, 3);
assert_eq!(cloned.get(0).node_idx(), 10);
assert_eq!(cloned.get(0).chosen_child_idx(), 3);
}
#[test]
fn test_fork_is_independent() {
let set = TraversalSet::new(2);
set.get(0).move_to(5, 2);
let forked = set.fork();
assert_eq!(forked.get(0).node_idx(), 5);
assert_eq!(forked.get(0).chosen_child_idx(), 2);
forked.get(0).move_to(20, 7);
assert_eq!(set.get(0).node_idx(), 5);
assert_eq!(set.get(0).chosen_child_idx(), 2);
assert_eq!(forked.get(0).node_idx(), 20);
assert_eq!(forked.get(0).chosen_child_idx(), 7);
}
}