1use crate::models::DecisionSnapshot;
2
3#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
4use crate::storage::StorageBackend;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7#[cfg(feature = "async")]
8use tokio;
9
10#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
11pub mod sync;
12#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
13pub use sync::SyncReplayEngine;
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
16pub enum ReplayMode {
17 Strict, Tolerant, ValidationOnly, }
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub enum ReplayStatus {
24 Pending,
25 Running,
26 Success,
27 Failed,
28 Partial, }
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReplayResult {
33 pub status: ReplayStatus,
34 pub original_snapshot: DecisionSnapshot,
35 pub replay_output: Option<serde_json::Value>,
36 pub outputs_match: bool,
37 pub diff: Option<SnapshotDiff>,
38 pub policy_violations: Vec<PolicyViolation>,
39 pub execution_time_ms: f64,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SnapshotDiff {
44 pub inputs_changed: bool,
45 pub outputs_changed: bool,
46 pub model_params_changed: bool,
47 pub execution_time_delta_ms: f64,
48 pub changes: Vec<FieldChange>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct FieldChange {
53 pub field_path: String,
54 pub old_value: serde_json::Value,
55 pub new_value: serde_json::Value,
56 pub change_type: ChangeType,
57}
58
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub enum ChangeType {
61 Added,
62 Removed,
63 Modified,
64}
65
66#[derive(Debug, Clone)]
67pub struct ReplayPolicy {
68 pub name: String,
69 pub rules: Vec<ValidationRule>,
70}
71
72impl ReplayPolicy {
73 pub fn new(name: impl Into<String>) -> Self {
74 Self {
75 name: name.into(),
76 rules: Vec::new(),
77 }
78 }
79
80 pub fn add_rule(mut self, rule: ValidationRule) -> Self {
81 self.rules.push(rule);
82 self
83 }
84
85 pub fn with_exact_match(mut self, field: impl Into<String>) -> Self {
86 self.rules.push(ValidationRule {
87 field: field.into(),
88 comparator: Comparator::ExactMatch,
89 threshold: 1.0,
90 });
91 self
92 }
93
94 pub fn with_similarity_threshold(mut self, field: impl Into<String>, threshold: f64) -> Self {
95 self.rules.push(ValidationRule {
96 field: field.into(),
97 comparator: Comparator::SemanticSimilarity,
98 threshold,
99 });
100 self
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct ValidationRule {
106 pub field: String,
107 pub comparator: Comparator,
108 pub threshold: f64,
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum Comparator {
113 ExactMatch,
114 SemanticSimilarity,
115 MaxIncreasePercent,
116 MaxDecreasePercent,
117 WithinRange,
118}
119
120#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
121pub struct PolicyViolation {
122 pub rule_name: String,
123 pub field: String,
124 pub expected: String,
125 pub actual: String,
126 pub message: String,
127}
128
129#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
130#[derive(Clone)]
131pub struct ReplayEngine<S: StorageBackend> {
132 storage: S,
133 default_mode: ReplayMode,
134}
135
136#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
137impl<S: StorageBackend> ReplayEngine<S> {
138 pub fn new(storage: S) -> Self {
139 Self {
140 storage,
141 default_mode: ReplayMode::Tolerant,
142 }
143 }
144
145 pub fn with_mode(storage: S, mode: ReplayMode) -> Self {
146 Self {
147 storage,
148 default_mode: mode,
149 }
150 }
151
152 pub fn default_mode(&self) -> &ReplayMode {
154 &self.default_mode
155 }
156
157 pub async fn replay(
159 &self,
160 snapshot_id: &str,
161 mode: Option<ReplayMode>,
162 _context_overrides: Option<std::collections::HashMap<String, serde_json::Value>>,
163 ) -> Result<ReplayResult, ReplayError> {
164 let start_time = std::time::Instant::now();
165 let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
166
167 let original_snapshot = match self.storage.load_decision(snapshot_id).await {
169 Ok(snapshot) => snapshot,
170 Err(e) => {
171 return Err(ReplayError::SnapshotNotFound(format!(
172 "Failed to load snapshot {}: {}",
173 snapshot_id, e
174 )))
175 }
176 };
177
178 match replay_mode {
179 ReplayMode::ValidationOnly => {
180 let execution_time = start_time.elapsed().as_millis() as f64;
182
183 Ok(ReplayResult {
184 status: ReplayStatus::Success,
185 original_snapshot,
186 replay_output: None,
187 outputs_match: true, diff: None,
189 policy_violations: Vec::new(),
190 execution_time_ms: execution_time,
191 })
192 }
193 ReplayMode::Strict | ReplayMode::Tolerant => {
194 self.simulate_replay(&original_snapshot, replay_mode, start_time)
197 .await
198 }
199 }
200 }
201
202 pub async fn replay_with_policy(
204 &self,
205 snapshot_id: &str,
206 policy: &ReplayPolicy,
207 mode: Option<ReplayMode>,
208 ) -> Result<ReplayResult, ReplayError> {
209 let mut result = self.replay(snapshot_id, mode, None).await?;
210
211 let violations = self
213 .validate_against_policy(&result.original_snapshot, policy)
214 .await?;
215 result.policy_violations = violations;
216
217 if !result.policy_violations.is_empty() {
218 result.status = ReplayStatus::Failed;
219 }
220
221 Ok(result)
222 }
223
224 pub async fn diff(&self, original_id: &str, new_id: &str) -> Result<SnapshotDiff, ReplayError> {
226 let original = self.storage.load_decision(original_id).await.map_err(|e| {
227 ReplayError::SnapshotNotFound(format!("Original snapshot not found: {}", e))
228 })?;
229
230 let new =
231 self.storage.load_decision(new_id).await.map_err(|e| {
232 ReplayError::SnapshotNotFound(format!("New snapshot not found: {}", e))
233 })?;
234
235 Ok(self.calculate_diff(&original, &new))
236 }
237
238 pub async fn validate(
240 &self,
241 snapshot_id: &str,
242 policy: &ReplayPolicy,
243 ) -> Result<Vec<PolicyViolation>, ReplayError> {
244 let snapshot = self
245 .storage
246 .load_decision(snapshot_id)
247 .await
248 .map_err(|e| ReplayError::SnapshotNotFound(e.to_string()))?;
249
250 self.validate_against_policy(&snapshot, policy).await
251 }
252
253 pub async fn replay_batch(
255 &self,
256 snapshot_ids: &[String],
257 mode: Option<ReplayMode>,
258 concurrency: usize,
259 ) -> Vec<Result<ReplayResult, ReplayError>> {
260 let semaphore = tokio::sync::Semaphore::new(concurrency);
261 let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
262
263 let tasks: Vec<_> = snapshot_ids
264 .iter()
265 .map(|id| {
266 let id = id.clone();
267 let mode = replay_mode.clone();
268 let semaphore = &semaphore;
269 async move {
270 let _permit = semaphore.acquire().await.unwrap();
271 self.replay(&id, Some(mode), None).await
272 }
273 })
274 .collect();
275
276 futures::future::join_all(tasks).await
277 }
278
279 pub async fn get_replay_stats(
281 &self,
282 snapshot_ids: &[String],
283 ) -> Result<ReplayStats, ReplayError> {
284 let results = self.replay_batch(snapshot_ids, None, 4).await;
285
286 let mut stats = ReplayStats {
287 total_replays: results.len(),
288 ..Default::default()
289 };
290
291 for result in results {
292 match result {
293 Ok(replay_result) => {
294 stats.successful_replays += 1;
295 stats.total_execution_time_ms += replay_result.execution_time_ms;
296
297 if replay_result.outputs_match {
298 stats.exact_matches += 1;
299 } else {
300 stats.mismatches += 1;
301 }
302 }
303 Err(_) => {
304 stats.failed_replays += 1;
305 }
306 }
307 }
308
309 stats.average_execution_time_ms = if stats.successful_replays > 0 {
310 stats.total_execution_time_ms / stats.successful_replays as f64
311 } else {
312 0.0
313 };
314
315 Ok(stats)
316 }
317
318 async fn simulate_replay(
321 &self,
322 original: &DecisionSnapshot,
323 mode: ReplayMode,
324 start_time: std::time::Instant,
325 ) -> Result<ReplayResult, ReplayError> {
326 let execution_time = original
334 .execution_time_ms
335 .unwrap_or_else(|| start_time.elapsed().as_millis() as f64);
336
337 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
339
340 let outputs_match = match mode {
341 ReplayMode::Strict => true, ReplayMode::Tolerant => true, ReplayMode::ValidationOnly => true,
344 };
345
346 Ok(ReplayResult {
347 status: if outputs_match {
348 ReplayStatus::Success
349 } else {
350 ReplayStatus::Failed
351 },
352 original_snapshot: original.clone(),
353 replay_output: original.outputs.first().map(|o| o.value.clone()),
354 outputs_match,
355 diff: None,
356 policy_violations: Vec::new(),
357 execution_time_ms: execution_time,
358 })
359 }
360
361 async fn validate_against_policy(
362 &self,
363 snapshot: &DecisionSnapshot,
364 policy: &ReplayPolicy,
365 ) -> Result<Vec<PolicyViolation>, ReplayError> {
366 let mut violations = Vec::new();
367
368 for rule in &policy.rules {
369 let violation = self.check_rule(snapshot, rule);
370 if let Some(v) = violation {
371 violations.push(v);
372 }
373 }
374
375 Ok(violations)
376 }
377
378 fn check_rule(
379 &self,
380 snapshot: &DecisionSnapshot,
381 rule: &ValidationRule,
382 ) -> Option<PolicyViolation> {
383 let _field_value = self.extract_field_value(snapshot, &rule.field)?;
385
386 match rule.comparator {
387 Comparator::ExactMatch => {
388 None
390 }
391 Comparator::SemanticSimilarity => {
392 if rule.threshold <= 0.5 {
394 Some(PolicyViolation {
395 rule_name: rule.field.clone(),
396 field: rule.field.clone(),
397 expected: format!("Similarity >= {}", rule.threshold),
398 actual: "0.4".to_string(),
399 message: format!(
400 "Semantic similarity below threshold of {}",
401 rule.threshold
402 ),
403 })
404 } else {
405 None
406 }
407 }
408 _ => None, }
410 }
411
412 fn extract_field_value(
413 &self,
414 snapshot: &DecisionSnapshot,
415 field_path: &str,
416 ) -> Option<serde_json::Value> {
417 match field_path {
419 "function_name" => Some(serde_json::Value::String(snapshot.function_name.clone())),
420 "execution_time_ms" => snapshot
421 .execution_time_ms
422 .map(|t| serde_json::Value::Number(serde_json::Number::from_f64(t).unwrap())),
423 "output" => {
424 snapshot.outputs.first().map(|output| output.value.clone())
426 }
427 _ => None,
428 }
429 }
430
431 fn calculate_diff(&self, original: &DecisionSnapshot, new: &DecisionSnapshot) -> SnapshotDiff {
432 let mut changes = Vec::new();
433
434 if original.function_name != new.function_name {
436 changes.push(FieldChange {
437 field_path: "function_name".to_string(),
438 old_value: serde_json::Value::String(original.function_name.clone()),
439 new_value: serde_json::Value::String(new.function_name.clone()),
440 change_type: ChangeType::Modified,
441 });
442 }
443
444 let inputs_changed = original.inputs != new.inputs;
446 if inputs_changed {
447 changes.push(FieldChange {
448 field_path: "inputs".to_string(),
449 old_value: serde_json::to_value(&original.inputs).unwrap(),
450 new_value: serde_json::to_value(&new.inputs).unwrap(),
451 change_type: ChangeType::Modified,
452 });
453 }
454
455 let outputs_changed = original.outputs != new.outputs;
457 if outputs_changed {
458 changes.push(FieldChange {
459 field_path: "outputs".to_string(),
460 old_value: serde_json::to_value(&original.outputs).unwrap(),
461 new_value: serde_json::to_value(&new.outputs).unwrap(),
462 change_type: ChangeType::Modified,
463 });
464 }
465
466 let model_params_changed = original.model_parameters != new.model_parameters;
468 if model_params_changed {
469 changes.push(FieldChange {
470 field_path: "model_parameters".to_string(),
471 old_value: serde_json::to_value(&original.model_parameters).unwrap(),
472 new_value: serde_json::to_value(&new.model_parameters).unwrap(),
473 change_type: ChangeType::Modified,
474 });
475 }
476
477 let execution_time_delta_ms = match (original.execution_time_ms, new.execution_time_ms) {
478 (Some(old), Some(new)) => new - old,
479 _ => 0.0,
480 };
481
482 SnapshotDiff {
483 inputs_changed,
484 outputs_changed,
485 model_params_changed,
486 execution_time_delta_ms,
487 changes,
488 }
489 }
490}
491
492#[derive(Debug, Clone, Default)]
493pub struct ReplayStats {
494 pub total_replays: usize,
495 pub successful_replays: usize,
496 pub failed_replays: usize,
497 pub exact_matches: usize,
498 pub mismatches: usize,
499 pub total_execution_time_ms: f64,
500 pub average_execution_time_ms: f64,
501}
502
503#[derive(Error, Debug, Clone, PartialEq)]
504pub enum ReplayError {
505 #[error("Snapshot not found: {0}")]
506 SnapshotNotFound(String),
507 #[error("Storage error: {0}")]
508 StorageError(String),
509 #[error("Execution error: {0}")]
510 ExecutionError(String),
511 #[error("Policy violations: {0:?}")]
512 PolicyViolation(Vec<PolicyViolation>),
513}
514
515#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
516impl From<crate::storage::StorageError> for ReplayError {
517 fn from(err: crate::storage::StorageError) -> Self {
518 ReplayError::StorageError(err.to_string())
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use crate::models::*;
526 use crate::storage::SqliteBackend;
527 use serde_json::json;
528
529 async fn create_test_decision() -> DecisionSnapshot {
530 let input = Input::new("test_input", json!("hello"), "string");
531 let output = Output::new("test_output", json!("world"), "string");
532 let model_params = ModelParameters::new("gpt-4");
533
534 DecisionSnapshot::new("test_function")
535 .add_input(input)
536 .add_output(output)
537 .with_model_parameters(model_params)
538 .with_execution_time(100.0)
539 }
540
541 #[tokio::test]
542 async fn test_replay_engine_creation() {
543 let storage = SqliteBackend::in_memory().unwrap();
544 let engine = ReplayEngine::new(storage);
545 assert!(matches!(engine.default_mode, ReplayMode::Tolerant));
546 }
547
548 #[tokio::test]
549 async fn test_replay_validation_only() {
550 let storage = SqliteBackend::in_memory().unwrap();
551 let engine = ReplayEngine::new(storage);
552 let decision = create_test_decision().await;
553
554 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
556
557 let result = engine
559 .replay(&decision_id, Some(ReplayMode::ValidationOnly), None)
560 .await
561 .unwrap();
562
563 assert_eq!(result.status, ReplayStatus::Success);
564 assert!(result.outputs_match);
565 assert!(result.replay_output.is_none());
566 }
567
568 #[tokio::test]
569 async fn test_replay_tolerant_mode() {
570 let storage = SqliteBackend::in_memory().unwrap();
571 let engine = ReplayEngine::new(storage);
572 let decision = create_test_decision().await;
573
574 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
576
577 let result = engine
579 .replay(&decision_id, Some(ReplayMode::Tolerant), None)
580 .await
581 .unwrap();
582
583 assert_eq!(result.status, ReplayStatus::Success);
584 assert!(result.outputs_match);
585 assert!(result.replay_output.is_some());
586 }
587
588 #[tokio::test]
589 async fn test_replay_with_policy() {
590 let storage = SqliteBackend::in_memory().unwrap();
591 let engine = ReplayEngine::new(storage);
592 let decision = create_test_decision().await;
593
594 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
596
597 let policy = ReplayPolicy::new("test_policy")
599 .with_exact_match("function_name")
600 .with_similarity_threshold("output", 0.9);
601
602 let result = engine
604 .replay_with_policy(&decision_id, &policy, None)
605 .await
606 .unwrap();
607
608 assert_eq!(result.status, ReplayStatus::Success);
609 assert!(result.policy_violations.is_empty());
610 }
611
612 #[tokio::test]
613 async fn test_diff_calculation() {
614 let storage = SqliteBackend::in_memory().unwrap();
615 let engine = ReplayEngine::new(storage);
616
617 let decision1 = create_test_decision().await;
618 let mut decision2 = create_test_decision().await;
619 decision2.function_name = "different_function".to_string();
620
621 let id1 = engine.storage.save_decision(&decision1).await.unwrap();
623 let id2 = engine.storage.save_decision(&decision2).await.unwrap();
624
625 let diff = engine.diff(&id1, &id2).await.unwrap();
627
628 assert!(!diff.changes.is_empty());
629 assert!(diff.changes.iter().any(|c| c.field_path == "function_name"));
630 }
631
632 #[tokio::test]
633 async fn test_batch_replay() {
634 let storage = SqliteBackend::in_memory().unwrap();
635 let engine = ReplayEngine::new(storage);
636
637 let mut snapshot_ids = Vec::new();
638
639 for i in 0..3 {
641 let mut decision = create_test_decision().await;
642 decision.function_name = format!("test_function_{}", i);
643 let id = engine.storage.save_decision(&decision).await.unwrap();
644 snapshot_ids.push(id);
645 }
646
647 let results = engine.replay_batch(&snapshot_ids, None, 2).await;
649
650 assert_eq!(results.len(), 3);
651 assert!(results.iter().all(|r| r.is_ok()));
652 }
653
654 #[tokio::test]
655 async fn test_replay_stats() {
656 let storage = SqliteBackend::in_memory().unwrap();
657 let engine = ReplayEngine::new(storage);
658
659 let mut snapshot_ids = Vec::new();
660
661 for i in 0..5 {
663 let mut decision = create_test_decision().await;
664 decision.function_name = format!("test_function_{}", i);
665 let id = engine.storage.save_decision(&decision).await.unwrap();
666 snapshot_ids.push(id);
667 }
668
669 let stats = engine.get_replay_stats(&snapshot_ids).await.unwrap();
671
672 assert_eq!(stats.total_replays, 5);
673 assert_eq!(stats.successful_replays, 5);
674 assert_eq!(stats.failed_replays, 0);
675 assert!(stats.average_execution_time_ms > 0.0);
676 }
677
678 #[tokio::test]
679 async fn test_nonexistent_snapshot() {
680 let storage = SqliteBackend::in_memory().unwrap();
681 let engine = ReplayEngine::new(storage);
682
683 let result = engine.replay("nonexistent-id", None, None).await;
684 assert!(matches!(result, Err(ReplayError::SnapshotNotFound(_))));
685 }
686
687 #[tokio::test]
688 async fn test_policy_validation() {
689 let storage = SqliteBackend::in_memory().unwrap();
690 let engine = ReplayEngine::new(storage);
691 let decision = create_test_decision().await;
692
693 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
695
696 let policy = ReplayPolicy::new("strict_policy").with_similarity_threshold("output", 0.3); let violations = engine.validate(&decision_id, &policy).await.unwrap();
700 assert!(!violations.is_empty());
701 }
702}