use crate::{SimTrace, TraceEvent};
pub struct ReplayRunner<S, F>
where
F: Fn(u64, u64) -> (S, SimTrace),
{
seed: u64,
sim_fn: F,
_phantom: std::marker::PhantomData<S>,
}
impl<S, F> ReplayRunner<S, F>
where
F: Fn(u64, u64) -> (S, SimTrace),
{
pub fn new(seed: u64, sim_fn: F) -> Self {
Self {
seed,
sim_fn,
_phantom: std::marker::PhantomData,
}
}
pub fn replay_to_tick(&self, target_tick: u64) -> (S, SimTrace) {
(self.sim_fn)(self.seed, target_tick)
}
pub fn replay_window(&self, start_tick: u64, end_tick: u64) -> (S, Vec<TraceEvent>) {
let (state, trace) = (self.sim_fn)(self.seed, end_tick);
let window_events: Vec<TraceEvent> = trace
.events()
.iter()
.filter(|e| e.tick >= start_tick && e.tick <= end_tick)
.cloned()
.collect();
(state, window_events)
}
pub fn seed(&self) -> u64 {
self.seed
}
}
pub fn replay_until<S, F, P>(
seed: u64,
max_ticks: u64,
step: u64,
sim_fn: F,
predicate: P,
) -> Option<u64>
where
F: Fn(u64, u64) -> (S, SimTrace),
P: Fn(&S, &SimTrace) -> bool,
{
let mut tick = step;
while tick <= max_ticks {
let (state, trace) = sim_fn(seed, tick);
if predicate(&state, &trace) {
if tick <= step {
return Some(tick);
}
let mut lo = tick - step;
let mut hi = tick;
while lo + 1 < hi {
let mid = lo + (hi - lo) / 2;
let (s, t) = sim_fn(seed, mid);
if predicate(&s, &t) {
hi = mid;
} else {
lo = mid;
}
}
return Some(hi);
}
tick += step;
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TraceEventKind;
fn counter_sim(_seed: u64, ticks: u64) -> (u64, SimTrace) {
let mut trace = SimTrace::new();
let mut counter = 0u64;
for t in 1..=ticks {
counter += 1;
if t % 10 == 0 {
trace.record(
t,
1,
TraceEventKind::Custom {
tag: "counter".into(),
data: format!("{counter}"),
},
);
}
}
(counter, trace)
}
#[test]
fn test_replay_to_tick() {
let runner = ReplayRunner::new(42, counter_sim);
let (state, _trace) = runner.replay_to_tick(100);
assert_eq!(state, 100);
}
#[test]
fn test_replay_deterministic() {
let runner = ReplayRunner::new(42, counter_sim);
let (s1, t1) = runner.replay_to_tick(500);
let (s2, t2) = runner.replay_to_tick(500);
assert_eq!(s1, s2);
assert_eq!(t1.len(), t2.len());
}
#[test]
fn test_replay_window() {
let runner = ReplayRunner::new(42, counter_sim);
let (state, events) = runner.replay_window(50, 100);
assert_eq!(state, 100);
assert_eq!(events.len(), 6);
assert!(events.iter().all(|e| e.tick >= 50 && e.tick <= 100));
}
#[test]
fn test_replay_until_found() {
let tick = replay_until(42, 1000, 10, counter_sim, |state, _trace| *state >= 75);
assert!(tick.is_some());
assert_eq!(tick.unwrap(), 75);
}
#[test]
fn test_replay_until_not_found() {
let tick = replay_until(42, 50, 10, counter_sim, |state, _trace| *state >= 100);
assert!(tick.is_none());
}
#[test]
fn test_replay_seed() {
let runner = ReplayRunner::new(123, counter_sim);
assert_eq!(runner.seed(), 123);
}
}