use tracing::event;
use crate::arena::action::Action;
use crate::arena::action::PlayedActionPayload;
use crate::arena::game_state::Round;
use crate::arena::Historian;
use crate::arena::GameState;
use crate::arena::HistorianError;
use super::ActionIndexMapper;
use super::CFRState;
use super::NodeData;
use super::PlayerData;
use super::TerminalData;
use super::TraversalSet;
pub struct CFRHistorian {
cfr_state: CFRState,
traversal_set: TraversalSet,
action_index_mapper: ActionIndexMapper,
allow_node_mutation: bool,
}
impl CFRHistorian {
pub(crate) fn new(
cfr_state: &CFRState,
traversal_set: TraversalSet,
allow_node_mutation: bool,
) -> Self {
let action_index_mapper = ActionIndexMapper::new(*cfr_state.mapper_config());
CFRHistorian {
cfr_state: cfr_state.clone(),
traversal_set,
action_index_mapper,
allow_node_mutation,
}
}
fn ensure_target_node(&self, node_data: NodeData) -> Result<usize, HistorianError> {
let traversal_state = self.traversal_set.get(0);
let (from_node_idx, from_child_idx) = traversal_state.get_position();
Ok(self.cfr_state.ensure_child(
from_node_idx,
from_child_idx,
node_data,
self.allow_node_mutation,
))
}
fn move_all_players_to(&self, to_node_idx: usize, child_idx: usize) {
self.traversal_set.move_all_to(to_node_idx, child_idx);
}
pub(crate) fn record_community_card(
&self,
card: crate::core::Card,
) -> Result<(), HistorianError> {
let card_value: u8 = card.into();
let to_node_idx = self.ensure_target_node(NodeData::Chance)?;
self.move_all_players_to(to_node_idx, card_value as usize);
Ok(())
}
pub(crate) fn record_starting_hand_card(
&self,
card: crate::core::Card,
) -> Result<(), HistorianError> {
let card_value: u8 = card.into();
let to_node_idx = self.ensure_target_node(NodeData::Chance)?;
self.move_all_players_to(to_node_idx, card_value as usize);
Ok(())
}
pub(crate) fn record_action(
&self,
_game_state: &GameState,
payload: &PlayedActionPayload,
) -> Result<(), HistorianError> {
let pre_action_round_bet = payload.starting_bet;
let pre_action_player_bet = payload.starting_player_bet;
let amount_bet = payload.final_player_bet - payload.starting_player_bet;
let pre_action_stack = payload.player_stack + amount_bet;
let action_idx = self.action_index_mapper.action_to_idx_raw(
&payload.action,
pre_action_round_bet,
pre_action_player_bet,
pre_action_stack,
);
event!(
tracing::Level::TRACE,
acting_player_idx = payload.idx,
?payload.action,
action_idx,
num_players = self.traversal_set.num_players(),
starting_pot = payload.starting_pot,
starting_bet = payload.starting_bet,
"Recording action for all players"
);
let to_node_idx = self.ensure_target_node(NodeData::Player(PlayerData {
regret_matcher: Option::default(),
player_idx: payload.idx as u8,
}))?;
self.move_all_players_to(to_node_idx, action_idx);
Ok(())
}
pub(crate) fn record_terminal(&self, game_state: &GameState) -> Result<(), HistorianError> {
let to_node_idx = self.ensure_target_node(NodeData::Terminal(TerminalData::default()))?;
self.move_all_players_to(to_node_idx, 0);
let total_reward: f32 = (0..self.traversal_set.num_players())
.map(|idx| game_state.player_reward(idx))
.sum();
if total_reward != 0.0 {
self.cfr_state
.update_node(to_node_idx, |data| {
if let NodeData::Terminal(td) = data {
td.total_utility += total_reward;
}
})
.map_err(|_| HistorianError::CFRNodeNotFound)?;
}
Ok(())
}
}
#[async_trait::async_trait]
impl Historian for CFRHistorian {
async fn record_action(
&mut self,
_id: u128,
game_state: &GameState,
action: &Action,
) -> Result<(), HistorianError> {
match action {
Action::GameStart(_) | Action::ForcedBet(_) | Action::PlayerSit(_) => Ok(()),
Action::RoundAdvance(Round::Complete) => self.record_terminal(game_state),
Action::RoundAdvance(_) => Ok(()),
Action::Award(_) => Ok(()),
Action::DealStartingHand(payload) => self.record_starting_hand_card(payload.card),
Action::PlayedAction(payload) => CFRHistorian::record_action(self, game_state, payload),
Action::FailedAction(failed_action_payload) => {
CFRHistorian::record_action(self, game_state, &failed_action_payload.result)
}
Action::DealCommunity(card) => self.record_community_card(*card),
}
}
}