use std::cell::RefCell;
use brink_format::Value;
use crate::story::{ExternalFnHandler, ExternalResult};
pub const RECORDING_CAP: usize = 16_384;
#[derive(Clone, Debug, PartialEq)]
pub struct RecordedExternal {
pub name: String,
pub args: Vec<Value>,
pub result: Value,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum ReplayMode {
#[default]
Recorded,
Live,
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ReplayRecorder {
log: Vec<RecordedExternal>,
cursor: usize,
diverged: bool,
}
impl ReplayRecorder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn record(&mut self, name: &str, args: &[Value], result: &Value) {
if self.log.len() >= RECORDING_CAP {
return;
}
self.log.push(RecordedExternal {
name: name.to_owned(),
args: args.to_vec(),
result: result.clone(),
});
}
pub fn take_recorded(&mut self, name: &str, args: &[Value]) -> Option<Value> {
if self.diverged {
return None;
}
match self.log.get(self.cursor) {
Some(entry) if entry.name == name && entry.args.as_slice() == args => {
self.cursor += 1;
Some(entry.result.clone())
}
_ => {
self.diverged = true;
None
}
}
}
pub fn reset_cursor(&mut self) {
self.cursor = 0;
self.diverged = false;
}
#[must_use]
pub fn len(&self) -> usize {
self.log.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.log.is_empty()
}
}
pub struct RecordingHandler<'a, H: ExternalFnHandler + ?Sized> {
inner: &'a H,
recorder: RefCell<&'a mut ReplayRecorder>,
}
impl<'a, H: ExternalFnHandler + ?Sized> RecordingHandler<'a, H> {
pub fn new(inner: &'a H, recorder: &'a mut ReplayRecorder) -> Self {
Self {
inner,
recorder: RefCell::new(recorder),
}
}
}
impl<H: ExternalFnHandler + ?Sized> ExternalFnHandler for RecordingHandler<'_, H> {
fn call(&self, name: &str, args: &[Value]) -> ExternalResult {
let result = self.inner.call(name, args);
if let ExternalResult::Resolved(value) = &result {
self.recorder.borrow_mut().record(name, args, value);
}
result
}
}
pub struct ReplayHandler<'a> {
recorder: RefCell<&'a mut ReplayRecorder>,
}
impl<'a> ReplayHandler<'a> {
pub fn new(recorder: &'a mut ReplayRecorder) -> Self {
recorder.reset_cursor();
Self {
recorder: RefCell::new(recorder),
}
}
}
impl ExternalFnHandler for ReplayHandler<'_> {
fn call(&self, name: &str, args: &[Value]) -> ExternalResult {
match self.recorder.borrow_mut().take_recorded(name, args) {
Some(value) => ExternalResult::Resolved(value),
None => ExternalResult::Fallback,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn args(xs: &[i32]) -> Vec<Value> {
xs.iter().map(|&x| Value::Int(x)).collect()
}
#[test]
fn records_and_replays_in_order() {
let mut r = ReplayRecorder::new();
r.record("get_switch", &args(&[1]), &Value::Bool(true));
r.record("get_var", &args(&[2]), &Value::Int(42));
assert_eq!(r.len(), 2);
assert_eq!(
r.take_recorded("get_switch", &args(&[1])),
Some(Value::Bool(true))
);
assert_eq!(
r.take_recorded("get_var", &args(&[2])),
Some(Value::Int(42))
);
assert_eq!(r.take_recorded("get_var", &args(&[2])), None);
}
#[test]
fn diverges_on_mismatch_and_stays_diverged() {
let mut r = ReplayRecorder::new();
r.record("a", &args(&[1]), &Value::Int(1));
r.record("b", &args(&[2]), &Value::Int(2));
assert_eq!(r.take_recorded("x", &args(&[1])), None);
assert_eq!(r.take_recorded("a", &args(&[1])), None);
}
#[test]
fn arg_mismatch_diverges() {
let mut r = ReplayRecorder::new();
r.record("get_switch", &args(&[1]), &Value::Bool(true));
assert_eq!(r.take_recorded("get_switch", &args(&[2])), None);
}
#[test]
fn reset_cursor_replays_again() {
let mut r = ReplayRecorder::new();
r.record("a", &args(&[1]), &Value::Int(7));
assert_eq!(r.take_recorded("a", &args(&[1])), Some(Value::Int(7)));
r.reset_cursor();
assert_eq!(r.take_recorded("a", &args(&[1])), Some(Value::Int(7)));
}
#[test]
fn cap_drops_beyond_limit() {
let mut r = ReplayRecorder::new();
for _ in 0..RECORDING_CAP + 10 {
r.record("a", &[], &Value::Null);
}
assert_eq!(r.len(), RECORDING_CAP);
}
struct Stub(Vec<(&'static str, Value)>);
impl ExternalFnHandler for Stub {
fn call(&self, name: &str, _args: &[Value]) -> ExternalResult {
self.0
.iter()
.find(|(n, _)| *n == name)
.map_or(ExternalResult::Fallback, |(_, v)| {
ExternalResult::Resolved(v.clone())
})
}
}
#[test]
fn recording_captures_resolved_passes_through_fallback() {
let mut rec = ReplayRecorder::new();
let inner = Stub(vec![("get", Value::Int(5))]);
{
let h = RecordingHandler::new(&inner, &mut rec);
assert!(matches!(h.call("get", &[]), ExternalResult::Resolved(_)));
assert!(matches!(h.call("nope", &[]), ExternalResult::Fallback));
}
assert_eq!(rec.len(), 1);
}
#[test]
fn replay_returns_recorded_then_fallback() {
let mut rec = ReplayRecorder::new();
rec.record("get", &[], &Value::Int(5));
let h = ReplayHandler::new(&mut rec);
assert!(matches!(
h.call("get", &[]),
ExternalResult::Resolved(Value::Int(5))
));
assert!(matches!(h.call("get", &[]), ExternalResult::Fallback));
}
#[test]
fn record_then_replay_roundtrip() {
let mut rec = ReplayRecorder::new();
let inner = Stub(vec![("a", Value::Int(1)), ("b", Value::Bool(true))]);
{
let h = RecordingHandler::new(&inner, &mut rec);
let _ = h.call("a", &[]);
let _ = h.call("b", &[]);
}
let h = ReplayHandler::new(&mut rec);
assert!(matches!(
h.call("a", &[]),
ExternalResult::Resolved(Value::Int(1))
));
assert!(matches!(
h.call("b", &[]),
ExternalResult::Resolved(Value::Bool(true))
));
}
#[test]
fn replay_diverges_to_fallback_on_mismatch() {
let mut rec = ReplayRecorder::new();
rec.record("a", &[], &Value::Int(1));
let h = ReplayHandler::new(&mut rec);
assert!(matches!(h.call("x", &[]), ExternalResult::Fallback));
assert!(matches!(h.call("a", &[]), ExternalResult::Fallback));
}
}