1use super::types::{ActionType, ReplaySession};
8use crate::error::Result;
9use crate::memvid::lifecycle::Memvid;
10use crate::MemvidError;
11use serde::{Deserialize, Serialize};
12use std::time::Instant;
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ActionReplayResult {
18 pub sequence: u64,
20 pub matched: bool,
22 pub diff: Option<String>,
24 pub duration_ms: u64,
26 pub action_type: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReplayResult {
33 pub session_id: Uuid,
35 pub total_actions: usize,
37 pub matched_actions: usize,
39 pub mismatched_actions: usize,
41 pub skipped_actions: usize,
43 pub action_results: Vec<ActionReplayResult>,
45 pub total_duration_ms: u64,
47 pub from_checkpoint: Option<u64>,
49}
50
51impl ReplayResult {
52 pub fn is_success(&self) -> bool {
54 self.mismatched_actions == 0
55 }
56
57 pub fn match_rate(&self) -> f64 {
59 if self.total_actions == 0 {
60 100.0
61 } else {
62 (self.matched_actions as f64 / self.total_actions as f64) * 100.0
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct ReplayExecutionConfig {
70 pub skip_puts: bool,
72 pub skip_finds: bool,
74 pub skip_asks: bool,
76 pub stop_on_mismatch: bool,
78 pub verbose: bool,
80 pub top_k: Option<usize>,
83 pub adaptive: bool,
85 pub min_relevancy: f32,
87}
88
89impl Default for ReplayExecutionConfig {
90 fn default() -> Self {
91 Self {
92 skip_puts: false,
93 skip_finds: false,
94 skip_asks: false,
95 stop_on_mismatch: false,
96 verbose: false,
97 top_k: None,
98 adaptive: false,
99 min_relevancy: 0.5,
100 }
101 }
102}
103
104pub struct ReplayEngine<'a> {
106 mem: &'a mut Memvid,
108 config: ReplayExecutionConfig,
110}
111
112impl<'a> ReplayEngine<'a> {
113 pub fn new(mem: &'a mut Memvid, config: ReplayExecutionConfig) -> Self {
115 Self { mem, config }
116 }
117
118 pub fn replay_session(&mut self, session: &ReplaySession) -> Result<ReplayResult> {
120 self.replay_session_from(session, None)
121 }
122
123 pub fn replay_session_from(
125 &mut self,
126 session: &ReplaySession,
127 from_checkpoint: Option<u64>,
128 ) -> Result<ReplayResult> {
129 let start_time = Instant::now();
130 let mut result = ReplayResult {
131 session_id: session.session_id,
132 total_actions: 0,
133 matched_actions: 0,
134 mismatched_actions: 0,
135 skipped_actions: 0,
136 action_results: Vec::new(),
137 total_duration_ms: 0,
138 from_checkpoint,
139 };
140
141 let start_sequence = if let Some(checkpoint_id) = from_checkpoint {
143 let checkpoint = session
144 .checkpoints
145 .iter()
146 .find(|c| c.id == checkpoint_id)
147 .ok_or_else(|| MemvidError::InvalidQuery {
148 reason: format!("Checkpoint {} not found in session", checkpoint_id),
149 })?;
150 checkpoint.at_sequence
151 } else {
152 0
153 };
154
155 let actions_to_replay: Vec<_> = session
157 .actions
158 .iter()
159 .filter(|a| a.sequence >= start_sequence)
160 .collect();
161
162 result.total_actions = actions_to_replay.len();
163
164 for action in actions_to_replay {
165 let action_start = Instant::now();
166 let mut action_result = ActionReplayResult {
167 sequence: action.sequence,
168 matched: false,
169 diff: None,
170 duration_ms: 0,
171 action_type: action.action_type.name().to_string(),
172 };
173
174 match &action.action_type {
175 ActionType::Put { frame_id } => {
176 if self.config.skip_puts {
177 result.skipped_actions += 1;
178 action_result.diff = Some("skipped".to_string());
179 } else {
180 let frame_count = self.mem.toc.frames.len();
184 if frame_count > 0 {
185 action_result.matched = true;
186 action_result.diff = Some(format!(
187 "Put verified (seq {}, {} frames total)",
188 frame_id, frame_count
189 ));
190 result.matched_actions += 1;
191 } else {
192 action_result.matched = false;
193 action_result.diff = Some("No frames found".to_string());
194 result.mismatched_actions += 1;
195 }
196 }
197 }
198
199 ActionType::Find {
200 query,
201 mode: _,
202 result_count,
203 } => {
204 if self.config.skip_finds {
205 result.skipped_actions += 1;
206 action_result.diff = Some("skipped".to_string());
207 } else {
208 let replay_top_k = self.config.top_k.unwrap_or(*result_count);
212
213 let search_request = crate::types::SearchRequest {
216 query: query.clone(),
217 top_k: replay_top_k,
218 snippet_chars: 120,
219 uri: None,
220 scope: None,
221 cursor: None,
222 #[cfg(feature = "temporal_track")]
223 temporal: None,
224 as_of_frame: None,
225 as_of_ts: None,
226 };
227 match self.mem.search(search_request) {
228 Ok(response) => {
229 let replay_count = if self.config.adaptive {
230 response
232 .hits
233 .iter()
234 .filter(|h| h.score.unwrap_or(0.0) >= self.config.min_relevancy)
235 .count()
236 } else {
237 response.hits.len()
238 };
239
240 if self.config.top_k.is_some() && replay_count != *result_count {
242 let mut doc_details = String::new();
244
245 if replay_count > *result_count {
246 let extra_count = replay_count - *result_count;
248 doc_details.push_str("\n Documents discovered with higher top-k:");
249 for (i, hit) in response.hits.iter().enumerate() {
250 let score = hit.score.unwrap_or(0.0);
251 let uri = &hit.uri;
252 let marker = if i >= *result_count { " [NEW]" } else { "" };
253 doc_details.push_str(&format!(
254 "\n [{}] {} (score: {:.2}){}",
255 i + 1, uri, score, marker
256 ));
257 }
258 action_result.matched = false;
259 action_result.diff = Some(format!(
260 "DISCOVERY: original found {}, replay with top-k={} found {} (+{} docs). Query: \"{}\"{}",
261 result_count,
262 replay_top_k,
263 replay_count,
264 extra_count,
265 query,
266 doc_details
267 ));
268 } else {
269 let missed_count = *result_count - replay_count;
271 doc_details.push_str("\n With lower top-k, only these would be found:");
272 for (i, hit) in response.hits.iter().enumerate() {
273 let score = hit.score.unwrap_or(0.0);
274 let uri = &hit.uri;
275 doc_details.push_str(&format!(
276 "\n [{}] {} (score: {:.2})",
277 i + 1, uri, score
278 ));
279 }
280 doc_details.push_str(&format!(
281 "\n {} document(s) would be MISSED with top-k={}",
282 missed_count, replay_top_k
283 ));
284 action_result.matched = false;
285 action_result.diff = Some(format!(
286 "FILTER: original found {}, replay with top-k={} would only find {} (-{} docs). Query: \"{}\"{}",
287 result_count,
288 replay_top_k,
289 replay_count,
290 missed_count,
291 query,
292 doc_details
293 ));
294 }
295
296 result.mismatched_actions += 1;
297 } else if replay_count == *result_count {
298 action_result.matched = true;
299 if self.config.adaptive {
300 action_result.diff = Some(format!(
301 "Matched with adaptive (min_relevancy={})",
302 self.config.min_relevancy
303 ));
304 }
305 result.matched_actions += 1;
306 } else {
307 action_result.matched = false;
308 action_result.diff = Some(format!(
309 "Result count mismatch: expected {}, got {}",
310 result_count,
311 replay_count
312 ));
313 result.mismatched_actions += 1;
314 }
315 }
316 Err(e) => {
317 action_result.matched = false;
318 action_result.diff =
319 Some(format!("Search failed: {}", e));
320 result.mismatched_actions += 1;
321 }
322 }
323 }
324 }
325
326 ActionType::Ask {
327 query: _,
328 provider: _,
329 model: _,
330 } => {
331 if self.config.skip_asks {
332 result.skipped_actions += 1;
333 action_result.diff = Some("skipped".to_string());
334 } else {
335 result.skipped_actions += 1;
338 action_result.diff =
339 Some("LLM responses are non-deterministic, skipped".to_string());
340 }
341 }
342
343 ActionType::Checkpoint { checkpoint_id } => {
344 action_result.matched = true;
346 action_result.diff = Some(format!("Checkpoint {} verified", checkpoint_id));
347 result.matched_actions += 1;
348 }
349
350 ActionType::PutMany { frame_ids, count } => {
351 if self.config.skip_puts {
352 result.skipped_actions += 1;
353 action_result.diff = Some("skipped".to_string());
354 } else {
355 let existing: Vec<_> = frame_ids
357 .iter()
358 .filter(|id| self.mem.frame_by_id(**id).is_ok())
359 .collect();
360 if existing.len() == *count {
361 action_result.matched = true;
362 result.matched_actions += 1;
363 } else {
364 action_result.matched = false;
365 action_result.diff = Some(format!(
366 "Expected {} frames, found {}",
367 count,
368 existing.len()
369 ));
370 result.mismatched_actions += 1;
371 }
372 }
373 }
374
375 ActionType::Update { frame_id } => {
376 if self.config.skip_puts {
377 result.skipped_actions += 1;
378 action_result.diff = Some("skipped".to_string());
379 } else {
380 if self.mem.frame_by_id(*frame_id).is_ok() {
382 action_result.matched = true;
383 result.matched_actions += 1;
384 } else {
385 action_result.matched = false;
386 action_result.diff = Some(format!("Frame {} not found", frame_id));
387 result.mismatched_actions += 1;
388 }
389 }
390 }
391
392 ActionType::Delete { frame_id } => {
393 if self.config.skip_puts {
394 result.skipped_actions += 1;
395 action_result.diff = Some("skipped".to_string());
396 } else {
397 if self.mem.frame_by_id(*frame_id).is_err() {
399 action_result.matched = true;
400 result.matched_actions += 1;
401 } else {
402 action_result.matched = false;
403 action_result.diff =
404 Some(format!("Frame {} still exists", frame_id));
405 result.mismatched_actions += 1;
406 }
407 }
408 }
409
410 ActionType::ToolCall { name, args_hash: _ } => {
411 result.skipped_actions += 1;
413 action_result.diff = Some(format!("Tool call '{}' skipped", name));
414 }
415 }
416
417 action_result.duration_ms = action_start.elapsed().as_millis() as u64;
418 result.action_results.push(action_result);
419
420 if self.config.stop_on_mismatch
422 && result.mismatched_actions > 0
423 && result.action_results.last().map_or(false, |r| !r.matched)
424 {
425 break;
426 }
427 }
428
429 result.total_duration_ms = start_time.elapsed().as_millis() as u64;
430
431 if self.config.verbose {
432 tracing::info!(
433 "Replay completed: {}/{} actions matched ({}%)",
434 result.matched_actions,
435 result.total_actions,
436 result.match_rate()
437 );
438 }
439
440 Ok(result)
441 }
442
443 pub fn compare_sessions(
445 session_a: &ReplaySession,
446 session_b: &ReplaySession,
447 ) -> SessionComparison {
448 let mut comparison = SessionComparison {
449 session_a_id: session_a.session_id,
450 session_b_id: session_b.session_id,
451 actions_only_in_a: Vec::new(),
452 actions_only_in_b: Vec::new(),
453 differing_actions: Vec::new(),
454 matching_actions: 0,
455 };
456
457 let a_actions: std::collections::HashMap<_, _> = session_a
459 .actions
460 .iter()
461 .map(|a| (a.sequence, a))
462 .collect();
463 let b_actions: std::collections::HashMap<_, _> = session_b
464 .actions
465 .iter()
466 .map(|a| (a.sequence, a))
467 .collect();
468
469 for (seq, _action) in &a_actions {
471 if !b_actions.contains_key(seq) {
472 comparison.actions_only_in_a.push(*seq);
473 }
474 }
475
476 for (seq, _action) in &b_actions {
478 if !a_actions.contains_key(seq) {
479 comparison.actions_only_in_b.push(*seq);
480 }
481 }
482
483 for (seq, action_a) in &a_actions {
485 if let Some(action_b) = b_actions.get(seq) {
486 if action_a.action_type.name() != action_b.action_type.name() {
487 comparison.differing_actions.push(ActionDiff {
488 sequence: *seq,
489 action_type_a: action_a.action_type.name().to_string(),
490 action_type_b: action_b.action_type.name().to_string(),
491 description: format!(
492 "Action type mismatch: {} vs {}",
493 action_a.action_type.name(),
494 action_b.action_type.name()
495 ),
496 });
497 } else {
498 let same = match (&action_a.action_type, &action_b.action_type) {
500 (ActionType::Put { frame_id: a }, ActionType::Put { frame_id: b }) => a == b,
501 (
502 ActionType::Find {
503 query: qa,
504 result_count: ra,
505 ..
506 },
507 ActionType::Find {
508 query: qb,
509 result_count: rb,
510 ..
511 },
512 ) => qa == qb && ra == rb,
513 (
514 ActionType::Ask { query: qa, .. },
515 ActionType::Ask { query: qb, .. },
516 ) => qa == qb,
517 (
518 ActionType::Checkpoint { checkpoint_id: a },
519 ActionType::Checkpoint { checkpoint_id: b },
520 ) => a == b,
521 _ => false,
522 };
523
524 if same {
525 comparison.matching_actions += 1;
526 } else {
527 comparison.differing_actions.push(ActionDiff {
528 sequence: *seq,
529 action_type_a: action_a.action_type.name().to_string(),
530 action_type_b: action_b.action_type.name().to_string(),
531 description: "Action details differ".to_string(),
532 });
533 }
534 }
535 }
536 }
537
538 comparison
539 }
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize)]
544pub struct SessionComparison {
545 pub session_a_id: Uuid,
547 pub session_b_id: Uuid,
549 pub actions_only_in_a: Vec<u64>,
551 pub actions_only_in_b: Vec<u64>,
553 pub differing_actions: Vec<ActionDiff>,
555 pub matching_actions: usize,
557}
558
559impl SessionComparison {
560 pub fn is_identical(&self) -> bool {
562 self.actions_only_in_a.is_empty()
563 && self.actions_only_in_b.is_empty()
564 && self.differing_actions.is_empty()
565 }
566}
567
568#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct ActionDiff {
571 pub sequence: u64,
573 pub action_type_a: String,
575 pub action_type_b: String,
577 pub description: String,
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use crate::replay::types::ReplayAction;
585
586 #[test]
587 fn test_replay_result_success() {
588 let result = ReplayResult {
589 session_id: Uuid::new_v4(),
590 total_actions: 10,
591 matched_actions: 10,
592 mismatched_actions: 0,
593 skipped_actions: 0,
594 action_results: Vec::new(),
595 total_duration_ms: 100,
596 from_checkpoint: None,
597 };
598 assert!(result.is_success());
599 assert_eq!(result.match_rate(), 100.0);
600 }
601
602 #[test]
603 fn test_replay_result_partial() {
604 let result = ReplayResult {
605 session_id: Uuid::new_v4(),
606 total_actions: 10,
607 matched_actions: 7,
608 mismatched_actions: 3,
609 skipped_actions: 0,
610 action_results: Vec::new(),
611 total_duration_ms: 100,
612 from_checkpoint: None,
613 };
614 assert!(!result.is_success());
615 assert_eq!(result.match_rate(), 70.0);
616 }
617
618 #[test]
619 fn test_session_comparison_identical() {
620 use std::collections::HashMap;
621
622 let session_a = ReplaySession {
623 session_id: Uuid::new_v4(),
624 name: Some("A".to_string()),
625 created_secs: 0,
626 ended_secs: Some(100),
627 actions: vec![
628 ReplayAction::new(0, ActionType::Find {
629 query: "test".to_string(),
630 mode: "lex".to_string(),
631 result_count: 5,
632 }),
633 ],
634 checkpoints: Vec::new(),
635 metadata: HashMap::new(),
636 version: 1,
637 };
638
639 let session_b = ReplaySession {
640 session_id: Uuid::new_v4(),
641 name: Some("B".to_string()),
642 created_secs: 0,
643 ended_secs: Some(100),
644 actions: vec![
645 ReplayAction::new(0, ActionType::Find {
646 query: "test".to_string(),
647 mode: "lex".to_string(),
648 result_count: 5,
649 }),
650 ],
651 checkpoints: Vec::new(),
652 metadata: HashMap::new(),
653 version: 1,
654 };
655
656 let comparison = ReplayEngine::compare_sessions(&session_a, &session_b);
657 assert!(comparison.is_identical());
658 assert_eq!(comparison.matching_actions, 1);
659 }
660}