Skip to main content

briefcase_core/
replay.rs

1use crate::models::DecisionSnapshot;
2
3#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
4use crate::storage::StorageBackend;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7#[cfg(feature = "async")]
8use tokio;
9
10#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
11pub mod sync;
12#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
13pub use sync::SyncReplayEngine;
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub enum ReplayMode {
17    Strict,         // Exact byte-for-byte match required
18    Tolerant,       // Allow minor differences (whitespace, formatting)
19    ValidationOnly, // Validate inputs/context without re-executing
20}
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub enum ReplayStatus {
24    Pending,
25    Running,
26    Success,
27    Failed,
28    Partial, // Some decisions matched, some didn't
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReplayResult {
33    pub status: ReplayStatus,
34    pub original_snapshot: DecisionSnapshot,
35    pub replay_output: Option<serde_json::Value>,
36    pub outputs_match: bool,
37    pub diff: Option<SnapshotDiff>,
38    pub policy_violations: Vec<PolicyViolation>,
39    pub execution_time_ms: f64,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SnapshotDiff {
44    pub inputs_changed: bool,
45    pub outputs_changed: bool,
46    pub model_params_changed: bool,
47    pub execution_time_delta_ms: f64,
48    pub changes: Vec<FieldChange>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FieldChange {
53    pub field_path: String,
54    pub old_value: serde_json::Value,
55    pub new_value: serde_json::Value,
56    pub change_type: ChangeType,
57}
58
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub enum ChangeType {
61    Added,
62    Removed,
63    Modified,
64}
65
66#[derive(Debug, Clone)]
67pub struct ReplayPolicy {
68    pub name: String,
69    pub rules: Vec<ValidationRule>,
70}
71
72impl ReplayPolicy {
73    pub fn new(name: impl Into<String>) -> Self {
74        Self {
75            name: name.into(),
76            rules: Vec::new(),
77        }
78    }
79
80    pub fn add_rule(mut self, rule: ValidationRule) -> Self {
81        self.rules.push(rule);
82        self
83    }
84
85    pub fn with_exact_match(mut self, field: impl Into<String>) -> Self {
86        self.rules.push(ValidationRule {
87            field: field.into(),
88            comparator: Comparator::ExactMatch,
89            threshold: 1.0,
90        });
91        self
92    }
93
94    pub fn with_similarity_threshold(mut self, field: impl Into<String>, threshold: f64) -> Self {
95        self.rules.push(ValidationRule {
96            field: field.into(),
97            comparator: Comparator::SemanticSimilarity,
98            threshold,
99        });
100        self
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct ValidationRule {
106    pub field: String,
107    pub comparator: Comparator,
108    pub threshold: f64,
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum Comparator {
113    ExactMatch,
114    SemanticSimilarity,
115    MaxIncreasePercent,
116    MaxDecreasePercent,
117    WithinRange,
118}
119
120#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
121pub struct PolicyViolation {
122    pub rule_name: String,
123    pub field: String,
124    pub expected: String,
125    pub actual: String,
126    pub message: String,
127}
128
129#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
130#[derive(Clone)]
131pub struct ReplayEngine<S: StorageBackend> {
132    storage: S,
133    default_mode: ReplayMode,
134}
135
136#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
137impl<S: StorageBackend> ReplayEngine<S> {
138    pub fn new(storage: S) -> Self {
139        Self {
140            storage,
141            default_mode: ReplayMode::Tolerant,
142        }
143    }
144
145    pub fn with_mode(storage: S, mode: ReplayMode) -> Self {
146        Self {
147            storage,
148            default_mode: mode,
149        }
150    }
151
152    /// Get the default replay mode
153    pub fn default_mode(&self) -> &ReplayMode {
154        &self.default_mode
155    }
156
157    /// Replay a snapshot by ID
158    pub async fn replay(
159        &self,
160        snapshot_id: &str,
161        mode: Option<ReplayMode>,
162        _context_overrides: Option<std::collections::HashMap<String, serde_json::Value>>,
163    ) -> Result<ReplayResult, ReplayError> {
164        let start_time = std::time::Instant::now();
165        let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
166
167        // Load the original snapshot
168        let original_snapshot = match self.storage.load_decision(snapshot_id).await {
169            Ok(snapshot) => snapshot,
170            Err(e) => {
171                return Err(ReplayError::SnapshotNotFound(format!(
172                    "Failed to load snapshot {}: {}",
173                    snapshot_id, e
174                )))
175            }
176        };
177
178        match replay_mode {
179            ReplayMode::ValidationOnly => {
180                // Just validate the snapshot structure without re-executing
181                let execution_time = start_time.elapsed().as_millis() as f64;
182
183                Ok(ReplayResult {
184                    status: ReplayStatus::Success,
185                    original_snapshot,
186                    replay_output: None,
187                    outputs_match: true, // We're not re-executing, so assume match
188                    diff: None,
189                    policy_violations: Vec::new(),
190                    execution_time_ms: execution_time,
191                })
192            }
193            ReplayMode::Strict | ReplayMode::Tolerant => {
194                // For now, we simulate replay by comparing with the original
195                // In a real implementation, this would re-execute the function
196                self.simulate_replay(&original_snapshot, replay_mode, start_time)
197                    .await
198            }
199        }
200    }
201
202    /// Replay with policy validation
203    pub async fn replay_with_policy(
204        &self,
205        snapshot_id: &str,
206        policy: &ReplayPolicy,
207        mode: Option<ReplayMode>,
208    ) -> Result<ReplayResult, ReplayError> {
209        let mut result = self.replay(snapshot_id, mode, None).await?;
210
211        // Apply policy validation
212        let violations = self
213            .validate_against_policy(&result.original_snapshot, policy)
214            .await?;
215        result.policy_violations = violations;
216
217        if !result.policy_violations.is_empty() {
218            result.status = ReplayStatus::Failed;
219        }
220
221        Ok(result)
222    }
223
224    /// Compare two snapshots
225    pub async fn diff(&self, original_id: &str, new_id: &str) -> Result<SnapshotDiff, ReplayError> {
226        let original = self.storage.load_decision(original_id).await.map_err(|e| {
227            ReplayError::SnapshotNotFound(format!("Original snapshot not found: {}", e))
228        })?;
229
230        let new =
231            self.storage.load_decision(new_id).await.map_err(|e| {
232                ReplayError::SnapshotNotFound(format!("New snapshot not found: {}", e))
233            })?;
234
235        Ok(self.calculate_diff(&original, &new))
236    }
237
238    /// Validate a snapshot against a policy (without re-executing)
239    pub async fn validate(
240        &self,
241        snapshot_id: &str,
242        policy: &ReplayPolicy,
243    ) -> Result<Vec<PolicyViolation>, ReplayError> {
244        let snapshot = self
245            .storage
246            .load_decision(snapshot_id)
247            .await
248            .map_err(|e| ReplayError::SnapshotNotFound(e.to_string()))?;
249
250        self.validate_against_policy(&snapshot, policy).await
251    }
252
253    /// Batch replay multiple snapshots
254    pub async fn replay_batch(
255        &self,
256        snapshot_ids: &[String],
257        mode: Option<ReplayMode>,
258        concurrency: usize,
259    ) -> Vec<Result<ReplayResult, ReplayError>> {
260        let semaphore = tokio::sync::Semaphore::new(concurrency);
261        let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
262
263        let tasks: Vec<_> = snapshot_ids
264            .iter()
265            .map(|id| {
266                let id = id.clone();
267                let mode = replay_mode.clone();
268                let semaphore = &semaphore;
269                async move {
270                    let _permit = semaphore.acquire().await.unwrap();
271                    self.replay(&id, Some(mode), None).await
272                }
273            })
274            .collect();
275
276        futures::future::join_all(tasks).await
277    }
278
279    /// Get replay statistics for a set of snapshots
280    pub async fn get_replay_stats(
281        &self,
282        snapshot_ids: &[String],
283    ) -> Result<ReplayStats, ReplayError> {
284        let results = self.replay_batch(snapshot_ids, None, 4).await;
285
286        let mut stats = ReplayStats {
287            total_replays: results.len(),
288            ..Default::default()
289        };
290
291        for result in results {
292            match result {
293                Ok(replay_result) => {
294                    stats.successful_replays += 1;
295                    stats.total_execution_time_ms += replay_result.execution_time_ms;
296
297                    if replay_result.outputs_match {
298                        stats.exact_matches += 1;
299                    } else {
300                        stats.mismatches += 1;
301                    }
302                }
303                Err(_) => {
304                    stats.failed_replays += 1;
305                }
306            }
307        }
308
309        stats.average_execution_time_ms = if stats.successful_replays > 0 {
310            stats.total_execution_time_ms / stats.successful_replays as f64
311        } else {
312            0.0
313        };
314
315        Ok(stats)
316    }
317
318    // Private helper methods
319
320    async fn simulate_replay(
321        &self,
322        original: &DecisionSnapshot,
323        mode: ReplayMode,
324        start_time: std::time::Instant,
325    ) -> Result<ReplayResult, ReplayError> {
326        // In a real implementation, this would:
327        // 1. Recreate the execution environment
328        // 2. Call the original function with the same inputs
329        // 3. Compare outputs
330
331        // For simulation, we'll create a "replay" that matches the original
332        // Use the original execution time from the snapshot, not the replay elapsed time
333        let execution_time = original
334            .execution_time_ms
335            .unwrap_or_else(|| start_time.elapsed().as_millis() as f64);
336
337        // Simulate some processing time
338        tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
339
340        let outputs_match = match mode {
341            ReplayMode::Strict => true,   // In simulation, always match
342            ReplayMode::Tolerant => true, // In simulation, always match
343            ReplayMode::ValidationOnly => true,
344        };
345
346        Ok(ReplayResult {
347            status: if outputs_match {
348                ReplayStatus::Success
349            } else {
350                ReplayStatus::Failed
351            },
352            original_snapshot: original.clone(),
353            replay_output: original.outputs.first().map(|o| o.value.clone()),
354            outputs_match,
355            diff: None,
356            policy_violations: Vec::new(),
357            execution_time_ms: execution_time,
358        })
359    }
360
361    async fn validate_against_policy(
362        &self,
363        snapshot: &DecisionSnapshot,
364        policy: &ReplayPolicy,
365    ) -> Result<Vec<PolicyViolation>, ReplayError> {
366        let mut violations = Vec::new();
367
368        for rule in &policy.rules {
369            let violation = self.check_rule(snapshot, rule);
370            if let Some(v) = violation {
371                violations.push(v);
372            }
373        }
374
375        Ok(violations)
376    }
377
378    fn check_rule(
379        &self,
380        snapshot: &DecisionSnapshot,
381        rule: &ValidationRule,
382    ) -> Option<PolicyViolation> {
383        // Extract the field value from the snapshot
384        let _field_value = self.extract_field_value(snapshot, &rule.field)?;
385
386        match rule.comparator {
387            Comparator::ExactMatch => {
388                // For simulation, assume all exact matches pass
389                None
390            }
391            Comparator::SemanticSimilarity => {
392                // For simulation, assume semantic similarity passes if threshold > 0.5
393                if rule.threshold <= 0.5 {
394                    Some(PolicyViolation {
395                        rule_name: rule.field.clone(),
396                        field: rule.field.clone(),
397                        expected: format!("Similarity >= {}", rule.threshold),
398                        actual: "0.4".to_string(),
399                        message: format!(
400                            "Semantic similarity below threshold of {}",
401                            rule.threshold
402                        ),
403                    })
404                } else {
405                    None
406                }
407            }
408            _ => None, // Other comparators not implemented in simulation
409        }
410    }
411
412    fn extract_field_value(
413        &self,
414        snapshot: &DecisionSnapshot,
415        field_path: &str,
416    ) -> Option<serde_json::Value> {
417        // Simple field extraction - in a real implementation, this would be more sophisticated
418        match field_path {
419            "function_name" => Some(serde_json::Value::String(snapshot.function_name.clone())),
420            "execution_time_ms" => snapshot
421                .execution_time_ms
422                .map(|t| serde_json::Value::Number(serde_json::Number::from_f64(t).unwrap())),
423            "output" => {
424                // Return the first output if available
425                snapshot.outputs.first().map(|output| output.value.clone())
426            }
427            _ => None,
428        }
429    }
430
431    fn calculate_diff(&self, original: &DecisionSnapshot, new: &DecisionSnapshot) -> SnapshotDiff {
432        let mut changes = Vec::new();
433
434        // Compare function names
435        if original.function_name != new.function_name {
436            changes.push(FieldChange {
437                field_path: "function_name".to_string(),
438                old_value: serde_json::Value::String(original.function_name.clone()),
439                new_value: serde_json::Value::String(new.function_name.clone()),
440                change_type: ChangeType::Modified,
441            });
442        }
443
444        // Compare inputs
445        let inputs_changed = original.inputs != new.inputs;
446        if inputs_changed {
447            changes.push(FieldChange {
448                field_path: "inputs".to_string(),
449                old_value: serde_json::to_value(&original.inputs).unwrap(),
450                new_value: serde_json::to_value(&new.inputs).unwrap(),
451                change_type: ChangeType::Modified,
452            });
453        }
454
455        // Compare outputs
456        let outputs_changed = original.outputs != new.outputs;
457        if outputs_changed {
458            changes.push(FieldChange {
459                field_path: "outputs".to_string(),
460                old_value: serde_json::to_value(&original.outputs).unwrap(),
461                new_value: serde_json::to_value(&new.outputs).unwrap(),
462                change_type: ChangeType::Modified,
463            });
464        }
465
466        // Compare model parameters
467        let model_params_changed = original.model_parameters != new.model_parameters;
468        if model_params_changed {
469            changes.push(FieldChange {
470                field_path: "model_parameters".to_string(),
471                old_value: serde_json::to_value(&original.model_parameters).unwrap(),
472                new_value: serde_json::to_value(&new.model_parameters).unwrap(),
473                change_type: ChangeType::Modified,
474            });
475        }
476
477        let execution_time_delta_ms = match (original.execution_time_ms, new.execution_time_ms) {
478            (Some(old), Some(new)) => new - old,
479            _ => 0.0,
480        };
481
482        SnapshotDiff {
483            inputs_changed,
484            outputs_changed,
485            model_params_changed,
486            execution_time_delta_ms,
487            changes,
488        }
489    }
490}
491
492#[derive(Debug, Clone, Default)]
493pub struct ReplayStats {
494    pub total_replays: usize,
495    pub successful_replays: usize,
496    pub failed_replays: usize,
497    pub exact_matches: usize,
498    pub mismatches: usize,
499    pub total_execution_time_ms: f64,
500    pub average_execution_time_ms: f64,
501}
502
503#[derive(Error, Debug, Clone, PartialEq)]
504pub enum ReplayError {
505    #[error("Snapshot not found: {0}")]
506    SnapshotNotFound(String),
507    #[error("Storage error: {0}")]
508    StorageError(String),
509    #[error("Execution error: {0}")]
510    ExecutionError(String),
511    #[error("Policy violations: {0:?}")]
512    PolicyViolation(Vec<PolicyViolation>),
513}
514
515#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
516impl From<crate::storage::StorageError> for ReplayError {
517    fn from(err: crate::storage::StorageError) -> Self {
518        ReplayError::StorageError(err.to_string())
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use crate::models::*;
526    use crate::storage::SqliteBackend;
527    use serde_json::json;
528
529    async fn create_test_decision() -> DecisionSnapshot {
530        let input = Input::new("test_input", json!("hello"), "string");
531        let output = Output::new("test_output", json!("world"), "string");
532        let model_params = ModelParameters::new("gpt-4");
533
534        DecisionSnapshot::new("test_function")
535            .add_input(input)
536            .add_output(output)
537            .with_model_parameters(model_params)
538            .with_execution_time(100.0)
539    }
540
541    #[tokio::test]
542    async fn test_replay_engine_creation() {
543        let storage = SqliteBackend::in_memory().unwrap();
544        let engine = ReplayEngine::new(storage);
545        assert!(matches!(engine.default_mode, ReplayMode::Tolerant));
546    }
547
548    #[tokio::test]
549    async fn test_replay_validation_only() {
550        let storage = SqliteBackend::in_memory().unwrap();
551        let engine = ReplayEngine::new(storage);
552        let decision = create_test_decision().await;
553
554        // Save the decision first
555        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
556
557        // Replay in validation-only mode
558        let result = engine
559            .replay(&decision_id, Some(ReplayMode::ValidationOnly), None)
560            .await
561            .unwrap();
562
563        assert_eq!(result.status, ReplayStatus::Success);
564        assert!(result.outputs_match);
565        assert!(result.replay_output.is_none());
566    }
567
568    #[tokio::test]
569    async fn test_replay_tolerant_mode() {
570        let storage = SqliteBackend::in_memory().unwrap();
571        let engine = ReplayEngine::new(storage);
572        let decision = create_test_decision().await;
573
574        // Save the decision first
575        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
576
577        // Replay in tolerant mode
578        let result = engine
579            .replay(&decision_id, Some(ReplayMode::Tolerant), None)
580            .await
581            .unwrap();
582
583        assert_eq!(result.status, ReplayStatus::Success);
584        assert!(result.outputs_match);
585        assert!(result.replay_output.is_some());
586    }
587
588    #[tokio::test]
589    async fn test_replay_with_policy() {
590        let storage = SqliteBackend::in_memory().unwrap();
591        let engine = ReplayEngine::new(storage);
592        let decision = create_test_decision().await;
593
594        // Save the decision first
595        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
596
597        // Create a policy
598        let policy = ReplayPolicy::new("test_policy")
599            .with_exact_match("function_name")
600            .with_similarity_threshold("output", 0.9);
601
602        // Replay with policy
603        let result = engine
604            .replay_with_policy(&decision_id, &policy, None)
605            .await
606            .unwrap();
607
608        assert_eq!(result.status, ReplayStatus::Success);
609        assert!(result.policy_violations.is_empty());
610    }
611
612    #[tokio::test]
613    async fn test_diff_calculation() {
614        let storage = SqliteBackend::in_memory().unwrap();
615        let engine = ReplayEngine::new(storage);
616
617        let decision1 = create_test_decision().await;
618        let mut decision2 = create_test_decision().await;
619        decision2.function_name = "different_function".to_string();
620
621        // Save both decisions
622        let id1 = engine.storage.save_decision(&decision1).await.unwrap();
623        let id2 = engine.storage.save_decision(&decision2).await.unwrap();
624
625        // Calculate diff
626        let diff = engine.diff(&id1, &id2).await.unwrap();
627
628        assert!(!diff.changes.is_empty());
629        assert!(diff.changes.iter().any(|c| c.field_path == "function_name"));
630    }
631
632    #[tokio::test]
633    async fn test_batch_replay() {
634        let storage = SqliteBackend::in_memory().unwrap();
635        let engine = ReplayEngine::new(storage);
636
637        let mut snapshot_ids = Vec::new();
638
639        // Save multiple decisions
640        for i in 0..3 {
641            let mut decision = create_test_decision().await;
642            decision.function_name = format!("test_function_{}", i);
643            let id = engine.storage.save_decision(&decision).await.unwrap();
644            snapshot_ids.push(id);
645        }
646
647        // Batch replay
648        let results = engine.replay_batch(&snapshot_ids, None, 2).await;
649
650        assert_eq!(results.len(), 3);
651        assert!(results.iter().all(|r| r.is_ok()));
652    }
653
654    #[tokio::test]
655    async fn test_replay_stats() {
656        let storage = SqliteBackend::in_memory().unwrap();
657        let engine = ReplayEngine::new(storage);
658
659        let mut snapshot_ids = Vec::new();
660
661        // Save multiple decisions
662        for i in 0..5 {
663            let mut decision = create_test_decision().await;
664            decision.function_name = format!("test_function_{}", i);
665            let id = engine.storage.save_decision(&decision).await.unwrap();
666            snapshot_ids.push(id);
667        }
668
669        // Get stats
670        let stats = engine.get_replay_stats(&snapshot_ids).await.unwrap();
671
672        assert_eq!(stats.total_replays, 5);
673        assert_eq!(stats.successful_replays, 5);
674        assert_eq!(stats.failed_replays, 0);
675        assert!(stats.average_execution_time_ms > 0.0);
676    }
677
678    #[tokio::test]
679    async fn test_nonexistent_snapshot() {
680        let storage = SqliteBackend::in_memory().unwrap();
681        let engine = ReplayEngine::new(storage);
682
683        let result = engine.replay("nonexistent-id", None, None).await;
684        assert!(matches!(result, Err(ReplayError::SnapshotNotFound(_))));
685    }
686
687    #[tokio::test]
688    async fn test_policy_validation() {
689        let storage = SqliteBackend::in_memory().unwrap();
690        let engine = ReplayEngine::new(storage);
691        let decision = create_test_decision().await;
692
693        // Save the decision first
694        let decision_id = engine.storage.save_decision(&decision).await.unwrap();
695
696        // Create a policy that should fail
697        let policy = ReplayPolicy::new("strict_policy").with_similarity_threshold("output", 0.3); // Low threshold should trigger violation
698
699        let violations = engine.validate(&decision_id, &policy).await.unwrap();
700        assert!(!violations.is_empty());
701    }
702}