use std::sync::Arc;
use crate::arena::{GameState, errors::CFRStateError};
use super::action_bit_set::ActionBitSet;
use super::node_arena::NodeArena;
use super::{ActionIndexMapperConfig, Node, NodeData};
fn full_action_bitset() -> ActionBitSet {
let mut bs = ActionBitSet::new();
for i in 0..super::NUM_ACTION_INDICES {
bs.insert(i);
}
bs
}
#[derive(Debug, Clone)]
pub struct CFRState {
arena: Arc<NodeArena>,
starting_game_state: Arc<GameState>,
mapper_config: ActionIndexMapperConfig,
}
impl CFRState {
pub fn new(game_state: GameState) -> Self {
let mapper_config = ActionIndexMapperConfig::from_game_state(&game_state);
let arena = NodeArena::new();
arena.push(Node::new_root());
CFRState {
arena: Arc::new(arena),
starting_game_state: Arc::new(game_state),
mapper_config,
}
}
pub fn starting_game_state(&self) -> &GameState {
&self.starting_game_state
}
pub fn mapper_config(&self) -> &ActionIndexMapperConfig {
&self.mapper_config
}
pub fn add(&self, parent_idx: usize, child_idx: usize, data: NodeData) -> usize {
let node = Node::new(parent_idx, child_idx, data);
let idx = self.arena.push(node);
self.arena
.get(parent_idx)
.try_set_child(child_idx, idx)
.unwrap_or_else(|existing| {
panic!(
"Child already set at parent_idx={parent_idx}, child_idx={child_idx}: \
existing node index={existing}. Use ensure_child() for concurrent access."
)
});
idx
}
pub fn get_node_data(&self, idx: usize) -> Option<NodeData> {
if idx < self.arena.len() {
Some(self.arena.get(idx).read_data().clone())
} else {
None
}
}
pub fn with_node_data<F, R>(&self, idx: usize, f: F) -> R
where
F: FnOnce(&NodeData) -> R,
{
let guard = self.arena.get(idx).read_data();
f(&guard)
}
pub fn get_child(&self, parent_idx: usize, child_idx: usize) -> Option<usize> {
if parent_idx < self.arena.len() {
self.arena.get(parent_idx).get_child(child_idx)
} else {
None
}
}
pub fn update_node<F>(&self, node_idx: usize, f: F) -> Result<(), CFRStateError>
where
F: FnOnce(&mut NodeData),
{
if node_idx < self.arena.len() {
let mut guard = self.arena.get(node_idx).write_data();
f(&mut guard);
Ok(())
} else {
Err(CFRStateError::NodeNotFound)
}
}
pub fn arena(&self) -> &Arc<NodeArena> {
&self.arena
}
pub fn node_count(&self) -> usize {
self.arena().len()
}
pub fn get_pruning_info(
&self,
node_idx: usize,
) -> (super::action_bit_set::ActionBitSet, usize) {
use little_sorry::RegretMinimizer;
self.with_node_data(node_idx, |data| {
if let NodeData::Player(pd) = data {
if let Some(rm) = pd.regret_matcher.as_ref() {
let strategy = rm.current_strategy();
let mut active = super::action_bit_set::ActionBitSet::new();
for (i, &w) in strategy.iter().enumerate() {
if w > 0.0 {
active.insert(i);
}
}
(active, rm.num_updates())
} else {
(full_action_bitset(), 0)
}
} else {
(full_action_bitset(), 0)
}
})
}
pub fn node_current_strategy_into(&self, node_idx: usize, out: &mut [f32]) -> bool {
use little_sorry::RegretMinimizer;
self.with_node_data(node_idx, |data| match data {
NodeData::Player(pd) => match pd.regret_matcher.as_ref() {
Some(rm) => {
let strategy = rm.current_strategy();
let n = out.len().min(strategy.len());
out[..n].copy_from_slice(&strategy[..n]);
true
}
None => false,
},
_ => false,
})
}
pub fn node_avg_regret(&self, node_idx: usize) -> Option<f32> {
use little_sorry::RegretMinimizer;
self.with_node_data(node_idx, |data| match data {
NodeData::Player(pd) => pd
.regret_matcher
.as_ref()
.and_then(|rm| (rm.num_updates() > 0).then(|| rm.average_regret())),
_ => None,
})
}
fn verify_existing_child(
&self,
existing_idx: usize,
expected_data: NodeData,
parent_idx: usize,
child_idx: usize,
allow_mutation: bool,
) -> usize {
let data_guard = self.arena.get(existing_idx).read_data();
let types_match =
std::mem::discriminant(&*data_guard) == std::mem::discriminant(&expected_data);
if types_match {
return existing_idx;
}
drop(data_guard);
if allow_mutation {
let mut data_guard = self.arena.get(existing_idx).write_data();
let still_mismatched =
std::mem::discriminant(&*data_guard) != std::mem::discriminant(&expected_data);
if still_mismatched {
tracing::debug!(
parent_idx,
child_idx,
existing_idx,
?expected_data,
"Node type mismatch - updating node type. This occurs when different \
bet amounts map to the same index but lead to different outcomes."
);
*data_guard = expected_data;
}
} else {
let data_guard = self.arena.get(existing_idx).read_data();
panic!(
"Node type mismatch at parent_idx={}, child_idx={}: \
expected {:?}, found {:?}. This can occur when different bet \
amounts map to the same index. Set allow_node_mutation=true \
to handle this case.",
parent_idx, child_idx, expected_data, *data_guard
);
}
existing_idx
}
pub fn ensure_child(
&self,
parent_idx: usize,
child_idx: usize,
expected_data: NodeData,
allow_mutation: bool,
) -> usize {
if let Some(existing_idx) = self.arena.get(parent_idx).get_child(child_idx) {
return self.verify_existing_child(
existing_idx,
expected_data,
parent_idx,
child_idx,
allow_mutation,
);
}
let node = Node::new(parent_idx, child_idx, expected_data.clone());
let idx = self.arena.push(node);
match self.arena.get(parent_idx).try_set_child(child_idx, idx) {
Ok(()) => idx,
Err(existing) => self.verify_existing_child(
existing,
expected_data,
parent_idx,
child_idx,
allow_mutation,
),
}
}
}
#[cfg(test)]
mod tests {
use crate::arena::cfr::{NodeData, PlayerData};
use crate::arena::GameStateBuilder;
use super::CFRState;
#[test]
fn test_add_get_node() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(3, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
let new_data = NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
});
let player_idx: usize = state.add(0, 0, new_data);
let node_data = state.get_node_data(player_idx).unwrap();
match &node_data {
NodeData::Player(pd) => assert!(pd.regret_matcher.is_none()),
_ => panic!("Expected player data"),
}
assert_eq!(state.get_child(0, 0), Some(player_idx));
}
#[test]
fn test_node_get_not_exist() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(3, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
let root = state.get_node_data(0);
assert!(root.is_some());
let node = state.get_node_data(100);
assert!(node.is_none());
}
#[test]
#[should_panic]
fn test_with_node_data_panics_out_of_bounds() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
state.with_node_data(100, |_| {});
}
#[test]
fn test_with_node_data_reads_correctly() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
let is_root = state.with_node_data(0, |data| data.is_root());
assert!(is_root);
let idx = state.add(
0,
0,
NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 1,
}),
);
let is_player = state.with_node_data(idx, |data| data.is_player());
assert!(is_player);
}
#[test]
fn test_update_node() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
let idx = state.add(
0,
0,
NodeData::Terminal(crate::arena::cfr::TerminalData::default()),
);
state
.update_node(idx, |data| {
if let NodeData::Terminal(td) = data {
td.total_utility = 42.0;
}
})
.unwrap();
let utility = state.with_node_data(idx, |data| match data {
NodeData::Terminal(td) => td.total_utility,
_ => panic!("Expected terminal"),
});
assert_eq!(utility, 42.0);
}
#[test]
fn test_update_node_not_found() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
let result = state.update_node(999, |_| {});
assert!(result.is_err());
}
#[test]
fn test_ensure_child_creates_new() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
assert!(state.get_child(0, 5).is_none());
let idx = state.ensure_child(0, 5, NodeData::Chance, false);
assert!(idx > 0);
assert_eq!(state.get_child(0, 5), Some(idx));
let idx2 = state.ensure_child(0, 5, NodeData::Chance, false);
assert_eq!(idx, idx2);
}
#[test]
fn test_ensure_child_concurrent_race() {
use std::sync::Arc;
let state = Arc::new(CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
));
let num_threads = 8;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let state = state.clone();
std::thread::spawn(move || state.ensure_child(0, 3, NodeData::Chance, false))
})
.collect();
let indices: Vec<usize> = handles.into_iter().map(|h| h.join().unwrap()).collect();
let first = indices[0];
for idx in &indices {
assert_eq!(*idx, first, "All threads must see the same child node");
}
assert_eq!(state.get_child(0, 3), Some(first));
}
#[test]
fn test_ensure_child_type_mismatch_with_mutation() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
let idx = state.ensure_child(0, 0, NodeData::Chance, true);
let idx2 = state.ensure_child(
0,
0,
NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
}),
true,
);
assert_eq!(idx, idx2);
let is_player = state.with_node_data(idx, |data| data.is_player());
assert!(is_player);
}
#[test]
#[should_panic(expected = "Node type mismatch")]
fn test_ensure_child_type_mismatch_without_mutation_panics() {
let state = CFRState::new(
GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap(),
);
state.ensure_child(0, 0, NodeData::Chance, false);
state.ensure_child(
0,
0,
NodeData::Player(PlayerData {
regret_matcher: None,
player_idx: 0,
}),
false,
);
}
}