1use super::JsonValue;
2use super::json_to_string;
3use super::session::{EffectRecord, RecordedOutcome};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum EffectReplayMode {
7 #[default]
8 Normal,
9 Record,
10 Replay,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ReplayFailure {
15 Exhausted {
16 effect_type: String,
17 position: usize,
18 },
19 Mismatch {
20 seq: u32,
21 expected: String,
22 got: String,
23 },
24 ArgsMismatch {
25 seq: u32,
26 effect_type: String,
27 expected: String,
28 got: String,
29 },
30 Unconsumed {
31 remaining: usize,
32 },
33}
34
35#[derive(Debug, Clone, Default)]
36pub struct EffectReplayState {
37 mode: EffectReplayMode,
38 recorded_effects: Vec<EffectRecord>,
39 replay_effects: Vec<EffectRecord>,
40 replay_pos: usize,
41 validate_replay_args: bool,
42 args_diff_count: usize,
43 group_stack: Vec<u32>,
45 branch_stack: Vec<u32>,
48 effect_count_stack: Vec<u32>,
50 next_group_id: u32,
52 group_consumed: Vec<usize>,
54}
55
56impl EffectReplayState {
57 pub fn mode(&self) -> EffectReplayMode {
58 self.mode
59 }
60
61 pub fn set_normal(&mut self) {
62 self.mode = EffectReplayMode::Normal;
63 self.recorded_effects.clear();
64 self.replay_effects.clear();
65 self.replay_pos = 0;
66 self.validate_replay_args = false;
67 self.args_diff_count = 0;
68 self.reset_group_state();
69 }
70
71 pub fn start_recording(&mut self) {
72 self.mode = EffectReplayMode::Record;
73 self.recorded_effects.clear();
74 self.replay_effects.clear();
75 self.replay_pos = 0;
76 self.validate_replay_args = false;
77 self.args_diff_count = 0;
78 self.reset_group_state();
79 }
80
81 pub fn start_replay(&mut self, effects: Vec<EffectRecord>, validate_args: bool) {
82 self.mode = EffectReplayMode::Replay;
83 self.replay_effects = effects;
84 self.replay_pos = 0;
85 self.validate_replay_args = validate_args;
86 self.recorded_effects.clear();
87 self.args_diff_count = 0;
88 self.reset_group_state();
89 }
90
91 pub fn take_recorded_effects(&mut self) -> Vec<EffectRecord> {
92 std::mem::take(&mut self.recorded_effects)
93 }
94
95 pub fn recorded_effects(&self) -> &[EffectRecord] {
96 &self.recorded_effects
97 }
98
99 pub fn replay_progress(&self) -> (usize, usize) {
100 (self.replay_pos, self.replay_effects.len())
101 }
102
103 pub fn args_diff_count(&self) -> usize {
104 self.args_diff_count
105 }
106
107 pub fn ensure_replay_consumed(&self) -> Result<(), ReplayFailure> {
108 if self.mode == EffectReplayMode::Replay && self.replay_pos < self.replay_effects.len() {
109 return Err(ReplayFailure::Unconsumed {
110 remaining: self.replay_effects.len() - self.replay_pos,
111 });
112 }
113 Ok(())
114 }
115
116 pub fn enter_group(&mut self) -> u32 {
118 self.next_group_id += 1;
119 let id = self.next_group_id;
120 self.group_stack.push(id);
121 self.branch_stack.push(0); self.effect_count_stack.push(0);
123 id
124 }
125
126 pub fn exit_group(&mut self) {
128 self.group_stack.pop();
129 self.branch_stack.pop();
130 self.effect_count_stack.pop();
131 }
132
133 pub fn set_branch(&mut self, index: u32) {
135 if let Some(last) = self.branch_stack.last_mut() {
136 *last = index;
137 }
138 if let Some(last) = self.effect_count_stack.last_mut() {
139 *last = 0;
140 }
141 }
142
143 pub fn record_effect(
144 &mut self,
145 effect_type: &str,
146 args: Vec<JsonValue>,
147 outcome: RecordedOutcome,
148 caller_fn: &str,
149 source_line: usize,
150 ) {
151 let seq = self.recorded_effects.len() as u32 + 1;
152 self.recorded_effects.push(EffectRecord {
153 seq,
154 effect_type: effect_type.to_string(),
155 args,
156 outcome,
157 caller_fn: caller_fn.to_string(),
158 source_line,
159 group_id: self.group_stack.last().copied(),
160 branch_path: if self.branch_stack.is_empty() {
161 None
162 } else {
163 Some(self.current_branch_path())
164 },
165 effect_occurrence: if self.branch_stack.is_empty() {
166 None
167 } else {
168 self.current_effect_occurrence()
169 },
170 });
171 self.bump_effect_occurrence();
172 }
173
174 pub fn replay_effect(
175 &mut self,
176 effect_type: &str,
177 got_args: Option<Vec<JsonValue>>,
178 ) -> Result<RecordedOutcome, ReplayFailure> {
179 if self.replay_pos < self.replay_effects.len()
182 && let Some(gid) = self.replay_effects[self.replay_pos].group_id
183 {
184 return self.replay_effect_in_group(gid, effect_type, got_args);
185 }
186
187 if self.replay_pos >= self.replay_effects.len() {
189 return Err(ReplayFailure::Exhausted {
190 effect_type: effect_type.to_string(),
191 position: self.replay_pos + 1,
192 });
193 }
194
195 let record = self.replay_effects[self.replay_pos].clone();
196 if record.effect_type != effect_type {
197 return Err(ReplayFailure::Mismatch {
198 seq: record.seq,
199 expected: record.effect_type,
200 got: effect_type.to_string(),
201 });
202 }
203
204 if let Some(got_args) = got_args
205 && got_args != record.args
206 {
207 if self.validate_replay_args {
208 return Err(ReplayFailure::ArgsMismatch {
209 seq: record.seq,
210 effect_type: effect_type.to_string(),
211 expected: json_to_string(&JsonValue::Array(record.args.clone())),
212 got: json_to_string(&JsonValue::Array(got_args)),
213 });
214 }
215 self.args_diff_count += 1;
216 }
217
218 self.replay_pos += 1;
219 Ok(record.outcome)
220 }
221
222 fn replay_effect_in_group(
225 &mut self,
226 group_id: u32,
227 effect_type: &str,
228 got_args: Option<Vec<JsonValue>>,
229 ) -> Result<RecordedOutcome, ReplayFailure> {
230 let group_start = self.replay_pos;
232 let group_end = self.replay_effects[group_start..]
233 .iter()
234 .position(|e| e.group_id != Some(group_id))
235 .map(|offset| group_start + offset)
236 .unwrap_or(self.replay_effects.len());
237
238 let current_bp = if self.branch_stack.is_empty() {
241 None
242 } else {
243 Some(self.current_branch_path())
244 };
245
246 let mut fallback_idx: Option<usize> = None;
247 for idx in group_start..group_end {
248 if self.group_consumed.contains(&idx) {
249 continue;
250 }
251 let record = &self.replay_effects[idx];
252 if record.effect_type != effect_type {
253 continue;
254 }
255
256 let args_ok = match (&got_args, self.validate_replay_args) {
258 (Some(got), true) if *got != record.args => false,
259 (Some(got), false) if *got != record.args => {
260 self.args_diff_count += 1;
261 true
262 }
263 _ => true,
264 };
265 if !args_ok {
266 continue;
267 }
268
269 let bp_match = match (¤t_bp, &record.branch_path) {
272 (Some(got), Some(rec)) => {
273 if got != rec {
274 continue; }
276 true
277 }
278 _ => false, };
280 if bp_match {
281 let current_occ = self.current_effect_occurrence();
283 match (current_occ, record.effect_occurrence) {
284 (Some(got), Some(rec)) if got == rec => {
285 return self.consume_group_match(idx, group_start, group_end);
286 }
287 (Some(_), Some(_)) => continue, _ => {
289 if fallback_idx.is_none() {
291 fallback_idx = Some(idx);
292 }
293 }
294 }
295 } else if fallback_idx.is_none() {
296 fallback_idx = Some(idx);
297 }
298 }
299
300 if let Some(idx) = fallback_idx {
302 return self.consume_group_match(idx, group_start, group_end);
303 }
304
305 Err(ReplayFailure::Mismatch {
307 seq: self.replay_effects[group_start].seq,
308 expected: format!("one of group {} effects", group_id),
309 got: effect_type.to_string(),
310 })
311 }
312
313 fn consume_group_match(
314 &mut self,
315 idx: usize,
316 group_start: usize,
317 group_end: usize,
318 ) -> Result<RecordedOutcome, ReplayFailure> {
319 let outcome = self.replay_effects[idx].outcome.clone();
320 self.bump_effect_occurrence();
321 self.group_consumed.push(idx);
322 let group_size = group_end - group_start;
323 if self.group_consumed.len() >= group_size {
324 self.replay_pos = group_end;
325 self.group_consumed.clear();
326 }
327 Ok(outcome)
328 }
329
330 fn reset_group_state(&mut self) {
331 self.group_stack.clear();
332 self.branch_stack.clear();
333 self.effect_count_stack.clear();
334 self.next_group_id = 0;
335 self.group_consumed.clear();
336 }
337
338 fn current_branch_path(&self) -> String {
339 self.branch_stack
340 .iter()
341 .map(|i| i.to_string())
342 .collect::<Vec<_>>()
343 .join(".")
344 }
345
346 fn current_effect_occurrence(&self) -> Option<u32> {
347 self.effect_count_stack.last().copied()
348 }
349
350 fn bump_effect_occurrence(&mut self) {
351 if let Some(last) = self.effect_count_stack.last_mut() {
352 *last += 1;
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 fn recorded_value(text: &str) -> RecordedOutcome {
362 RecordedOutcome::Value(JsonValue::String(text.to_string()))
363 }
364
365 #[test]
366 fn nested_groups_preserve_outer_effect_occurrence() {
367 let mut state = EffectReplayState::default();
368
369 state.start_recording();
370 state.enter_group();
371 state.set_branch(0);
372 state.record_effect(
373 "Console.print",
374 vec![],
375 RecordedOutcome::Value(JsonValue::Null),
376 "",
377 0,
378 );
379
380 state.enter_group();
381 state.set_branch(1);
382 state.record_effect(
383 "Console.print",
384 vec![],
385 RecordedOutcome::Value(JsonValue::Null),
386 "",
387 0,
388 );
389 state.exit_group();
390
391 state.record_effect(
392 "Console.print",
393 vec![],
394 RecordedOutcome::Value(JsonValue::Null),
395 "",
396 0,
397 );
398
399 let effects = state.take_recorded_effects();
400 assert_eq!(effects.len(), 3);
401 assert_eq!(effects[0].branch_path.as_deref(), Some("0"));
402 assert_eq!(effects[0].effect_occurrence, Some(0));
403 assert_eq!(effects[1].branch_path.as_deref(), Some("0.1"));
404 assert_eq!(effects[1].effect_occurrence, Some(0));
405 assert_eq!(effects[2].branch_path.as_deref(), Some("0"));
406 assert_eq!(effects[2].effect_occurrence, Some(1));
407 }
408
409 #[test]
410 fn start_replay_clears_group_state() {
411 let mut state = EffectReplayState::default();
412 state.start_recording();
413 state.enter_group();
414 state.set_branch(3);
415 state.record_effect(
416 "Console.print",
417 vec![],
418 RecordedOutcome::Value(JsonValue::Null),
419 "",
420 0,
421 );
422
423 state.start_replay(Vec::new(), true);
424
425 assert!(state.group_stack.is_empty());
426 assert!(state.branch_stack.is_empty());
427 assert!(state.effect_count_stack.is_empty());
428 assert!(state.group_consumed.is_empty());
429 assert_eq!(state.next_group_id, 0);
430 assert_eq!(state.args_diff_count, 0);
431 }
432
433 #[test]
434 fn replay_group_matching_uses_effect_occurrence() {
435 let mut state = EffectReplayState::default();
436 state.start_replay(
437 vec![
438 EffectRecord {
439 seq: 1,
440 effect_type: "Console.print".to_string(),
441 args: vec![JsonValue::String("same".to_string())],
442 outcome: recorded_value("first"),
443 caller_fn: String::new(),
444 source_line: 0,
445 group_id: Some(1),
446 branch_path: Some("0".to_string()),
447 effect_occurrence: Some(0),
448 },
449 EffectRecord {
450 seq: 2,
451 effect_type: "Console.print".to_string(),
452 args: vec![JsonValue::String("same".to_string())],
453 outcome: recorded_value("second"),
454 caller_fn: String::new(),
455 source_line: 0,
456 group_id: Some(1),
457 branch_path: Some("0".to_string()),
458 effect_occurrence: Some(1),
459 },
460 ],
461 true,
462 );
463
464 state.enter_group();
465 state.set_branch(0);
466
467 let first = state
468 .replay_effect(
469 "Console.print",
470 Some(vec![JsonValue::String("same".to_string())]),
471 )
472 .expect("first replay should match");
473 let second = state
474 .replay_effect(
475 "Console.print",
476 Some(vec![JsonValue::String("same".to_string())]),
477 )
478 .expect("second replay should match");
479
480 assert_eq!(first, recorded_value("first"));
481 assert_eq!(second, recorded_value("second"));
482 }
483}