use std::{
cell::{Ref, RefCell, RefMut},
rc::Rc,
};
use crate::arena::GameState;
use super::{Node, NodeData};
#[derive(Debug)]
pub struct CFRStateInternal {
pub nodes: Vec<Node>,
pub starting_game_state: GameState,
next_node_idx: usize,
}
#[derive(Debug, Clone)]
pub struct CFRState {
inner_state: Rc<RefCell<CFRStateInternal>>,
}
impl CFRState {
pub fn new(game_state: GameState) -> Self {
CFRState {
inner_state: Rc::new(RefCell::new(CFRStateInternal {
nodes: vec![Node::new_root()],
starting_game_state: game_state.clone(),
next_node_idx: 1,
})),
}
}
pub fn starting_game_state(&self) -> GameState {
self.inner_state.borrow().starting_game_state.clone()
}
pub fn add(&mut self, parent_idx: usize, child_idx: usize, data: NodeData) -> usize {
let mut state = self.inner_state.borrow_mut();
let idx = state.next_node_idx;
state.next_node_idx += 1;
let node = Node::new(idx, parent_idx, child_idx, data);
state.nodes.push(node);
state.nodes[parent_idx].set_child(child_idx, idx);
idx
}
pub fn get(&self, idx: usize) -> Option<Ref<Node>> {
let inner_ref = self.inner_state.borrow();
Ref::filter_map(inner_ref, |state| state.nodes.get(idx)).ok()
}
pub fn get_mut(&mut self, idx: usize) -> Option<RefMut<Node>> {
let inner_ref = self.inner_state.borrow_mut();
RefMut::filter_map(inner_ref, |state| state.nodes.get_mut(idx)).ok()
}
pub fn internal_state(&self) -> &Rc<RefCell<CFRStateInternal>> {
&self.inner_state
}
}
#[derive(Debug, PartialEq)]
pub struct TraversalStateInternal {
pub node_idx: usize,
pub chosen_child_idx: usize,
pub player_idx: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct TraversalState {
inner_state: Rc<RefCell<TraversalStateInternal>>,
}
impl TraversalState {
pub fn new(node_idx: usize, chosen_child_idx: usize, player_idx: usize) -> Self {
TraversalState {
inner_state: Rc::new(RefCell::new(TraversalStateInternal {
node_idx,
chosen_child_idx,
player_idx,
})),
}
}
pub fn new_root(player_idx: usize) -> Self {
TraversalState::new(0, 0, player_idx)
}
pub fn node_idx(&self) -> usize {
self.inner_state.borrow().node_idx
}
pub fn player_idx(&self) -> usize {
self.inner_state.borrow().player_idx
}
pub fn chosen_child_idx(&self) -> usize {
self.inner_state.borrow().chosen_child_idx
}
pub fn move_to(&mut self, node_idx: usize, chosen_child_idx: usize) {
self.inner_state.borrow_mut().node_idx = node_idx;
self.inner_state.borrow_mut().chosen_child_idx = chosen_child_idx;
}
}
#[cfg(test)]
mod tests {
use crate::arena::cfr::{NodeData, PlayerData, TraversalState};
use crate::arena::GameState;
use super::CFRState;
#[test]
fn test_add_get_node() {
let mut state = CFRState::new(GameState::new_starting(vec![100.0; 3], 10.0, 5.0, 0.0, 0));
let new_data = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
});
let player_idx: usize = state.add(0, 0, new_data);
let node = state.get(player_idx).unwrap();
match &node.data {
NodeData::Player(pd) => assert!(pd.regret_matcher.is_none()),
_ => panic!("Expected player data"),
}
assert_eq!(node.parent, Some(0));
assert_eq!(node.parent_child_idx, Some(0));
let parent = state.get(0).unwrap();
assert_eq!(parent.get_child(0), Some(player_idx));
}
#[test]
fn test_node_get_not_exist() {
let state = CFRState::new(GameState::new_starting(vec![100.0; 3], 10.0, 5.0, 0.0, 0));
let root = state.get(0);
assert!(root.is_some());
let node = state.get(100);
assert!(node.is_none());
}
#[test]
fn test_cloned_traversal_share_loc() {
let mut traversal = TraversalState::new(0, 0, 0);
let cloned = traversal.clone();
assert_eq!(traversal.node_idx(), 0);
assert_eq!(traversal.player_idx(), 0);
assert_eq!(traversal.chosen_child_idx(), 0);
assert_eq!(cloned.node_idx(), 0);
assert_eq!(cloned.player_idx(), 0);
assert_eq!(cloned.chosen_child_idx(), 0);
traversal.move_to(2, 42);
assert_eq!(traversal.node_idx(), 2);
assert_eq!(traversal.chosen_child_idx(), 42);
assert_eq!(cloned.node_idx(), 2);
assert_eq!(cloned.chosen_child_idx(), 42);
}
}