use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use tracing::{instrument, trace};
use crate::arena::{GameState, action::Action};
use super::{Historian, HistorianError, HistorianLock};
#[derive(Debug, Clone)]
pub struct HistoryRecord {
pub before_game_state: Option<GameState>,
pub action: Action,
pub after_game_state: GameState,
}
pub type SharedHistoryStorage = Arc<Mutex<Vec<HistoryRecord>>>;
pub struct VecHistorian {
previous: Option<GameState>,
records: SharedHistoryStorage,
}
impl VecHistorian {
pub fn get_storage(&self) -> SharedHistoryStorage {
self.records.clone()
}
pub fn new_with_actions(actions: SharedHistoryStorage) -> Self {
Self {
records: actions,
previous: None,
}
}
pub fn new() -> Self {
VecHistorian::new_with_actions(Arc::new(Mutex::new(vec![])))
}
}
impl Default for VecHistorian {
fn default() -> Self {
VecHistorian::new()
}
}
#[async_trait]
impl Historian for VecHistorian {
#[instrument(level = "trace", skip(self, game_state), fields(record_count))]
async fn record_action(
&mut self,
_id: u128,
game_state: &GameState,
action: &Action,
) -> Result<(), HistorianError> {
let record_count = {
let mut act = self
.records
.lock()
.map_err(|_| HistorianError::LockPoisoned {
lock: HistorianLock::VecRecords,
})?;
act.push(HistoryRecord {
before_game_state: self.previous.clone(),
action: action.clone(),
after_game_state: game_state.clone(),
});
act.len()
};
trace!(record_count, "Recorded action to VecHistorian");
self.previous = Some(game_state.clone());
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::arena::{
Agent, HoldemSimulationBuilder,
agent::{CallingAgent, RandomAgent},
};
use super::*;
use crate::arena::GameStateBuilder;
#[tokio::test]
async fn test_vec_historian() {
let hist = Box::new(VecHistorian::default());
let records = hist.get_storage();
let agents: Vec<Box<dyn Agent>> = vec![
Box::<RandomAgent>::default(),
Box::<RandomAgent>::default(),
Box::<RandomAgent>::default(),
Box::<RandomAgent>::default(),
Box::<RandomAgent>::default(),
];
let game_state = GameStateBuilder::new()
.num_players_with_stack(5, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.historians(vec![hist])
.build()
.unwrap();
sim.run().await;
assert!(records.lock().unwrap().len() > 10);
}
#[tokio::test]
async fn test_restarting_simulations() {
let hist = Box::new(VecHistorian::default());
let records = hist.get_storage();
let agents: Vec<Box<dyn Agent>> = vec![
Box::<CallingAgent>::default(),
Box::<CallingAgent>::default(),
];
let game_state = GameStateBuilder::new()
.num_players_with_stack(2, 100.0)
.blinds(10.0, 5.0)
.build()
.unwrap();
let mut sim = HoldemSimulationBuilder::default()
.game_state(game_state)
.agents(agents)
.historians(vec![hist])
.build()
.unwrap();
sim.run().await;
let snapshot = records.lock().unwrap().clone();
for r in snapshot.iter() {
if let (Action::PlayedAction(played_action), Some(before_game_state)) =
(&r.action, &r.before_game_state)
{
let inner_agents: Vec<Box<dyn Agent>> = vec![
Box::<CallingAgent>::default(),
Box::<CallingAgent>::default(),
];
let inner_hist = Box::new(VecHistorian::default());
let inner_records = inner_hist.get_storage();
let mut inner_sim = HoldemSimulationBuilder::default()
.game_state(before_game_state.clone())
.agents(inner_agents)
.historians(vec![inner_hist])
.build()
.unwrap();
inner_sim.run().await;
let first_record = inner_records.lock().unwrap().first().unwrap().clone();
if let Action::PlayedAction(inner_played_action) = first_record.action {
assert_eq!(played_action.idx, inner_played_action.idx);
} else {
panic!(
"The first action should be a played action, found {:?}",
first_record.action
);
}
}
}
}
}