1use crate::models::DecisionSnapshot;
2
3#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
4use crate::storage::StorageBackend;
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7#[cfg(feature = "async")]
8use tokio;
9
10pub mod executor;
11#[cfg(feature = "async")]
12pub use executor::ModelExecutor;
13pub use executor::{
14 ComparisonResult, ExecutionConfig, ExecutionResult, FieldComparison, SyncModelExecutor,
15};
16
17#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
18pub mod sync;
19#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
20pub use sync::SyncReplayEngine;
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub enum ReplayMode {
24 Strict, Tolerant, ValidationOnly, }
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
30pub enum ReplayStatus {
31 Pending,
32 Running,
33 Success,
34 Failed,
35 Partial, }
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ReplayResult {
40 pub status: ReplayStatus,
41 pub original_snapshot: DecisionSnapshot,
42 pub replay_output: Option<serde_json::Value>,
43 pub outputs_match: bool,
44 pub diff: Option<SnapshotDiff>,
45 pub policy_violations: Vec<PolicyViolation>,
46 pub execution_time_ms: f64,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SnapshotDiff {
51 pub inputs_changed: bool,
52 pub outputs_changed: bool,
53 pub model_params_changed: bool,
54 pub execution_time_delta_ms: f64,
55 pub changes: Vec<FieldChange>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct FieldChange {
60 pub field_path: String,
61 pub old_value: serde_json::Value,
62 pub new_value: serde_json::Value,
63 pub change_type: ChangeType,
64}
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum ChangeType {
68 Added,
69 Removed,
70 Modified,
71}
72
73#[derive(Debug, Clone)]
74pub struct ReplayPolicy {
75 pub name: String,
76 pub rules: Vec<ValidationRule>,
77}
78
79impl ReplayPolicy {
80 pub fn new(name: impl Into<String>) -> Self {
81 Self {
82 name: name.into(),
83 rules: Vec::new(),
84 }
85 }
86
87 pub fn add_rule(mut self, rule: ValidationRule) -> Self {
88 self.rules.push(rule);
89 self
90 }
91
92 pub fn with_exact_match(mut self, field: impl Into<String>) -> Self {
93 self.rules.push(ValidationRule {
94 field: field.into(),
95 comparator: Comparator::ExactMatch,
96 threshold: 1.0,
97 });
98 self
99 }
100
101 pub fn with_similarity_threshold(mut self, field: impl Into<String>, threshold: f64) -> Self {
102 self.rules.push(ValidationRule {
103 field: field.into(),
104 comparator: Comparator::SemanticSimilarity,
105 threshold,
106 });
107 self
108 }
109}
110
111#[derive(Debug, Clone)]
112pub struct ValidationRule {
113 pub field: String,
114 pub comparator: Comparator,
115 pub threshold: f64,
116}
117
118#[derive(Debug, Clone, PartialEq)]
119pub enum Comparator {
120 ExactMatch,
121 SemanticSimilarity,
122 MaxIncreasePercent,
123 MaxDecreasePercent,
124 WithinRange,
125}
126
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128pub struct PolicyViolation {
129 pub rule_name: String,
130 pub field: String,
131 pub expected: String,
132 pub actual: String,
133 pub message: String,
134}
135
136#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
137#[derive(Clone)]
138pub struct ReplayEngine<S: StorageBackend> {
139 storage: S,
140 default_mode: ReplayMode,
141 #[cfg(feature = "async")]
142 executor: Option<std::sync::Arc<dyn ModelExecutor>>,
143}
144
145#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
146impl<S: StorageBackend> ReplayEngine<S> {
147 pub fn new(storage: S) -> Self {
148 Self {
149 storage,
150 default_mode: ReplayMode::Tolerant,
151 #[cfg(feature = "async")]
152 executor: None,
153 }
154 }
155
156 pub fn with_mode(storage: S, mode: ReplayMode) -> Self {
157 Self {
158 storage,
159 default_mode: mode,
160 #[cfg(feature = "async")]
161 executor: None,
162 }
163 }
164
165 #[cfg(feature = "async")]
167 pub fn with_executor(mut self, executor: std::sync::Arc<dyn ModelExecutor>) -> Self {
168 self.executor = Some(executor);
169 self
170 }
171
172 #[cfg(feature = "async")]
174 pub fn executor(&self) -> Option<&dyn ModelExecutor> {
175 self.executor.as_ref().map(|arc| arc.as_ref())
176 }
177
178 pub fn default_mode(&self) -> &ReplayMode {
180 &self.default_mode
181 }
182
183 pub async fn replay(
185 &self,
186 snapshot_id: &str,
187 mode: Option<ReplayMode>,
188 _context_overrides: Option<std::collections::HashMap<String, serde_json::Value>>,
189 ) -> Result<ReplayResult, ReplayError> {
190 let start_time = std::time::Instant::now();
191 let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
192
193 let original_snapshot = match self.storage.load_decision(snapshot_id).await {
195 Ok(snapshot) => snapshot,
196 Err(e) => {
197 return Err(ReplayError::SnapshotNotFound(format!(
198 "Failed to load snapshot {}: {}",
199 snapshot_id, e
200 )))
201 }
202 };
203
204 match replay_mode {
205 ReplayMode::ValidationOnly => {
206 let execution_time = start_time.elapsed().as_millis() as f64;
208
209 Ok(ReplayResult {
210 status: ReplayStatus::Success,
211 original_snapshot,
212 replay_output: None,
213 outputs_match: true, diff: None,
215 policy_violations: Vec::new(),
216 execution_time_ms: execution_time,
217 })
218 }
219 ReplayMode::Strict | ReplayMode::Tolerant => {
220 #[cfg(feature = "async")]
222 {
223 self.execute_replay(&original_snapshot, replay_mode, start_time)
224 .await
225 }
226 #[cfg(not(feature = "async"))]
227 {
228 self.simulate_replay(&original_snapshot, replay_mode, start_time)
229 .await
230 }
231 }
232 }
233 }
234
235 pub async fn replay_with_policy(
237 &self,
238 snapshot_id: &str,
239 policy: &ReplayPolicy,
240 mode: Option<ReplayMode>,
241 ) -> Result<ReplayResult, ReplayError> {
242 let mut result = self.replay(snapshot_id, mode, None).await?;
243
244 let violations = self
246 .validate_against_policy(&result.original_snapshot, policy)
247 .await?;
248 result.policy_violations = violations;
249
250 if !result.policy_violations.is_empty() {
251 result.status = ReplayStatus::Failed;
252 }
253
254 Ok(result)
255 }
256
257 pub async fn diff(&self, original_id: &str, new_id: &str) -> Result<SnapshotDiff, ReplayError> {
259 let original = self.storage.load_decision(original_id).await.map_err(|e| {
260 ReplayError::SnapshotNotFound(format!("Original snapshot not found: {}", e))
261 })?;
262
263 let new =
264 self.storage.load_decision(new_id).await.map_err(|e| {
265 ReplayError::SnapshotNotFound(format!("New snapshot not found: {}", e))
266 })?;
267
268 Ok(self.calculate_diff(&original, &new))
269 }
270
271 pub async fn validate(
273 &self,
274 snapshot_id: &str,
275 policy: &ReplayPolicy,
276 ) -> Result<Vec<PolicyViolation>, ReplayError> {
277 let snapshot = self
278 .storage
279 .load_decision(snapshot_id)
280 .await
281 .map_err(|e| ReplayError::SnapshotNotFound(e.to_string()))?;
282
283 self.validate_against_policy(&snapshot, policy).await
284 }
285
286 pub async fn replay_batch(
288 &self,
289 snapshot_ids: &[String],
290 mode: Option<ReplayMode>,
291 concurrency: usize,
292 ) -> Vec<Result<ReplayResult, ReplayError>> {
293 let semaphore = tokio::sync::Semaphore::new(concurrency);
294 let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
295
296 let tasks: Vec<_> = snapshot_ids
297 .iter()
298 .map(|id| {
299 let id = id.clone();
300 let mode = replay_mode.clone();
301 let semaphore = &semaphore;
302 async move {
303 let _permit = semaphore.acquire().await.unwrap();
304 self.replay(&id, Some(mode), None).await
305 }
306 })
307 .collect();
308
309 futures::future::join_all(tasks).await
310 }
311
312 pub async fn get_replay_stats(
314 &self,
315 snapshot_ids: &[String],
316 ) -> Result<ReplayStats, ReplayError> {
317 let results = self.replay_batch(snapshot_ids, None, 4).await;
318
319 let mut stats = ReplayStats {
320 total_replays: results.len(),
321 ..Default::default()
322 };
323
324 for result in results {
325 match result {
326 Ok(replay_result) => {
327 stats.successful_replays += 1;
328 stats.total_execution_time_ms += replay_result.execution_time_ms;
329
330 if replay_result.outputs_match {
331 stats.exact_matches += 1;
332 } else {
333 stats.mismatches += 1;
334 }
335 }
336 Err(_) => {
337 stats.failed_replays += 1;
338 }
339 }
340 }
341
342 stats.average_execution_time_ms = if stats.successful_replays > 0 {
343 stats.total_execution_time_ms / stats.successful_replays as f64
344 } else {
345 0.0
346 };
347
348 Ok(stats)
349 }
350
351 #[cfg(feature = "async")]
354 async fn execute_replay(
355 &self,
356 original: &DecisionSnapshot,
357 mode: ReplayMode,
358 start_time: std::time::Instant,
359 ) -> Result<ReplayResult, ReplayError> {
360 if let Some(ref executor) = self.executor {
361 if let Some(ref params) = original.model_parameters {
363 if !executor.supports_model(¶ms.model_name) {
364 return Err(ReplayError::ExecutionError(format!(
365 "Executor '{}' does not support model '{}'",
366 executor.executor_name(),
367 params.model_name
368 )));
369 }
370 }
371
372 let config = ExecutionConfig::default();
374 let exec_result = executor
375 .execute(
376 &original.inputs,
377 original.model_parameters.as_ref(),
378 &original.context,
379 &config,
380 )
381 .await?;
382
383 let execution_time = start_time.elapsed().as_millis() as f64;
384
385 let tolerance = match mode {
387 ReplayMode::Strict => 1.0, ReplayMode::Tolerant => 0.8, ReplayMode::ValidationOnly => 0.0, };
391
392 let comparison =
393 executor.compare_outputs(&original.outputs, &exec_result.outputs, tolerance);
394
395 let replay_output = serde_json::to_value(&exec_result.outputs).ok();
396
397 Ok(ReplayResult {
398 status: if comparison.is_match {
399 ReplayStatus::Success
400 } else {
401 ReplayStatus::Failed
402 },
403 original_snapshot: original.clone(),
404 replay_output,
405 outputs_match: comparison.is_match,
406 diff: Some(SnapshotDiff {
407 inputs_changed: false,
408 outputs_changed: !comparison.is_match,
409 model_params_changed: false,
410 execution_time_delta_ms: execution_time
411 - original.execution_time_ms.unwrap_or(0.0),
412 changes: comparison
413 .field_comparisons
414 .iter()
415 .filter(|c| !c.is_match)
416 .map(|c| FieldChange {
417 field_path: format!("output.{}", c.field_name),
418 old_value: c.original_value.clone(),
419 new_value: c.replayed_value.clone(),
420 change_type: ChangeType::Modified,
421 })
422 .collect(),
423 }),
424 policy_violations: Vec::new(),
425 execution_time_ms: execution_time,
426 })
427 } else {
428 self.simulate_replay(original, mode, start_time).await
430 }
431 }
432
433 async fn simulate_replay(
434 &self,
435 original: &DecisionSnapshot,
436 mode: ReplayMode,
437 start_time: std::time::Instant,
438 ) -> Result<ReplayResult, ReplayError> {
439 let execution_time = original
447 .execution_time_ms
448 .unwrap_or_else(|| start_time.elapsed().as_millis() as f64);
449
450 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
452
453 let outputs_match = match mode {
454 ReplayMode::Strict => true, ReplayMode::Tolerant => true, ReplayMode::ValidationOnly => true,
457 };
458
459 Ok(ReplayResult {
460 status: if outputs_match {
461 ReplayStatus::Success
462 } else {
463 ReplayStatus::Failed
464 },
465 original_snapshot: original.clone(),
466 replay_output: original.outputs.first().map(|o| o.value.clone()),
467 outputs_match,
468 diff: None,
469 policy_violations: Vec::new(),
470 execution_time_ms: execution_time,
471 })
472 }
473
474 async fn validate_against_policy(
475 &self,
476 snapshot: &DecisionSnapshot,
477 policy: &ReplayPolicy,
478 ) -> Result<Vec<PolicyViolation>, ReplayError> {
479 let mut violations = Vec::new();
480
481 for rule in &policy.rules {
482 let violation = self.check_rule(snapshot, rule);
483 if let Some(v) = violation {
484 violations.push(v);
485 }
486 }
487
488 Ok(violations)
489 }
490
491 fn check_rule(
492 &self,
493 snapshot: &DecisionSnapshot,
494 rule: &ValidationRule,
495 ) -> Option<PolicyViolation> {
496 let _field_value = self.extract_field_value(snapshot, &rule.field)?;
498
499 match rule.comparator {
500 Comparator::ExactMatch => {
501 None
503 }
504 Comparator::SemanticSimilarity => {
505 if rule.threshold <= 0.5 {
507 Some(PolicyViolation {
508 rule_name: rule.field.clone(),
509 field: rule.field.clone(),
510 expected: format!("Similarity >= {}", rule.threshold),
511 actual: "0.4".to_string(),
512 message: format!(
513 "Semantic similarity below threshold of {}",
514 rule.threshold
515 ),
516 })
517 } else {
518 None
519 }
520 }
521 _ => None, }
523 }
524
525 fn extract_field_value(
526 &self,
527 snapshot: &DecisionSnapshot,
528 field_path: &str,
529 ) -> Option<serde_json::Value> {
530 match field_path {
532 "function_name" => Some(serde_json::Value::String(snapshot.function_name.clone())),
533 "execution_time_ms" => snapshot
534 .execution_time_ms
535 .map(|t| serde_json::Value::Number(serde_json::Number::from_f64(t).unwrap())),
536 "output" => {
537 snapshot.outputs.first().map(|output| output.value.clone())
539 }
540 _ => None,
541 }
542 }
543
544 fn calculate_diff(&self, original: &DecisionSnapshot, new: &DecisionSnapshot) -> SnapshotDiff {
545 let mut changes = Vec::new();
546
547 if original.function_name != new.function_name {
549 changes.push(FieldChange {
550 field_path: "function_name".to_string(),
551 old_value: serde_json::Value::String(original.function_name.clone()),
552 new_value: serde_json::Value::String(new.function_name.clone()),
553 change_type: ChangeType::Modified,
554 });
555 }
556
557 let inputs_changed = original.inputs != new.inputs;
559 if inputs_changed {
560 changes.push(FieldChange {
561 field_path: "inputs".to_string(),
562 old_value: serde_json::to_value(&original.inputs).unwrap(),
563 new_value: serde_json::to_value(&new.inputs).unwrap(),
564 change_type: ChangeType::Modified,
565 });
566 }
567
568 let outputs_changed = original.outputs != new.outputs;
570 if outputs_changed {
571 changes.push(FieldChange {
572 field_path: "outputs".to_string(),
573 old_value: serde_json::to_value(&original.outputs).unwrap(),
574 new_value: serde_json::to_value(&new.outputs).unwrap(),
575 change_type: ChangeType::Modified,
576 });
577 }
578
579 let model_params_changed = original.model_parameters != new.model_parameters;
581 if model_params_changed {
582 changes.push(FieldChange {
583 field_path: "model_parameters".to_string(),
584 old_value: serde_json::to_value(&original.model_parameters).unwrap(),
585 new_value: serde_json::to_value(&new.model_parameters).unwrap(),
586 change_type: ChangeType::Modified,
587 });
588 }
589
590 let execution_time_delta_ms = match (original.execution_time_ms, new.execution_time_ms) {
591 (Some(old), Some(new)) => new - old,
592 _ => 0.0,
593 };
594
595 SnapshotDiff {
596 inputs_changed,
597 outputs_changed,
598 model_params_changed,
599 execution_time_delta_ms,
600 changes,
601 }
602 }
603}
604
605#[derive(Debug, Clone, Default)]
606pub struct ReplayStats {
607 pub total_replays: usize,
608 pub successful_replays: usize,
609 pub failed_replays: usize,
610 pub exact_matches: usize,
611 pub mismatches: usize,
612 pub total_execution_time_ms: f64,
613 pub average_execution_time_ms: f64,
614}
615
616#[derive(Error, Debug, Clone, PartialEq)]
617pub enum ReplayError {
618 #[error("Snapshot not found: {0}")]
619 SnapshotNotFound(String),
620 #[error("Storage error: {0}")]
621 StorageError(String),
622 #[error("Execution error: {0}")]
623 ExecutionError(String),
624 #[error("Policy violations: {0:?}")]
625 PolicyViolation(Vec<PolicyViolation>),
626}
627
628#[cfg(any(feature = "sqlite-storage", feature = "lakefs-storage"))]
629impl From<crate::storage::StorageError> for ReplayError {
630 fn from(err: crate::storage::StorageError) -> Self {
631 ReplayError::StorageError(err.to_string())
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::models::*;
639 use crate::storage::SqliteBackend;
640 use serde_json::json;
641
642 async fn create_test_decision() -> DecisionSnapshot {
643 let input = Input::new("test_input", json!("hello"), "string");
644 let output = Output::new("test_output", json!("world"), "string");
645 let model_params = ModelParameters::new("gpt-4");
646
647 DecisionSnapshot::new("test_function")
648 .add_input(input)
649 .add_output(output)
650 .with_model_parameters(model_params)
651 .with_execution_time(100.0)
652 }
653
654 #[tokio::test]
655 async fn test_replay_engine_creation() {
656 let storage = SqliteBackend::in_memory().unwrap();
657 let engine = ReplayEngine::new(storage);
658 assert!(matches!(engine.default_mode, ReplayMode::Tolerant));
659 }
660
661 #[tokio::test]
662 async fn test_replay_validation_only() {
663 let storage = SqliteBackend::in_memory().unwrap();
664 let engine = ReplayEngine::new(storage);
665 let decision = create_test_decision().await;
666
667 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
669
670 let result = engine
672 .replay(&decision_id, Some(ReplayMode::ValidationOnly), None)
673 .await
674 .unwrap();
675
676 assert_eq!(result.status, ReplayStatus::Success);
677 assert!(result.outputs_match);
678 assert!(result.replay_output.is_none());
679 }
680
681 #[tokio::test]
682 async fn test_replay_tolerant_mode() {
683 let storage = SqliteBackend::in_memory().unwrap();
684 let engine = ReplayEngine::new(storage);
685 let decision = create_test_decision().await;
686
687 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
689
690 let result = engine
692 .replay(&decision_id, Some(ReplayMode::Tolerant), None)
693 .await
694 .unwrap();
695
696 assert_eq!(result.status, ReplayStatus::Success);
697 assert!(result.outputs_match);
698 assert!(result.replay_output.is_some());
699 }
700
701 #[tokio::test]
702 async fn test_replay_with_policy() {
703 let storage = SqliteBackend::in_memory().unwrap();
704 let engine = ReplayEngine::new(storage);
705 let decision = create_test_decision().await;
706
707 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
709
710 let policy = ReplayPolicy::new("test_policy")
712 .with_exact_match("function_name")
713 .with_similarity_threshold("output", 0.9);
714
715 let result = engine
717 .replay_with_policy(&decision_id, &policy, None)
718 .await
719 .unwrap();
720
721 assert_eq!(result.status, ReplayStatus::Success);
722 assert!(result.policy_violations.is_empty());
723 }
724
725 #[tokio::test]
726 async fn test_diff_calculation() {
727 let storage = SqliteBackend::in_memory().unwrap();
728 let engine = ReplayEngine::new(storage);
729
730 let decision1 = create_test_decision().await;
731 let mut decision2 = create_test_decision().await;
732 decision2.function_name = "different_function".to_string();
733
734 let id1 = engine.storage.save_decision(&decision1).await.unwrap();
736 let id2 = engine.storage.save_decision(&decision2).await.unwrap();
737
738 let diff = engine.diff(&id1, &id2).await.unwrap();
740
741 assert!(!diff.changes.is_empty());
742 assert!(diff.changes.iter().any(|c| c.field_path == "function_name"));
743 }
744
745 #[tokio::test]
746 async fn test_batch_replay() {
747 let storage = SqliteBackend::in_memory().unwrap();
748 let engine = ReplayEngine::new(storage);
749
750 let mut snapshot_ids = Vec::new();
751
752 for i in 0..3 {
754 let mut decision = create_test_decision().await;
755 decision.function_name = format!("test_function_{}", i);
756 let id = engine.storage.save_decision(&decision).await.unwrap();
757 snapshot_ids.push(id);
758 }
759
760 let results = engine.replay_batch(&snapshot_ids, None, 2).await;
762
763 assert_eq!(results.len(), 3);
764 assert!(results.iter().all(|r| r.is_ok()));
765 }
766
767 #[tokio::test]
768 async fn test_replay_stats() {
769 let storage = SqliteBackend::in_memory().unwrap();
770 let engine = ReplayEngine::new(storage);
771
772 let mut snapshot_ids = Vec::new();
773
774 for i in 0..5 {
776 let mut decision = create_test_decision().await;
777 decision.function_name = format!("test_function_{}", i);
778 let id = engine.storage.save_decision(&decision).await.unwrap();
779 snapshot_ids.push(id);
780 }
781
782 let stats = engine.get_replay_stats(&snapshot_ids).await.unwrap();
784
785 assert_eq!(stats.total_replays, 5);
786 assert_eq!(stats.successful_replays, 5);
787 assert_eq!(stats.failed_replays, 0);
788 assert!(stats.average_execution_time_ms > 0.0);
789 }
790
791 #[tokio::test]
792 async fn test_nonexistent_snapshot() {
793 let storage = SqliteBackend::in_memory().unwrap();
794 let engine = ReplayEngine::new(storage);
795
796 let result = engine.replay("nonexistent-id", None, None).await;
797 assert!(matches!(result, Err(ReplayError::SnapshotNotFound(_))));
798 }
799
800 #[tokio::test]
801 async fn test_policy_validation() {
802 let storage = SqliteBackend::in_memory().unwrap();
803 let engine = ReplayEngine::new(storage);
804 let decision = create_test_decision().await;
805
806 let decision_id = engine.storage.save_decision(&decision).await.unwrap();
808
809 let policy = ReplayPolicy::new("strict_policy").with_similarity_threshold("output", 0.3); let violations = engine.validate(&decision_id, &policy).await.unwrap();
813 assert!(!violations.is_empty());
814 }
815}