use crate::arena::action::Action;
use crate::arena::game_state::Round;
use crate::arena::action::AgentAction;
use crate::arena::Historian;
use crate::core::Card;
use crate::arena::GameState;
use crate::arena::HistorianError;
use super::ActionGenerator;
use super::CFRState;
use super::NodeData;
use super::PlayerData;
use super::TerminalData;
use super::TraversalState;
pub struct CFRHistorian<T>
where
T: ActionGenerator,
{
pub traversal_state: TraversalState,
pub cfr_state: CFRState,
pub action_generator: T,
}
impl<T> CFRHistorian<T>
where
T: ActionGenerator,
{
pub(crate) fn new(traversal_state: TraversalState, cfr_state: CFRState) -> Self {
let action_generator = T::new(cfr_state.clone(), traversal_state.clone());
CFRHistorian {
traversal_state,
cfr_state,
action_generator,
}
}
pub(crate) fn ensure_target_node(
&mut self,
node_data: NodeData,
) -> Result<usize, HistorianError> {
let from_node_idx = self.traversal_state.node_idx();
let from_child_idx = self.traversal_state.chosen_child_idx();
self.cfr_state
.get_mut(from_node_idx)
.ok_or(HistorianError::CFRNodeNotFound)?
.increment_count(from_child_idx);
let to = self
.cfr_state
.get(from_node_idx)
.ok_or(HistorianError::CFRNodeNotFound)?
.get_child(from_child_idx);
match to {
Some(t) => Ok(t),
None => Ok(self.cfr_state.add(from_node_idx, from_child_idx, node_data)),
}
}
pub(crate) fn record_card(
&mut self,
_game_state: &GameState,
card: Card,
) -> Result<(), HistorianError> {
let card_value: u8 = card.into();
let to_node_idx = self.ensure_target_node(NodeData::Chance)?;
self.traversal_state
.move_to(to_node_idx, card_value as usize);
Ok(())
}
pub(crate) fn record_action(
&mut self,
game_state: &GameState,
action: AgentAction,
player_idx: usize,
) -> Result<(), HistorianError> {
let action_idx = self.action_generator.action_to_idx(game_state, &action);
let to_node_idx = self.ensure_target_node(NodeData::Player(PlayerData {
regret_matcher: Option::default(),
player_idx,
}))?;
self.traversal_state.move_to(to_node_idx, action_idx);
Ok(())
}
pub(crate) fn record_terminal(&mut self, game_state: &GameState) -> Result<(), HistorianError> {
let to_node_idx = self.ensure_target_node(NodeData::Terminal(TerminalData::default()))?;
self.traversal_state.move_to(to_node_idx, 0);
let reward = game_state.player_reward(self.traversal_state.player_idx());
let mut node = self
.cfr_state
.get_mut(to_node_idx)
.ok_or(HistorianError::CFRNodeNotFound)?;
node.increment_count(0);
if let NodeData::Terminal(td) = &mut node.data {
td.total_utility += reward;
Ok(())
} else {
Err(HistorianError::CFRUnexpectedNode(
"Expected terminal node".to_string(),
))
}
}
}
impl<T> Historian for CFRHistorian<T>
where
T: ActionGenerator,
{
fn record_action(
&mut self,
_id: u128,
game_state: &GameState,
action: Action,
) -> Result<(), HistorianError> {
match action {
Action::GameStart(_) | Action::ForcedBet(_) | Action::PlayerSit(_) => Ok(()),
Action::RoundAdvance(Round::Complete) => self.record_terminal(game_state),
Action::RoundAdvance(_) => Ok(()),
Action::Award(_) => Ok(()),
Action::DealStartingHand(payload) => {
if payload.idx == self.traversal_state.player_idx() {
self.record_card(game_state, payload.card)
} else {
Ok(())
}
}
Action::PlayedAction(payload) => {
self.record_action(game_state, payload.action, payload.idx)
}
Action::FailedAction(failed_action_payload) => self.record_action(
game_state,
failed_action_payload.result.action,
failed_action_payload.result.idx,
),
Action::DealCommunity(card) => self.record_card(game_state, card),
}
}
}