1use super::JsonValue;
2use super::json_to_string;
3use super::session::{EffectRecord, RecordedOutcome};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
6pub enum EffectReplayMode {
7 #[default]
8 Normal,
9 Record,
10 Replay,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ReplayFailure {
15 Exhausted {
16 effect_type: String,
17 position: usize,
18 },
19 Mismatch {
20 seq: u32,
21 expected: String,
22 got: String,
23 },
24 ArgsMismatch {
25 seq: u32,
26 effect_type: String,
27 expected: String,
28 got: String,
29 },
30 Unconsumed {
31 remaining: usize,
32 },
33}
34
35#[derive(Debug, Clone, Default)]
36pub struct EffectReplayState {
37 mode: EffectReplayMode,
38 recorded_effects: Vec<EffectRecord>,
39 replay_effects: Vec<EffectRecord>,
40 replay_pos: usize,
41 validate_replay_args: bool,
42 args_diff_count: usize,
43 group_stack: Vec<u32>,
45 branch_stack: Vec<u32>,
48 effect_count_stack: Vec<u32>,
50 next_group_id: u32,
52 group_consumed: Vec<usize>,
54 record_cap: Option<usize>,
60}
61
62impl EffectReplayState {
63 pub fn mode(&self) -> EffectReplayMode {
64 self.mode
65 }
66
67 pub fn set_normal(&mut self) {
68 self.mode = EffectReplayMode::Normal;
69 self.recorded_effects.clear();
70 self.replay_effects.clear();
71 self.replay_pos = 0;
72 self.validate_replay_args = false;
73 self.args_diff_count = 0;
74 self.reset_group_state();
75 }
76
77 pub fn start_recording(&mut self) {
78 self.mode = EffectReplayMode::Record;
79 self.recorded_effects.clear();
80 self.replay_effects.clear();
81 self.replay_pos = 0;
82 self.validate_replay_args = false;
83 self.args_diff_count = 0;
84 self.reset_group_state();
85 }
86
87 pub fn set_record_cap(&mut self, cap: Option<usize>) {
88 self.record_cap = cap;
89 }
90
91 pub fn record_full(&self) -> bool {
92 matches!(self.record_cap, Some(cap) if self.recorded_effects.len() >= cap)
93 }
94
95 pub fn start_replay(&mut self, effects: Vec<EffectRecord>, validate_args: bool) {
96 self.mode = EffectReplayMode::Replay;
97 self.replay_effects = effects;
98 self.replay_pos = 0;
99 self.validate_replay_args = validate_args;
100 self.recorded_effects.clear();
101 self.args_diff_count = 0;
102 self.reset_group_state();
103 }
104
105 pub fn take_recorded_effects(&mut self) -> Vec<EffectRecord> {
106 std::mem::take(&mut self.recorded_effects)
107 }
108
109 pub fn recorded_effects(&self) -> &[EffectRecord] {
110 &self.recorded_effects
111 }
112
113 pub fn replay_progress(&self) -> (usize, usize) {
114 (self.replay_pos, self.replay_effects.len())
115 }
116
117 pub fn args_diff_count(&self) -> usize {
118 self.args_diff_count
119 }
120
121 pub fn ensure_replay_consumed(&self) -> Result<(), ReplayFailure> {
122 if self.mode == EffectReplayMode::Replay && self.replay_pos < self.replay_effects.len() {
123 return Err(ReplayFailure::Unconsumed {
124 remaining: self.replay_effects.len() - self.replay_pos,
125 });
126 }
127 Ok(())
128 }
129
130 pub fn reset_scope(&mut self) {
137 self.next_group_id = 0;
138 self.group_stack.clear();
139 self.branch_stack.clear();
140 self.effect_count_stack.clear();
141 }
142
143 pub fn enter_group(&mut self) -> u32 {
145 self.next_group_id += 1;
146 let id = self.next_group_id;
147 self.group_stack.push(id);
148 self.branch_stack.push(0); self.effect_count_stack.push(0);
150 id
151 }
152
153 pub fn exit_group(&mut self) {
155 self.group_stack.pop();
156 self.branch_stack.pop();
157 self.effect_count_stack.pop();
158 }
159
160 pub fn current_group_id(&self) -> Option<u32> {
164 self.group_stack.last().copied()
165 }
166
167 pub fn current_branch_idx(&self) -> Option<u32> {
171 self.branch_stack.last().copied()
172 }
173
174 pub fn set_branch(&mut self, index: u32) {
176 if let Some(last) = self.branch_stack.last_mut() {
177 *last = index;
178 }
179 if let Some(last) = self.effect_count_stack.last_mut() {
180 *last = 0;
181 }
182 }
183
184 pub fn record_effect(
185 &mut self,
186 effect_type: &str,
187 args: Vec<JsonValue>,
188 outcome: RecordedOutcome,
189 caller_fn: &str,
190 source_line: usize,
191 ) {
192 let seq = self.recorded_effects.len() as u32 + 1;
193 self.recorded_effects.push(EffectRecord {
194 seq,
195 effect_type: effect_type.to_string(),
196 args,
197 outcome,
198 caller_fn: caller_fn.to_string(),
199 source_line,
200 group_id: self.group_stack.last().copied(),
201 branch_path: if self.branch_stack.is_empty() {
202 None
203 } else {
204 Some(self.current_branch_path())
205 },
206 effect_occurrence: if self.branch_stack.is_empty() {
207 None
208 } else {
209 self.current_effect_occurrence()
210 },
211 });
212 self.bump_effect_occurrence();
213 }
214
215 pub fn replay_effect(
216 &mut self,
217 effect_type: &str,
218 got_args: Option<Vec<JsonValue>>,
219 ) -> Result<RecordedOutcome, ReplayFailure> {
220 if self.replay_pos < self.replay_effects.len()
223 && let Some(gid) = self.replay_effects[self.replay_pos].group_id
224 {
225 return self.replay_effect_in_group(gid, effect_type, got_args);
226 }
227
228 if self.replay_pos >= self.replay_effects.len() {
230 return Err(ReplayFailure::Exhausted {
231 effect_type: effect_type.to_string(),
232 position: self.replay_pos + 1,
233 });
234 }
235
236 let record = self.replay_effects[self.replay_pos].clone();
237 if record.effect_type != effect_type {
238 return Err(ReplayFailure::Mismatch {
239 seq: record.seq,
240 expected: record.effect_type,
241 got: effect_type.to_string(),
242 });
243 }
244
245 if let Some(got_args) = got_args
246 && got_args != record.args
247 {
248 if self.validate_replay_args {
249 return Err(ReplayFailure::ArgsMismatch {
250 seq: record.seq,
251 effect_type: effect_type.to_string(),
252 expected: json_to_string(&JsonValue::Array(record.args.clone())),
253 got: json_to_string(&JsonValue::Array(got_args)),
254 });
255 }
256 self.args_diff_count += 1;
257 }
258
259 self.replay_pos += 1;
260 Ok(record.outcome)
261 }
262
263 fn replay_effect_in_group(
266 &mut self,
267 group_id: u32,
268 effect_type: &str,
269 got_args: Option<Vec<JsonValue>>,
270 ) -> Result<RecordedOutcome, ReplayFailure> {
271 let group_start = self.replay_pos;
273 let group_end = self.replay_effects[group_start..]
274 .iter()
275 .position(|e| e.group_id != Some(group_id))
276 .map(|offset| group_start + offset)
277 .unwrap_or(self.replay_effects.len());
278
279 let current_bp = if self.branch_stack.is_empty() {
282 None
283 } else {
284 Some(self.current_branch_path())
285 };
286
287 let mut fallback_idx: Option<usize> = None;
288 for idx in group_start..group_end {
289 if self.group_consumed.contains(&idx) {
290 continue;
291 }
292 let record = &self.replay_effects[idx];
293 if record.effect_type != effect_type {
294 continue;
295 }
296
297 let args_ok = match (&got_args, self.validate_replay_args) {
299 (Some(got), true) if *got != record.args => false,
300 (Some(got), false) if *got != record.args => {
301 self.args_diff_count += 1;
302 true
303 }
304 _ => true,
305 };
306 if !args_ok {
307 continue;
308 }
309
310 let bp_match = match (¤t_bp, &record.branch_path) {
313 (Some(got), Some(rec)) => {
314 if got != rec {
315 continue; }
317 true
318 }
319 _ => false, };
321 if bp_match {
322 let current_occ = self.current_effect_occurrence();
324 match (current_occ, record.effect_occurrence) {
325 (Some(got), Some(rec)) if got == rec => {
326 return self.consume_group_match(idx, group_start, group_end);
327 }
328 (Some(_), Some(_)) => continue, _ => {
330 if fallback_idx.is_none() {
332 fallback_idx = Some(idx);
333 }
334 }
335 }
336 } else if fallback_idx.is_none() {
337 fallback_idx = Some(idx);
338 }
339 }
340
341 if let Some(idx) = fallback_idx {
343 return self.consume_group_match(idx, group_start, group_end);
344 }
345
346 Err(ReplayFailure::Mismatch {
348 seq: self.replay_effects[group_start].seq,
349 expected: format!("one of group {} effects", group_id),
350 got: effect_type.to_string(),
351 })
352 }
353
354 fn consume_group_match(
355 &mut self,
356 idx: usize,
357 group_start: usize,
358 group_end: usize,
359 ) -> Result<RecordedOutcome, ReplayFailure> {
360 let outcome = self.replay_effects[idx].outcome.clone();
361 self.bump_effect_occurrence();
362 self.group_consumed.push(idx);
363 let group_size = group_end - group_start;
364 if self.group_consumed.len() >= group_size {
365 self.replay_pos = group_end;
366 self.group_consumed.clear();
367 }
368 Ok(outcome)
369 }
370
371 fn reset_group_state(&mut self) {
372 self.group_stack.clear();
373 self.branch_stack.clear();
374 self.effect_count_stack.clear();
375 self.next_group_id = 0;
376 self.group_consumed.clear();
377 }
378
379 fn current_branch_path(&self) -> String {
380 self.branch_stack
381 .iter()
382 .map(|i| i.to_string())
383 .collect::<Vec<_>>()
384 .join(".")
385 }
386
387 fn current_effect_occurrence(&self) -> Option<u32> {
388 self.effect_count_stack.last().copied()
389 }
390
391 fn bump_effect_occurrence(&mut self) {
392 if let Some(last) = self.effect_count_stack.last_mut() {
393 *last += 1;
394 }
395 }
396
397 pub fn oracle_path_string(&self) -> String {
400 self.current_branch_path()
401 }
402
403 pub fn oracle_branch_counter(&self) -> Option<u32> {
406 self.current_effect_occurrence()
407 }
408
409 pub fn bump_oracle_branch_counter(&mut self) {
412 self.bump_effect_occurrence();
413 }
414
415 pub fn is_inside_group(&self) -> bool {
418 !self.branch_stack.is_empty()
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 fn recorded_value(text: &str) -> RecordedOutcome {
427 RecordedOutcome::Value(JsonValue::String(text.to_string()))
428 }
429
430 #[test]
431 fn nested_groups_preserve_outer_effect_occurrence() {
432 let mut state = EffectReplayState::default();
433
434 state.start_recording();
435 state.enter_group();
436 state.set_branch(0);
437 state.record_effect(
438 "Console.print",
439 vec![],
440 RecordedOutcome::Value(JsonValue::Null),
441 "",
442 0,
443 );
444
445 state.enter_group();
446 state.set_branch(1);
447 state.record_effect(
448 "Console.print",
449 vec![],
450 RecordedOutcome::Value(JsonValue::Null),
451 "",
452 0,
453 );
454 state.exit_group();
455
456 state.record_effect(
457 "Console.print",
458 vec![],
459 RecordedOutcome::Value(JsonValue::Null),
460 "",
461 0,
462 );
463
464 let effects = state.take_recorded_effects();
465 assert_eq!(effects.len(), 3);
466 assert_eq!(effects[0].branch_path.as_deref(), Some("0"));
467 assert_eq!(effects[0].effect_occurrence, Some(0));
468 assert_eq!(effects[1].branch_path.as_deref(), Some("0.1"));
469 assert_eq!(effects[1].effect_occurrence, Some(0));
470 assert_eq!(effects[2].branch_path.as_deref(), Some("0"));
471 assert_eq!(effects[2].effect_occurrence, Some(1));
472 }
473
474 #[test]
475 fn start_replay_clears_group_state() {
476 let mut state = EffectReplayState::default();
477 state.start_recording();
478 state.enter_group();
479 state.set_branch(3);
480 state.record_effect(
481 "Console.print",
482 vec![],
483 RecordedOutcome::Value(JsonValue::Null),
484 "",
485 0,
486 );
487
488 state.start_replay(Vec::new(), true);
489
490 assert!(state.group_stack.is_empty());
491 assert!(state.branch_stack.is_empty());
492 assert!(state.effect_count_stack.is_empty());
493 assert!(state.group_consumed.is_empty());
494 assert_eq!(state.next_group_id, 0);
495 assert_eq!(state.args_diff_count, 0);
496 }
497
498 #[test]
499 fn replay_group_matching_uses_effect_occurrence() {
500 let mut state = EffectReplayState::default();
501 state.start_replay(
502 vec![
503 EffectRecord {
504 seq: 1,
505 effect_type: "Console.print".to_string(),
506 args: vec![JsonValue::String("same".to_string())],
507 outcome: recorded_value("first"),
508 caller_fn: String::new(),
509 source_line: 0,
510 group_id: Some(1),
511 branch_path: Some("0".to_string()),
512 effect_occurrence: Some(0),
513 },
514 EffectRecord {
515 seq: 2,
516 effect_type: "Console.print".to_string(),
517 args: vec![JsonValue::String("same".to_string())],
518 outcome: recorded_value("second"),
519 caller_fn: String::new(),
520 source_line: 0,
521 group_id: Some(1),
522 branch_path: Some("0".to_string()),
523 effect_occurrence: Some(1),
524 },
525 ],
526 true,
527 );
528
529 state.enter_group();
530 state.set_branch(0);
531
532 let first = state
533 .replay_effect(
534 "Console.print",
535 Some(vec![JsonValue::String("same".to_string())]),
536 )
537 .expect("first replay should match");
538 let second = state
539 .replay_effect(
540 "Console.print",
541 Some(vec![JsonValue::String("same".to_string())]),
542 )
543 .expect("second replay should match");
544
545 assert_eq!(first, recorded_value("first"));
546 assert_eq!(second, recorded_value("second"));
547 }
548}