Skip to main content

briefcase_core/
replay.rs

1use crate::models::DecisionSnapshot;
2
3#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
4use crate::storage::StorageBackend;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7#[cfg(feature = "async")]
8use tokio;
9
10pub mod executor;
11#[cfg(feature = "async")]
12pub use executor::ModelExecutor;
13pub use executor::{
14    ComparisonResult, ExecutionConfig, ExecutionResult, FieldComparison, SyncModelExecutor,
15};
16
17#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
18pub mod sync;
19#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
20pub use sync::SyncReplayEngine;
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub enum ReplayMode {
24    Strict,         // Exact byte-for-byte match required
25    Tolerant,       // Allow minor differences (whitespace, formatting)
26    ValidationOnly, // Validate inputs/context without re-executing
27}
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub enum ReplayStatus {
31    Pending,
32    Running,
33    Success,
34    Failed,
35    Partial, // Some decisions matched, some didn't
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ReplayResult {
40    pub status: ReplayStatus,
41    pub original_snapshot: DecisionSnapshot,
42    pub replay_output: Option<serde_json::Value>,
43    pub outputs_match: bool,
44    pub diff: Option<SnapshotDiff>,
45    pub policy_violations: Vec<PolicyViolation>,
46    pub execution_time_ms: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SnapshotDiff {
51    pub inputs_changed: bool,
52    pub outputs_changed: bool,
53    pub model_params_changed: bool,
54    pub execution_time_delta_ms: f64,
55    pub changes: Vec<FieldChange>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct FieldChange {
60    pub field_path: String,
61    pub old_value: serde_json::Value,
62    pub new_value: serde_json::Value,
63    pub change_type: ChangeType,
64}
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum ChangeType {
68    Added,
69    Removed,
70    Modified,
71}
72
73#[derive(Debug, Clone)]
74pub struct ReplayPolicy {
75    pub name: String,
76    pub rules: Vec<ValidationRule>,
77}
78
79impl ReplayPolicy {
80    pub fn new(name: impl Into<String>) -> Self {
81        Self {
82            name: name.into(),
83            rules: Vec::new(),
84        }
85    }
86
87    pub fn add_rule(mut self, rule: ValidationRule) -> Self {
88        self.rules.push(rule);
89        self
90    }
91
92    pub fn with_exact_match(mut self, field: impl Into<String>) -> Self {
93        self.rules.push(ValidationRule {
94            field: field.into(),
95            comparator: Comparator::ExactMatch,
96            threshold: 1.0,
97        });
98        self
99    }
100
101    pub fn with_similarity_threshold(mut self, field: impl Into<String>, threshold: f64) -> Self {
102        self.rules.push(ValidationRule {
103            field: field.into(),
104            comparator: Comparator::SemanticSimilarity,
105            threshold,
106        });
107        self
108    }
109}
110
111#[derive(Debug, Clone)]
112pub struct ValidationRule {
113    pub field: String,
114    pub comparator: Comparator,
115    pub threshold: f64,
116}
117
118#[derive(Debug, Clone, PartialEq)]
119pub enum Comparator {
120    ExactMatch,
121    SemanticSimilarity,
122    MaxIncreasePercent,
123    MaxDecreasePercent,
124    WithinRange,
125}
126
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128pub struct PolicyViolation {
129    pub rule_name: String,
130    pub field: String,
131    pub expected: String,
132    pub actual: String,
133    pub message: String,
134}
135
136#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
137#[derive(Clone)]
138pub struct ReplayEngine<S: StorageBackend> {
139    storage: S,
140    default_mode: ReplayMode,
141    #[cfg(feature = "async")]
142    executor: Option<std::sync::Arc<dyn ModelExecutor>>,
143}
144
145#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
146impl<S: StorageBackend> ReplayEngine<S> {
147    pub fn new(storage: S) -> Self {
148        Self {
149            storage,
150            default_mode: ReplayMode::Tolerant,
151            #[cfg(feature = "async")]
152            executor: None,
153        }
154    }
155
156    pub fn with_mode(storage: S, mode: ReplayMode) -> Self {
157        Self {
158            storage,
159            default_mode: mode,
160            #[cfg(feature = "async")]
161            executor: None,
162        }
163    }
164
165    /// Set a model executor for the replay engine
166    #[cfg(feature = "async")]
167    pub fn with_executor(mut self, executor: std::sync::Arc<dyn ModelExecutor>) -> Self {
168        self.executor = Some(executor);
169        self
170    }
171
172    /// Get a reference to the current executor, if any
173    #[cfg(feature = "async")]
174    pub fn executor(&self) -> Option<&dyn ModelExecutor> {
175        self.executor.as_ref().map(|arc| arc.as_ref())
176    }
177
178    /// Get the default replay mode
179    pub fn default_mode(&self) -> &ReplayMode {
180        &self.default_mode
181    }
182
183    /// Replay a snapshot by ID
184    pub async fn replay(
185        &self,
186        snapshot_id: &str,
187        mode: Option<ReplayMode>,
188        _context_overrides: Option<std::collections::HashMap<String, serde_json::Value>>,
189    ) -> Result<ReplayResult, ReplayError> {
190        let start_time = std::time::Instant::now();
191        let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
192
193        // Load the original snapshot
194        let original_snapshot = match self.storage.load_decision(snapshot_id).await {
195            Ok(snapshot) => snapshot,
196            Err(e) => {
197                return Err(ReplayError::SnapshotNotFound(format!(
198                    "Failed to load snapshot {}: {}",
199                    snapshot_id, e
200                )))
201            }
202        };
203
204        match replay_mode {
205            ReplayMode::ValidationOnly => {
206                // Just validate the snapshot structure without re-executing
207                let execution_time = start_time.elapsed().as_millis() as f64;
208
209                Ok(ReplayResult {
210                    status: ReplayStatus::Success,
211                    original_snapshot,
212                    replay_output: None,
213                    outputs_match: true, // We're not re-executing, so assume match
214                    diff: None,
215                    policy_violations: Vec::new(),
216                    execution_time_ms: execution_time,
217                })
218            }
219            ReplayMode::Strict | ReplayMode::Tolerant => {
220                // Use the executor if available, otherwise fall back to simulation
221                #[cfg(feature = "async")]
222                {
223                    self.execute_replay(&original_snapshot, replay_mode, start_time)
224                        .await
225                }
226                #[cfg(not(feature = "async"))]
227                {
228                    self.simulate_replay(&original_snapshot, replay_mode, start_time)
229                        .await
230                }
231            }
232        }
233    }
234
235    /// Replay with policy validation
236    pub async fn replay_with_policy(
237        &self,
238        snapshot_id: &str,
239        policy: &ReplayPolicy,
240        mode: Option<ReplayMode>,
241    ) -> Result<ReplayResult, ReplayError> {
242        let mut result = self.replay(snapshot_id, mode, None).await?;
243
244        // Apply policy validation
245        let violations = self
246            .validate_against_policy(&result.original_snapshot, policy)
247            .await?;
248        result.policy_violations = violations;
249
250        if !result.policy_violations.is_empty() {
251            result.status = ReplayStatus::Failed;
252        }
253
254        Ok(result)
255    }
256
257    /// Compare two snapshots
258    pub async fn diff(&self, original_id: &str, new_id: &str) -> Result<SnapshotDiff, ReplayError> {
259        let original = self.storage.load_decision(original_id).await.map_err(|e| {
260            ReplayError::SnapshotNotFound(format!("Original snapshot not found: {}", e))
261        })?;
262
263        let new =
264            self.storage.load_decision(new_id).await.map_err(|e| {
265                ReplayError::SnapshotNotFound(format!("New snapshot not found: {}", e))
266            })?;
267
268        Ok(self.calculate_diff(&original, &new))
269    }
270
271    /// Validate a snapshot against a policy (without re-executing)
272    pub async fn validate(
273        &self,
274        snapshot_id: &str,
275        policy: &ReplayPolicy,
276    ) -> Result<Vec<PolicyViolation>, ReplayError> {
277        let snapshot = self
278            .storage
279            .load_decision(snapshot_id)
280            .await
281            .map_err(|e| ReplayError::SnapshotNotFound(e.to_string()))?;
282
283        self.validate_against_policy(&snapshot, policy).await
284    }
285
286    /// Batch replay multiple snapshots
287    pub async fn replay_batch(
288        &self,
289        snapshot_ids: &[String],
290        mode: Option<ReplayMode>,
291        concurrency: usize,
292    ) -> Vec<Result<ReplayResult, ReplayError>> {
293        let semaphore = tokio::sync::Semaphore::new(concurrency);
294        let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
295
296        let tasks: Vec<_> = snapshot_ids
297            .iter()
298            .map(|id| {
299                let id = id.clone();
300                let mode = replay_mode.clone();
301                let semaphore = &semaphore;
302                async move {
303                    let _permit = semaphore.acquire().await.unwrap();
304                    self.replay(&id, Some(mode), None).await
305                }
306            })
307            .collect();
308
309        futures::future::join_all(tasks).await
310    }
311
312    /// Get replay statistics for a set of snapshots
313    pub async fn get_replay_stats(
314        &self,
315        snapshot_ids: &[String],
316    ) -> Result<ReplayStats, ReplayError> {
317        let results = self.replay_batch(snapshot_ids, None, 4).await;
318
319        let mut stats = ReplayStats {
320            total_replays: results.len(),
321            ..Default::default()
322        };
323
324        for result in results {
325            match result {
326                Ok(replay_result) => {
327                    stats.successful_replays += 1;
328                    stats.total_execution_time_ms += replay_result.execution_time_ms;
329
330                    if replay_result.outputs_match {
331                        stats.exact_matches += 1;
332                    } else {
333                        stats.mismatches += 1;
334                    }
335                }
336                Err(_) => {
337                    stats.failed_replays += 1;
338                }
339            }
340        }
341
342        stats.average_execution_time_ms = if stats.successful_replays > 0 {
343            stats.total_execution_time_ms / stats.successful_replays as f64
344        } else {
345            0.0
346        };
347
348        Ok(stats)
349    }
350
351    // Private helper methods
352
353    #[cfg(feature = "async")]
354    async fn execute_replay(
355        &self,
356        original: &DecisionSnapshot,
357        mode: ReplayMode,
358        start_time: std::time::Instant,
359    ) -> Result<ReplayResult, ReplayError> {
360        if let Some(ref executor) = self.executor {
361            // Check model support
362            if let Some(ref params) = original.model_parameters {
363                if !executor.supports_model(&params.model_name) {
364                    return Err(ReplayError::ExecutionError(format!(
365                        "Executor '{}' does not support model '{}'",
366                        executor.executor_name(),
367                        params.model_name
368                    )));
369                }
370            }
371
372            // Execute with the configured default config (could be customized per replay)
373            let config = ExecutionConfig::default();
374            let exec_result = executor
375                .execute(
376                    &original.inputs,
377                    original.model_parameters.as_ref(),
378                    &original.context,
379                    &config,
380                )
381                .await?;
382
383            let execution_time = start_time.elapsed().as_millis() as f64;
384
385            // Compare outputs based on mode
386            let tolerance = match mode {
387                ReplayMode::Strict => 1.0,         // Exact match
388                ReplayMode::Tolerant => 0.8,       // 80% similarity
389                ReplayMode::ValidationOnly => 0.0, // Don't care about outputs
390            };
391
392            let comparison =
393                executor.compare_outputs(&original.outputs, &exec_result.outputs, tolerance);
394
395            let replay_output = serde_json::to_value(&exec_result.outputs).ok();
396
397            Ok(ReplayResult {
398                status: if comparison.is_match {
399                    ReplayStatus::Success
400                } else {
401                    ReplayStatus::Failed
402                },
403                original_snapshot: original.clone(),
404                replay_output,
405                outputs_match: comparison.is_match,
406                diff: Some(SnapshotDiff {
407                    inputs_changed: false,
408                    outputs_changed: !comparison.is_match,
409                    model_params_changed: false,
410                    execution_time_delta_ms: execution_time
411                        - original.execution_time_ms.unwrap_or(0.0),
412                    changes: comparison
413                        .field_comparisons
414                        .iter()
415                        .filter(|c| !c.is_match)
416                        .map(|c| FieldChange {
417                            field_path: format!("output.{}", c.field_name),
418                            old_value: c.original_value.clone(),
419                            new_value: c.replayed_value.clone(),
420                            change_type: ChangeType::Modified,
421                        })
422                        .collect(),
423                }),
424                policy_violations: Vec::new(),
425                execution_time_ms: execution_time,
426            })
427        } else {
428            // Fall back to simulation if no executor is set
429            self.simulate_replay(original, mode, start_time).await
430        }
431    }
432
433    async fn simulate_replay(
434        &self,
435        original: &DecisionSnapshot,
436        mode: ReplayMode,
437        start_time: std::time::Instant,
438    ) -> Result<ReplayResult, ReplayError> {
439        // In a real implementation, this would:
440        // 1. Recreate the execution environment
441        // 2. Call the original function with the same inputs
442        // 3. Compare outputs
443
444        // For simulation, we'll create a "replay" that matches the original
445        // Use the original execution time from the snapshot, not the replay elapsed time
446        let execution_time = original
447            .execution_time_ms
448            .unwrap_or_else(|| start_time.elapsed().as_millis() as f64);
449
450        // Simulate some processing time
451        tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
452
453        let outputs_match = match mode {
454            ReplayMode::Strict => true,   // In simulation, always match
455            ReplayMode::Tolerant => true, // In simulation, always match
456            ReplayMode::ValidationOnly => true,
457        };
458
459        Ok(ReplayResult {
460            status: if outputs_match {
461                ReplayStatus::Success
462            } else {
463                ReplayStatus::Failed
464            },
465            original_snapshot: original.clone(),
466            replay_output: original.outputs.first().map(|o| o.value.clone()),
467            outputs_match,
468            diff: None,
469            policy_violations: Vec::new(),
470            execution_time_ms: execution_time,
471        })
472    }
473
474    async fn validate_against_policy(
475        &self,
476        snapshot: &DecisionSnapshot,
477        policy: &ReplayPolicy,
478    ) -> Result<Vec<PolicyViolation>, ReplayError> {
479        let mut violations = Vec::new();
480
481        for rule in &policy.rules {
482            let violation = self.check_rule(snapshot, rule);
483            if let Some(v) = violation {
484                violations.push(v);
485            }
486        }
487
488        Ok(violations)
489    }
490
491    fn check_rule(
492        &self,
493        snapshot: &DecisionSnapshot,
494        rule: &ValidationRule,
495    ) -> Option<PolicyViolation> {
496        // Extract the field value from the snapshot
497        let _field_value = self.extract_field_value(snapshot, &rule.field)?;
498
499        match rule.comparator {
500            Comparator::ExactMatch => {
501                // For simulation, assume all exact matches pass
502                None
503            }
504            Comparator::SemanticSimilarity => {
505                // For simulation, assume semantic similarity passes if threshold > 0.5
506                if rule.threshold <= 0.5 {
507                    Some(PolicyViolation {
508                        rule_name: rule.field.clone(),
509                        field: rule.field.clone(),
510                        expected: format!("Similarity >= {}", rule.threshold),
511                        actual: "0.4".to_string(),
512                        message: format!(
513                            "Semantic similarity below threshold of {}",
514                            rule.threshold
515                        ),
516                    })
517                } else {
518                    None
519                }
520            }
521            _ => None, // Other comparators not implemented in simulation
522        }
523    }
524
525    fn extract_field_value(
526        &self,
527        snapshot: &DecisionSnapshot,
528        field_path: &str,
529    ) -> Option<serde_json::Value> {
530        // Simple field extraction - in a real implementation, this would be more sophisticated
531        match field_path {
532            "function_name" => Some(serde_json::Value::String(snapshot.function_name.clone())),
533            "execution_time_ms" => snapshot
534                .execution_time_ms
535                .map(|t| serde_json::Value::Number(serde_json::Number::from_f64(t).unwrap())),
536            "output" => {
537                // Return the first output if available
538                snapshot.outputs.first().map(|output| output.value.clone())
539            }
540            _ => None,
541        }
542    }
543
544    fn calculate_diff(&self, original: &DecisionSnapshot, new: &DecisionSnapshot) -> SnapshotDiff {
545        let mut changes = Vec::new();
546
547        // Compare function names
548        if original.function_name != new.function_name {
549            changes.push(FieldChange {
550                field_path: "function_name".to_string(),
551                old_value: serde_json::Value::String(original.function_name.clone()),
552                new_value: serde_json::Value::String(new.function_name.clone()),
553                change_type: ChangeType::Modified,
554            });
555        }
556
557        // Compare inputs
558        let inputs_changed = original.inputs != new.inputs;
559        if inputs_changed {
560            changes.push(FieldChange {
561                field_path: "inputs".to_string(),
562                old_value: serde_json::to_value(&original.inputs).unwrap(),
563                new_value: serde_json::to_value(&new.inputs).unwrap(),
564                change_type: ChangeType::Modified,
565            });
566        }
567
568        // Compare outputs
569        let outputs_changed = original.outputs != new.outputs;
570        if outputs_changed {
571            changes.push(FieldChange {
572                field_path: "outputs".to_string(),
573                old_value: serde_json::to_value(&original.outputs).unwrap(),
574                new_value: serde_json::to_value(&new.outputs).unwrap(),
575                change_type: ChangeType::Modified,
576            });
577        }
578
579        // Compare model parameters
580        let model_params_changed = original.model_parameters != new.model_parameters;
581        if model_params_changed {
582            changes.push(FieldChange {
583                field_path: "model_parameters".to_string(),
584                old_value: serde_json::to_value(&original.model_parameters).unwrap(),
585                new_value: serde_json::to_value(&new.model_parameters).unwrap(),
586                change_type: ChangeType::Modified,
587            });
588        }
589
590        let execution_time_delta_ms = match (original.execution_time_ms, new.execution_time_ms) {
591            (Some(old), Some(new)) => new - old,
592            _ => 0.0,
593        };
594
595        SnapshotDiff {
596            inputs_changed,
597            outputs_changed,
598            model_params_changed,
599            execution_time_delta_ms,
600            changes,
601        }
602    }
603}
604
605#[derive(Debug, Clone, Default)]
606pub struct ReplayStats {
607    pub total_replays: usize,
608    pub successful_replays: usize,
609    pub failed_replays: usize,
610    pub exact_matches: usize,
611    pub mismatches: usize,
612    pub total_execution_time_ms: f64,
613    pub average_execution_time_ms: f64,
614}
615
616#[derive(Error, Debug, Clone, PartialEq)]
617pub enum ReplayError {
618    #[error("Snapshot not found: {0}")]
619    SnapshotNotFound(String),
620    #[error("Storage error: {0}")]
621    StorageError(String),
622    #[error("Execution error: {0}")]
623    ExecutionError(String),
624    #[error("Policy violations: {0:?}")]
625    PolicyViolation(Vec<PolicyViolation>),
626}
627
628#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
629impl From<crate::storage::StorageError> for ReplayError {
630    fn from(err: crate::storage::StorageError) -> Self {
631        ReplayError::StorageError(err.to_string())
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use crate::models::*;
639    use crate::storage::SqliteBackend;
640    use serde_json::json;
641
642    async fn create_test_decision() -> DecisionSnapshot {
643        let input = Input::new("test_input", json!("hello"), "string");
644        let output = Output::new("test_output", json!("world"), "string");
645        let model_params = ModelParameters::new("gpt-4");
646
647        DecisionSnapshot::new("test_function")
648            .add_input(input)
649            .add_output(output)
650            .with_model_parameters(model_params)
651            .with_execution_time(100.0)
652    }
653
654    #[tokio::test]
655    async fn test_replay_engine_creation() {
656        let storage = SqliteBackend::in_memory().unwrap();
657        let engine = ReplayEngine::new(storage);
658        assert!(matches!(engine.default_mode, ReplayMode::Tolerant));
659    }
660
661    #[tokio::test]
662    async fn test_replay_validation_only() {
663        let storage = SqliteBackend::in_memory().unwrap();
664        let engine = ReplayEngine::new(storage);
665        let decision = create_test_decision().await;
666
667        // Save the decision first
668        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
669
670        // Replay in validation-only mode
671        let result = engine
672            .replay(&decision_id, Some(ReplayMode::ValidationOnly), None)
673            .await
674            .unwrap();
675
676        assert_eq!(result.status, ReplayStatus::Success);
677        assert!(result.outputs_match);
678        assert!(result.replay_output.is_none());
679    }
680
681    #[tokio::test]
682    async fn test_replay_tolerant_mode() {
683        let storage = SqliteBackend::in_memory().unwrap();
684        let engine = ReplayEngine::new(storage);
685        let decision = create_test_decision().await;
686
687        // Save the decision first
688        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
689
690        // Replay in tolerant mode
691        let result = engine
692            .replay(&decision_id, Some(ReplayMode::Tolerant), None)
693            .await
694            .unwrap();
695
696        assert_eq!(result.status, ReplayStatus::Success);
697        assert!(result.outputs_match);
698        assert!(result.replay_output.is_some());
699    }
700
701    #[tokio::test]
702    async fn test_replay_with_policy() {
703        let storage = SqliteBackend::in_memory().unwrap();
704        let engine = ReplayEngine::new(storage);
705        let decision = create_test_decision().await;
706
707        // Save the decision first
708        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
709
710        // Create a policy
711        let policy = ReplayPolicy::new("test_policy")
712            .with_exact_match("function_name")
713            .with_similarity_threshold("output", 0.9);
714
715        // Replay with policy
716        let result = engine
717            .replay_with_policy(&decision_id, &policy, None)
718            .await
719            .unwrap();
720
721        assert_eq!(result.status, ReplayStatus::Success);
722        assert!(result.policy_violations.is_empty());
723    }
724
725    #[tokio::test]
726    async fn test_diff_calculation() {
727        let storage = SqliteBackend::in_memory().unwrap();
728        let engine = ReplayEngine::new(storage);
729
730        let decision1 = create_test_decision().await;
731        let mut decision2 = create_test_decision().await;
732        decision2.function_name = "different_function".to_string();
733
734        // Save both decisions
735        let id1 = engine.storage.save_decision(&decision1).await.unwrap();
736        let id2 = engine.storage.save_decision(&decision2).await.unwrap();
737
738        // Calculate diff
739        let diff = engine.diff(&id1, &id2).await.unwrap();
740
741        assert!(!diff.changes.is_empty());
742        assert!(diff.changes.iter().any(|c| c.field_path == "function_name"));
743    }
744
745    #[tokio::test]
746    async fn test_batch_replay() {
747        let storage = SqliteBackend::in_memory().unwrap();
748        let engine = ReplayEngine::new(storage);
749
750        let mut snapshot_ids = Vec::new();
751
752        // Save multiple decisions
753        for i in 0..3 {
754            let mut decision = create_test_decision().await;
755            decision.function_name = format!("test_function_{}", i);
756            let id = engine.storage.save_decision(&decision).await.unwrap();
757            snapshot_ids.push(id);
758        }
759
760        // Batch replay
761        let results = engine.replay_batch(&snapshot_ids, None, 2).await;
762
763        assert_eq!(results.len(), 3);
764        assert!(results.iter().all(|r| r.is_ok()));
765    }
766
767    #[tokio::test]
768    async fn test_replay_stats() {
769        let storage = SqliteBackend::in_memory().unwrap();
770        let engine = ReplayEngine::new(storage);
771
772        let mut snapshot_ids = Vec::new();
773
774        // Save multiple decisions
775        for i in 0..5 {
776            let mut decision = create_test_decision().await;
777            decision.function_name = format!("test_function_{}", i);
778            let id = engine.storage.save_decision(&decision).await.unwrap();
779            snapshot_ids.push(id);
780        }
781
782        // Get stats
783        let stats = engine.get_replay_stats(&snapshot_ids).await.unwrap();
784
785        assert_eq!(stats.total_replays, 5);
786        assert_eq!(stats.successful_replays, 5);
787        assert_eq!(stats.failed_replays, 0);
788        assert!(stats.average_execution_time_ms > 0.0);
789    }
790
791    #[tokio::test]
792    async fn test_nonexistent_snapshot() {
793        let storage = SqliteBackend::in_memory().unwrap();
794        let engine = ReplayEngine::new(storage);
795
796        let result = engine.replay("nonexistent-id", None, None).await;
797        assert!(matches!(result, Err(ReplayError::SnapshotNotFound(_))));
798    }
799
800    #[tokio::test]
801    async fn test_policy_validation() {
802        let storage = SqliteBackend::in_memory().unwrap();
803        let engine = ReplayEngine::new(storage);
804        let decision = create_test_decision().await;
805
806        // Save the decision first
807        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
808
809        // Create a policy that should fail
810        let policy = ReplayPolicy::new("strict_policy").with_similarity_threshold("output", 0.3); // Low threshold should trigger violation
811
812        let violations = engine.validate(&decision_id, &policy).await.unwrap();
813        assert!(!violations.is_empty());
814    }
815}