1use super::types::{ActionType, ReplaySession};
8use crate::MemvidError;
9use crate::error::Result;
10use crate::memvid::lifecycle::Memvid;
11use serde::{Deserialize, Serialize};
12use std::time::Instant;
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ActionReplayResult {
18 pub sequence: u64,
20 pub matched: bool,
22 pub diff: Option<String>,
24 pub duration_ms: u64,
26 pub action_type: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReplayResult {
33 pub session_id: Uuid,
35 pub total_actions: usize,
37 pub matched_actions: usize,
39 pub mismatched_actions: usize,
41 pub skipped_actions: usize,
43 pub action_results: Vec<ActionReplayResult>,
45 pub total_duration_ms: u64,
47 pub from_checkpoint: Option<u64>,
49}
50
51impl ReplayResult {
52 #[must_use]
54 pub fn is_success(&self) -> bool {
55 self.mismatched_actions == 0
56 }
57
58 #[must_use]
60 pub fn match_rate(&self) -> f64 {
61 if self.total_actions == 0 {
62 100.0
63 } else {
64 (self.matched_actions as f64 / self.total_actions as f64) * 100.0
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct ReplayExecutionConfig {
72 pub skip_puts: bool,
74 pub skip_finds: bool,
76 pub skip_asks: bool,
78 pub stop_on_mismatch: bool,
80 pub verbose: bool,
82 pub top_k: Option<usize>,
85 pub adaptive: bool,
87 pub min_relevancy: f32,
89 pub audit_mode: bool,
92 pub use_model: Option<String>,
95 pub generate_diff: bool,
97}
98
99impl Default for ReplayExecutionConfig {
100 fn default() -> Self {
101 Self {
102 skip_puts: false,
103 skip_finds: false,
104 skip_asks: false,
105 stop_on_mismatch: false,
106 verbose: false,
107 top_k: None,
108 adaptive: false,
109 min_relevancy: 0.5,
110 audit_mode: false,
111 use_model: None,
112 generate_diff: false,
113 }
114 }
115}
116
117pub struct ReplayEngine<'a> {
119 mem: &'a mut Memvid,
121 config: ReplayExecutionConfig,
123}
124
125impl<'a> ReplayEngine<'a> {
126 pub fn new(mem: &'a mut Memvid, config: ReplayExecutionConfig) -> Self {
128 Self { mem, config }
129 }
130
131 pub fn replay_session(&mut self, session: &ReplaySession) -> Result<ReplayResult> {
133 self.replay_session_from(session, None)
134 }
135
136 pub fn replay_session_from(
138 &mut self,
139 session: &ReplaySession,
140 from_checkpoint: Option<u64>,
141 ) -> Result<ReplayResult> {
142 let start_time = Instant::now();
143 let mut result = ReplayResult {
144 session_id: session.session_id,
145 total_actions: 0,
146 matched_actions: 0,
147 mismatched_actions: 0,
148 skipped_actions: 0,
149 action_results: Vec::new(),
150 total_duration_ms: 0,
151 from_checkpoint,
152 };
153
154 let start_sequence = if let Some(checkpoint_id) = from_checkpoint {
156 let checkpoint = session
157 .checkpoints
158 .iter()
159 .find(|c| c.id == checkpoint_id)
160 .ok_or_else(|| MemvidError::InvalidQuery {
161 reason: format!("Checkpoint {checkpoint_id} not found in session"),
162 })?;
163 checkpoint.at_sequence
164 } else {
165 0
166 };
167
168 let actions_to_replay: Vec<_> = session
170 .actions
171 .iter()
172 .filter(|a| a.sequence >= start_sequence)
173 .collect();
174
175 result.total_actions = actions_to_replay.len();
176
177 for action in actions_to_replay {
178 let action_start = Instant::now();
179 let mut action_result = ActionReplayResult {
180 sequence: action.sequence,
181 matched: false,
182 diff: None,
183 duration_ms: 0,
184 action_type: action.action_type.name().to_string(),
185 };
186
187 match &action.action_type {
188 ActionType::Put { frame_id } => {
189 if self.config.skip_puts {
190 result.skipped_actions += 1;
191 action_result.diff = Some("skipped".to_string());
192 } else {
193 let frame_count = self.mem.toc.frames.len();
197 if frame_count > 0 {
198 action_result.matched = true;
199 action_result.diff = Some(format!(
200 "Put verified (seq {frame_id}, {frame_count} frames total)"
201 ));
202 result.matched_actions += 1;
203 } else {
204 action_result.matched = false;
205 action_result.diff = Some("No frames found".to_string());
206 result.mismatched_actions += 1;
207 }
208 }
209 }
210
211 ActionType::Find {
212 query,
213 mode: _,
214 result_count,
215 } => {
216 if self.config.skip_finds {
217 result.skipped_actions += 1;
218 action_result.diff = Some("skipped".to_string());
219 } else {
220 let replay_top_k = self.config.top_k.unwrap_or(*result_count);
224
225 let search_request = crate::types::SearchRequest {
228 query: query.clone(),
229 top_k: replay_top_k,
230 snippet_chars: 120,
231 uri: None,
232 scope: None,
233 cursor: None,
234 #[cfg(feature = "temporal_track")]
235 temporal: None,
236 as_of_frame: None,
237 as_of_ts: None,
238 no_sketch: false,
239 acl_context: None,
240 acl_enforcement_mode: crate::types::AclEnforcementMode::Audit,
241 };
242 match self.mem.search(search_request) {
243 Ok(response) => {
244 let replay_count = if self.config.adaptive {
245 response
247 .hits
248 .iter()
249 .filter(|h| {
250 h.score.unwrap_or(0.0) >= self.config.min_relevancy
251 })
252 .count()
253 } else {
254 response.hits.len()
255 };
256
257 if self.config.top_k.is_some() && replay_count != *result_count {
259 let mut doc_details = String::new();
261
262 if replay_count > *result_count {
263 let extra_count = replay_count - *result_count;
265 doc_details.push_str(
266 "\n Documents discovered with higher top-k:",
267 );
268 for (i, hit) in response.hits.iter().enumerate() {
269 let score = hit.score.unwrap_or(0.0);
270 let uri = &hit.uri;
271 let marker =
272 if i >= *result_count { " [NEW]" } else { "" };
273 doc_details.push_str(&format!(
274 "\n [{}] {} (score: {:.2}){}",
275 i + 1,
276 uri,
277 score,
278 marker
279 ));
280 }
281 action_result.matched = false;
282 action_result.diff = Some(format!(
283 "DISCOVERY: original found {result_count}, replay with top-k={replay_top_k} found {replay_count} (+{extra_count} docs). Query: \"{query}\"{doc_details}"
284 ));
285 } else {
286 let missed_count = *result_count - replay_count;
288 doc_details.push_str(
289 "\n With lower top-k, only these would be found:",
290 );
291 for (i, hit) in response.hits.iter().enumerate() {
292 let score = hit.score.unwrap_or(0.0);
293 let uri = &hit.uri;
294 doc_details.push_str(&format!(
295 "\n [{}] {} (score: {:.2})",
296 i + 1,
297 uri,
298 score
299 ));
300 }
301 doc_details.push_str(&format!(
302 "\n {missed_count} document(s) would be MISSED with top-k={replay_top_k}"
303 ));
304 action_result.matched = false;
305 action_result.diff = Some(format!(
306 "FILTER: original found {result_count}, replay with top-k={replay_top_k} would only find {replay_count} (-{missed_count} docs). Query: \"{query}\"{doc_details}"
307 ));
308 }
309
310 result.mismatched_actions += 1;
311 } else if replay_count == *result_count {
312 action_result.matched = true;
313 if self.config.adaptive {
314 action_result.diff = Some(format!(
315 "Matched with adaptive (min_relevancy={})",
316 self.config.min_relevancy
317 ));
318 }
319 result.matched_actions += 1;
320 } else {
321 action_result.matched = false;
322 action_result.diff = Some(format!(
323 "Result count mismatch: expected {result_count}, got {replay_count}"
324 ));
325 result.mismatched_actions += 1;
326 }
327 }
328 Err(e) => {
329 action_result.matched = false;
330 action_result.diff = Some(format!("Search failed: {e}"));
331 result.mismatched_actions += 1;
332 }
333 }
334 }
335 }
336
337 ActionType::Ask {
338 query,
339 provider,
340 model,
341 } => {
342 if self.config.skip_asks {
343 result.skipped_actions += 1;
344 action_result.diff = Some("skipped".to_string());
345 } else if self.config.audit_mode {
346 let frames_str = if action.affected_frames.is_empty() {
348 "none recorded".to_string()
349 } else {
350 action
351 .affected_frames
352 .iter()
353 .map(std::string::ToString::to_string)
354 .collect::<Vec<_>>()
355 .join(", ")
356 };
357
358 let original_answer = if action.output_preview.is_empty() {
359 "(no answer recorded)".to_string()
360 } else {
361 action.output_preview.clone()
362 };
363
364 let mut details = format!(
366 "Question: \"{query}\"\n Mode: AUDIT (frozen retrieval)\n Original Model: {provider}:{model}\n Frozen frames: [{frames_str}]"
367 );
368
369 if let Some(ref override_model) = self.config.use_model {
371 details
372 .push_str(&format!("\n Override Model: {override_model}"));
373 }
374
375 let answer_preview = if original_answer.len() > 200 {
377 format!("{}...", &original_answer[..200])
378 } else {
379 original_answer.clone()
380 };
381 details
382 .push_str(&format!("\n Original Answer: \"{answer_preview}\""));
383
384 if action.affected_frames.is_empty() {
386 details.push_str("\n Context: MISSING (no frames recorded - session recorded before Phase 1)");
387 action_result.matched = false;
388 result.mismatched_actions += 1;
389 } else {
390 details.push_str("\n Context: VERIFIED (frames frozen)");
391 action_result.matched = true;
392 result.matched_actions += 1;
393 }
394
395 action_result.diff = Some(details);
396 } else {
397 let frames_str = if action.affected_frames.is_empty() {
400 "none recorded".to_string()
401 } else {
402 action
403 .affected_frames
404 .iter()
405 .map(std::string::ToString::to_string)
406 .collect::<Vec<_>>()
407 .join(", ")
408 };
409
410 let answer_preview = if action.output_preview.is_empty() {
411 "(no answer recorded)".to_string()
412 } else {
413 let preview = &action.output_preview;
415 if preview.len() > 200 {
416 format!("{}...", &preview[..200])
417 } else {
418 preview.clone()
419 }
420 };
421
422 let details = format!(
424 "Question: \"{query}\"\n Model: {provider}:{model}\n Retrieved frames: [{frames_str}]\n Answer: \"{answer_preview}\""
425 );
426
427 action_result.matched = true;
428 action_result.diff = Some(details);
429 result.matched_actions += 1;
430 }
431 }
432
433 ActionType::Checkpoint { checkpoint_id } => {
434 action_result.matched = true;
436 action_result.diff = Some(format!("Checkpoint {checkpoint_id} verified"));
437 result.matched_actions += 1;
438 }
439
440 ActionType::PutMany { frame_ids, count } => {
441 if self.config.skip_puts {
442 result.skipped_actions += 1;
443 action_result.diff = Some("skipped".to_string());
444 } else {
445 let existing: Vec<_> = frame_ids
447 .iter()
448 .filter(|id| self.mem.frame_by_id(**id).is_ok())
449 .collect();
450 if existing.len() == *count {
451 action_result.matched = true;
452 result.matched_actions += 1;
453 } else {
454 action_result.matched = false;
455 action_result.diff = Some(format!(
456 "Expected {} frames, found {}",
457 count,
458 existing.len()
459 ));
460 result.mismatched_actions += 1;
461 }
462 }
463 }
464
465 ActionType::Update { frame_id } => {
466 if self.config.skip_puts {
467 result.skipped_actions += 1;
468 action_result.diff = Some("skipped".to_string());
469 } else {
470 if self.mem.frame_by_id(*frame_id).is_ok() {
472 action_result.matched = true;
473 result.matched_actions += 1;
474 } else {
475 action_result.matched = false;
476 action_result.diff = Some(format!("Frame {frame_id} not found"));
477 result.mismatched_actions += 1;
478 }
479 }
480 }
481
482 ActionType::Delete { frame_id } => {
483 if self.config.skip_puts {
484 result.skipped_actions += 1;
485 action_result.diff = Some("skipped".to_string());
486 } else {
487 if self.mem.frame_by_id(*frame_id).is_err() {
489 action_result.matched = true;
490 result.matched_actions += 1;
491 } else {
492 action_result.matched = false;
493 action_result.diff = Some(format!("Frame {frame_id} still exists"));
494 result.mismatched_actions += 1;
495 }
496 }
497 }
498
499 ActionType::ToolCall { name, args_hash: _ } => {
500 result.skipped_actions += 1;
502 action_result.diff = Some(format!("Tool call '{name}' skipped"));
503 }
504 }
505
506 action_result.duration_ms = action_start
507 .elapsed()
508 .as_millis()
509 .try_into()
510 .unwrap_or(u64::MAX);
511 result.action_results.push(action_result);
512
513 if self.config.stop_on_mismatch
515 && result.mismatched_actions > 0
516 && result.action_results.last().is_some_and(|r| !r.matched)
517 {
518 break;
519 }
520 }
521
522 result.total_duration_ms = start_time
523 .elapsed()
524 .as_millis()
525 .try_into()
526 .unwrap_or(u64::MAX);
527
528 if self.config.verbose {
529 tracing::info!(
530 "Replay completed: {}/{} actions matched ({}%)",
531 result.matched_actions,
532 result.total_actions,
533 result.match_rate()
534 );
535 }
536
537 Ok(result)
538 }
539
540 #[must_use]
542 pub fn compare_sessions(
543 session_a: &ReplaySession,
544 session_b: &ReplaySession,
545 ) -> SessionComparison {
546 let mut comparison = SessionComparison {
547 session_a_id: session_a.session_id,
548 session_b_id: session_b.session_id,
549 actions_only_in_a: Vec::new(),
550 actions_only_in_b: Vec::new(),
551 differing_actions: Vec::new(),
552 matching_actions: 0,
553 };
554
555 let a_actions: std::collections::HashMap<_, _> =
557 session_a.actions.iter().map(|a| (a.sequence, a)).collect();
558 let b_actions: std::collections::HashMap<_, _> =
559 session_b.actions.iter().map(|a| (a.sequence, a)).collect();
560
561 for seq in a_actions.keys() {
563 if !b_actions.contains_key(seq) {
564 comparison.actions_only_in_a.push(*seq);
565 }
566 }
567
568 for seq in b_actions.keys() {
570 if !a_actions.contains_key(seq) {
571 comparison.actions_only_in_b.push(*seq);
572 }
573 }
574
575 for (seq, action_a) in &a_actions {
577 if let Some(action_b) = b_actions.get(seq) {
578 if action_a.action_type.name() == action_b.action_type.name() {
579 let same = match (&action_a.action_type, &action_b.action_type) {
581 (ActionType::Put { frame_id: a }, ActionType::Put { frame_id: b }) => {
582 a == b
583 }
584 (
585 ActionType::Find {
586 query: qa,
587 result_count: ra,
588 ..
589 },
590 ActionType::Find {
591 query: qb,
592 result_count: rb,
593 ..
594 },
595 ) => qa == qb && ra == rb,
596 (ActionType::Ask { query: qa, .. }, ActionType::Ask { query: qb, .. }) => {
597 qa == qb
598 }
599 (
600 ActionType::Checkpoint { checkpoint_id: a },
601 ActionType::Checkpoint { checkpoint_id: b },
602 ) => a == b,
603 _ => false,
604 };
605
606 if same {
607 comparison.matching_actions += 1;
608 } else {
609 comparison.differing_actions.push(ActionDiff {
610 sequence: *seq,
611 action_type_a: action_a.action_type.name().to_string(),
612 action_type_b: action_b.action_type.name().to_string(),
613 description: "Action details differ".to_string(),
614 });
615 }
616 } else {
617 comparison.differing_actions.push(ActionDiff {
618 sequence: *seq,
619 action_type_a: action_a.action_type.name().to_string(),
620 action_type_b: action_b.action_type.name().to_string(),
621 description: format!(
622 "Action type mismatch: {} vs {}",
623 action_a.action_type.name(),
624 action_b.action_type.name()
625 ),
626 });
627 }
628 }
629 }
630
631 comparison
632 }
633}
634
635#[derive(Debug, Clone, Serialize, Deserialize)]
637pub struct SessionComparison {
638 pub session_a_id: Uuid,
640 pub session_b_id: Uuid,
642 pub actions_only_in_a: Vec<u64>,
644 pub actions_only_in_b: Vec<u64>,
646 pub differing_actions: Vec<ActionDiff>,
648 pub matching_actions: usize,
650}
651
652impl SessionComparison {
653 #[must_use]
655 pub fn is_identical(&self) -> bool {
656 self.actions_only_in_a.is_empty()
657 && self.actions_only_in_b.is_empty()
658 && self.differing_actions.is_empty()
659 }
660}
661
662#[derive(Debug, Clone, Serialize, Deserialize)]
664pub struct ActionDiff {
665 pub sequence: u64,
667 pub action_type_a: String,
669 pub action_type_b: String,
671 pub description: String,
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use crate::replay::types::ReplayAction;
679
680 #[test]
681 fn test_replay_result_success() {
682 let result = ReplayResult {
683 session_id: Uuid::new_v4(),
684 total_actions: 10,
685 matched_actions: 10,
686 mismatched_actions: 0,
687 skipped_actions: 0,
688 action_results: Vec::new(),
689 total_duration_ms: 100,
690 from_checkpoint: None,
691 };
692 assert!(result.is_success());
693 assert_eq!(result.match_rate(), 100.0);
694 }
695
696 #[test]
697 fn test_replay_result_partial() {
698 let result = ReplayResult {
699 session_id: Uuid::new_v4(),
700 total_actions: 10,
701 matched_actions: 7,
702 mismatched_actions: 3,
703 skipped_actions: 0,
704 action_results: Vec::new(),
705 total_duration_ms: 100,
706 from_checkpoint: None,
707 };
708 assert!(!result.is_success());
709 assert_eq!(result.match_rate(), 70.0);
710 }
711
712 #[test]
713 fn test_session_comparison_identical() {
714 use std::collections::HashMap;
715
716 let session_a = ReplaySession {
717 session_id: Uuid::new_v4(),
718 name: Some("A".to_string()),
719 created_secs: 0,
720 ended_secs: Some(100),
721 actions: vec![ReplayAction::new(
722 0,
723 ActionType::Find {
724 query: "test".to_string(),
725 mode: "lex".to_string(),
726 result_count: 5,
727 },
728 )],
729 checkpoints: Vec::new(),
730 metadata: HashMap::new(),
731 version: 1,
732 };
733
734 let session_b = ReplaySession {
735 session_id: Uuid::new_v4(),
736 name: Some("B".to_string()),
737 created_secs: 0,
738 ended_secs: Some(100),
739 actions: vec![ReplayAction::new(
740 0,
741 ActionType::Find {
742 query: "test".to_string(),
743 mode: "lex".to_string(),
744 result_count: 5,
745 },
746 )],
747 checkpoints: Vec::new(),
748 metadata: HashMap::new(),
749 version: 1,
750 };
751
752 let comparison = ReplayEngine::compare_sessions(&session_a, &session_b);
753 assert!(comparison.is_identical());
754 assert_eq!(comparison.matching_actions, 1);
755 }
756}