1use std::cell::RefCell;
18
19use brink_format::Value;
20
21use crate::story::{ExternalFnHandler, ExternalResult};
22
23pub const RECORDING_CAP: usize = 16_384;
27
28#[derive(Clone, Debug, PartialEq)]
31pub struct RecordedExternal {
32 pub name: String,
34 pub args: Vec<Value>,
36 pub result: Value,
38}
39
40#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
42pub enum ReplayMode {
43 #[default]
47 Recorded,
48 Live,
52}
53
54#[derive(Clone, Debug, Default, PartialEq)]
57pub struct ReplayRecorder {
58 log: Vec<RecordedExternal>,
59 cursor: usize,
60 diverged: bool,
61}
62
63impl ReplayRecorder {
64 #[must_use]
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn record(&mut self, name: &str, args: &[Value], result: &Value) {
73 if self.log.len() >= RECORDING_CAP {
74 return;
75 }
76 self.log.push(RecordedExternal {
77 name: name.to_owned(),
78 args: args.to_vec(),
79 result: result.clone(),
80 });
81 }
82
83 pub fn take_recorded(&mut self, name: &str, args: &[Value]) -> Option<Value> {
89 if self.diverged {
90 return None;
91 }
92 match self.log.get(self.cursor) {
93 Some(entry) if entry.name == name && entry.args.as_slice() == args => {
94 self.cursor += 1;
95 Some(entry.result.clone())
96 }
97 _ => {
98 self.diverged = true;
99 None
100 }
101 }
102 }
103
104 pub fn reset_cursor(&mut self) {
107 self.cursor = 0;
108 self.diverged = false;
109 }
110
111 #[must_use]
113 pub fn len(&self) -> usize {
114 self.log.len()
115 }
116
117 #[must_use]
119 pub fn is_empty(&self) -> bool {
120 self.log.is_empty()
121 }
122}
123
124pub struct RecordingHandler<'a, H: ExternalFnHandler + ?Sized> {
133 inner: &'a H,
134 recorder: RefCell<&'a mut ReplayRecorder>,
135}
136
137impl<'a, H: ExternalFnHandler + ?Sized> RecordingHandler<'a, H> {
138 pub fn new(inner: &'a H, recorder: &'a mut ReplayRecorder) -> Self {
140 Self {
141 inner,
142 recorder: RefCell::new(recorder),
143 }
144 }
145}
146
147impl<H: ExternalFnHandler + ?Sized> ExternalFnHandler for RecordingHandler<'_, H> {
148 fn call(&self, name: &str, args: &[Value]) -> ExternalResult {
149 let result = self.inner.call(name, args);
150 if let ExternalResult::Resolved(value) = &result {
151 self.recorder.borrow_mut().record(name, args, value);
152 }
153 result
154 }
155}
156
157pub struct ReplayHandler<'a> {
167 recorder: RefCell<&'a mut ReplayRecorder>,
168}
169
170impl<'a> ReplayHandler<'a> {
171 pub fn new(recorder: &'a mut ReplayRecorder) -> Self {
174 recorder.reset_cursor();
175 Self {
176 recorder: RefCell::new(recorder),
177 }
178 }
179}
180
181impl ExternalFnHandler for ReplayHandler<'_> {
182 fn call(&self, name: &str, args: &[Value]) -> ExternalResult {
183 match self.recorder.borrow_mut().take_recorded(name, args) {
184 Some(value) => ExternalResult::Resolved(value),
185 None => ExternalResult::Fallback,
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn args(xs: &[i32]) -> Vec<Value> {
195 xs.iter().map(|&x| Value::Int(x)).collect()
196 }
197
198 #[test]
199 fn records_and_replays_in_order() {
200 let mut r = ReplayRecorder::new();
201 r.record("get_switch", &args(&[1]), &Value::Bool(true));
202 r.record("get_var", &args(&[2]), &Value::Int(42));
203 assert_eq!(r.len(), 2);
204
205 assert_eq!(
206 r.take_recorded("get_switch", &args(&[1])),
207 Some(Value::Bool(true))
208 );
209 assert_eq!(
210 r.take_recorded("get_var", &args(&[2])),
211 Some(Value::Int(42))
212 );
213 assert_eq!(r.take_recorded("get_var", &args(&[2])), None);
215 }
216
217 #[test]
218 fn diverges_on_mismatch_and_stays_diverged() {
219 let mut r = ReplayRecorder::new();
220 r.record("a", &args(&[1]), &Value::Int(1));
221 r.record("b", &args(&[2]), &Value::Int(2));
222 assert_eq!(r.take_recorded("x", &args(&[1])), None);
223 assert_eq!(r.take_recorded("a", &args(&[1])), None);
225 }
226
227 #[test]
228 fn arg_mismatch_diverges() {
229 let mut r = ReplayRecorder::new();
230 r.record("get_switch", &args(&[1]), &Value::Bool(true));
231 assert_eq!(r.take_recorded("get_switch", &args(&[2])), None);
232 }
233
234 #[test]
235 fn reset_cursor_replays_again() {
236 let mut r = ReplayRecorder::new();
237 r.record("a", &args(&[1]), &Value::Int(7));
238 assert_eq!(r.take_recorded("a", &args(&[1])), Some(Value::Int(7)));
239 r.reset_cursor();
240 assert_eq!(r.take_recorded("a", &args(&[1])), Some(Value::Int(7)));
241 }
242
243 #[test]
244 fn cap_drops_beyond_limit() {
245 let mut r = ReplayRecorder::new();
246 for _ in 0..RECORDING_CAP + 10 {
247 r.record("a", &[], &Value::Null);
248 }
249 assert_eq!(r.len(), RECORDING_CAP);
250 }
251
252 struct Stub(Vec<(&'static str, Value)>);
254 impl ExternalFnHandler for Stub {
255 fn call(&self, name: &str, _args: &[Value]) -> ExternalResult {
256 self.0
257 .iter()
258 .find(|(n, _)| *n == name)
259 .map_or(ExternalResult::Fallback, |(_, v)| {
260 ExternalResult::Resolved(v.clone())
261 })
262 }
263 }
264
265 #[test]
266 fn recording_captures_resolved_passes_through_fallback() {
267 let mut rec = ReplayRecorder::new();
268 let inner = Stub(vec![("get", Value::Int(5))]);
269 {
270 let h = RecordingHandler::new(&inner, &mut rec);
271 assert!(matches!(h.call("get", &[]), ExternalResult::Resolved(_)));
272 assert!(matches!(h.call("nope", &[]), ExternalResult::Fallback));
273 }
274 assert_eq!(rec.len(), 1);
275 }
276
277 #[test]
278 fn replay_returns_recorded_then_fallback() {
279 let mut rec = ReplayRecorder::new();
280 rec.record("get", &[], &Value::Int(5));
281 let h = ReplayHandler::new(&mut rec);
282 assert!(matches!(
283 h.call("get", &[]),
284 ExternalResult::Resolved(Value::Int(5))
285 ));
286 assert!(matches!(h.call("get", &[]), ExternalResult::Fallback));
287 }
288
289 #[test]
290 fn record_then_replay_roundtrip() {
291 let mut rec = ReplayRecorder::new();
292 let inner = Stub(vec![("a", Value::Int(1)), ("b", Value::Bool(true))]);
293 {
294 let h = RecordingHandler::new(&inner, &mut rec);
295 let _ = h.call("a", &[]);
296 let _ = h.call("b", &[]);
297 }
298 let h = ReplayHandler::new(&mut rec);
299 assert!(matches!(
300 h.call("a", &[]),
301 ExternalResult::Resolved(Value::Int(1))
302 ));
303 assert!(matches!(
304 h.call("b", &[]),
305 ExternalResult::Resolved(Value::Bool(true))
306 ));
307 }
308
309 #[test]
310 fn replay_diverges_to_fallback_on_mismatch() {
311 let mut rec = ReplayRecorder::new();
312 rec.record("a", &[], &Value::Int(1));
313 let h = ReplayHandler::new(&mut rec);
314 assert!(matches!(h.call("x", &[]), ExternalResult::Fallback));
315 assert!(matches!(h.call("a", &[]), ExternalResult::Fallback));
316 }
317}