use crate::models::DecisionSnapshot;
#[cfg(feature = "sqlite-storage")]
use crate::storage::StorageBackend;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[cfg(feature = "async")]
use tokio;
pub mod executor;
#[cfg(feature = "async")]
pub use executor::ModelExecutor;
pub use executor::{
ComparisonResult, ExecutionConfig, ExecutionResult, FieldComparison, SyncModelExecutor,
};
#[cfg(feature = "sqlite-storage")]
pub mod sync;
#[cfg(feature = "sqlite-storage")]
pub use sync::SyncReplayEngine;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ReplayMode {
Strict, Tolerant, ValidationOnly, }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ReplayStatus {
Pending,
Running,
Success,
Failed,
Partial, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReplayResult {
pub status: ReplayStatus,
pub original_snapshot: DecisionSnapshot,
pub replay_output: Option<serde_json::Value>,
pub outputs_match: bool,
pub diff: Option<SnapshotDiff>,
pub policy_violations: Vec<PolicyViolation>,
pub execution_time_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotDiff {
pub inputs_changed: bool,
pub outputs_changed: bool,
pub model_params_changed: bool,
pub execution_time_delta_ms: f64,
pub changes: Vec<FieldChange>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldChange {
pub field_path: String,
pub old_value: serde_json::Value,
pub new_value: serde_json::Value,
pub change_type: ChangeType,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ChangeType {
Added,
Removed,
Modified,
}
#[derive(Debug, Clone)]
pub struct ReplayPolicy {
pub name: String,
pub rules: Vec<ValidationRule>,
}
impl ReplayPolicy {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
rules: Vec::new(),
}
}
pub fn add_rule(mut self, rule: ValidationRule) -> Self {
self.rules.push(rule);
self
}
pub fn with_exact_match(mut self, field: impl Into<String>) -> Self {
self.rules.push(ValidationRule {
field: field.into(),
comparator: Comparator::ExactMatch,
threshold: 1.0,
});
self
}
pub fn with_similarity_threshold(mut self, field: impl Into<String>, threshold: f64) -> Self {
self.rules.push(ValidationRule {
field: field.into(),
comparator: Comparator::SemanticSimilarity,
threshold,
});
self
}
}
#[derive(Debug, Clone)]
pub struct ValidationRule {
pub field: String,
pub comparator: Comparator,
pub threshold: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Comparator {
ExactMatch,
SemanticSimilarity,
MaxIncreasePercent,
MaxDecreasePercent,
WithinRange,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PolicyViolation {
pub rule_name: String,
pub field: String,
pub expected: String,
pub actual: String,
pub message: String,
}
#[cfg(feature = "sqlite-storage")]
#[derive(Clone)]
pub struct ReplayEngine<S: StorageBackend> {
storage: S,
default_mode: ReplayMode,
#[cfg(feature = "async")]
executor: Option<std::sync::Arc<dyn ModelExecutor>>,
}
#[cfg(feature = "sqlite-storage")]
impl<S: StorageBackend> ReplayEngine<S> {
pub fn new(storage: S) -> Self {
Self {
storage,
default_mode: ReplayMode::Tolerant,
#[cfg(feature = "async")]
executor: None,
}
}
pub fn with_mode(storage: S, mode: ReplayMode) -> Self {
Self {
storage,
default_mode: mode,
#[cfg(feature = "async")]
executor: None,
}
}
#[cfg(feature = "async")]
pub fn with_executor(mut self, executor: std::sync::Arc<dyn ModelExecutor>) -> Self {
self.executor = Some(executor);
self
}
#[cfg(feature = "async")]
pub fn executor(&self) -> Option<&dyn ModelExecutor> {
self.executor.as_ref().map(|arc| arc.as_ref())
}
pub fn default_mode(&self) -> &ReplayMode {
&self.default_mode
}
pub async fn replay(
&self,
snapshot_id: &str,
mode: Option<ReplayMode>,
_context_overrides: Option<std::collections::HashMap<String, serde_json::Value>>,
) -> Result<ReplayResult, ReplayError> {
let start_time = std::time::Instant::now();
let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
let original_snapshot = match self.storage.load_decision(snapshot_id).await {
Ok(snapshot) => snapshot,
Err(e) => {
return Err(ReplayError::SnapshotNotFound(format!(
"Failed to load snapshot {}: {}",
snapshot_id, e
)))
}
};
match replay_mode {
ReplayMode::ValidationOnly => {
let execution_time = start_time.elapsed().as_millis() as f64;
Ok(ReplayResult {
status: ReplayStatus::Success,
original_snapshot,
replay_output: None,
outputs_match: true, diff: None,
policy_violations: Vec::new(),
execution_time_ms: execution_time,
})
}
ReplayMode::Strict | ReplayMode::Tolerant => {
#[cfg(feature = "async")]
{
self.execute_replay(&original_snapshot, replay_mode, start_time)
.await
}
#[cfg(not(feature = "async"))]
{
self.simulate_replay(&original_snapshot, replay_mode, start_time)
.await
}
}
}
}
pub async fn replay_with_policy(
&self,
snapshot_id: &str,
policy: &ReplayPolicy,
mode: Option<ReplayMode>,
) -> Result<ReplayResult, ReplayError> {
let mut result = self.replay(snapshot_id, mode, None).await?;
let violations = self
.validate_against_policy(&result.original_snapshot, policy)
.await?;
result.policy_violations = violations;
if !result.policy_violations.is_empty() {
result.status = ReplayStatus::Failed;
}
Ok(result)
}
pub async fn diff(&self, original_id: &str, new_id: &str) -> Result<SnapshotDiff, ReplayError> {
let original = self.storage.load_decision(original_id).await.map_err(|e| {
ReplayError::SnapshotNotFound(format!("Original snapshot not found: {}", e))
})?;
let new =
self.storage.load_decision(new_id).await.map_err(|e| {
ReplayError::SnapshotNotFound(format!("New snapshot not found: {}", e))
})?;
Ok(self.calculate_diff(&original, &new))
}
pub async fn validate(
&self,
snapshot_id: &str,
policy: &ReplayPolicy,
) -> Result<Vec<PolicyViolation>, ReplayError> {
let snapshot = self
.storage
.load_decision(snapshot_id)
.await
.map_err(|e| ReplayError::SnapshotNotFound(e.to_string()))?;
self.validate_against_policy(&snapshot, policy).await
}
pub async fn replay_batch(
&self,
snapshot_ids: &[String],
mode: Option<ReplayMode>,
concurrency: usize,
) -> Vec<Result<ReplayResult, ReplayError>> {
let semaphore = tokio::sync::Semaphore::new(concurrency);
let replay_mode = mode.unwrap_or_else(|| self.default_mode.clone());
let tasks: Vec<_> = snapshot_ids
.iter()
.map(|id| {
let id = id.clone();
let mode = replay_mode.clone();
let semaphore = &semaphore;
async move {
let _permit = semaphore.acquire().await.unwrap();
self.replay(&id, Some(mode), None).await
}
})
.collect();
futures::future::join_all(tasks).await
}
pub async fn get_replay_stats(
&self,
snapshot_ids: &[String],
) -> Result<ReplayStats, ReplayError> {
let results = self.replay_batch(snapshot_ids, None, 4).await;
let mut stats = ReplayStats {
total_replays: results.len(),
..Default::default()
};
for result in results {
match result {
Ok(replay_result) => {
stats.successful_replays += 1;
stats.total_execution_time_ms += replay_result.execution_time_ms;
if replay_result.outputs_match {
stats.exact_matches += 1;
} else {
stats.mismatches += 1;
}
}
Err(_) => {
stats.failed_replays += 1;
}
}
}
stats.average_execution_time_ms = if stats.successful_replays > 0 {
stats.total_execution_time_ms / stats.successful_replays as f64
} else {
0.0
};
Ok(stats)
}
#[cfg(feature = "async")]
async fn execute_replay(
&self,
original: &DecisionSnapshot,
mode: ReplayMode,
start_time: std::time::Instant,
) -> Result<ReplayResult, ReplayError> {
if let Some(ref executor) = self.executor {
if let Some(ref params) = original.model_parameters {
if !executor.supports_model(¶ms.model_name) {
return Err(ReplayError::ExecutionError(format!(
"Executor '{}' does not support model '{}'",
executor.executor_name(),
params.model_name
)));
}
}
let config = ExecutionConfig::default();
let exec_result = executor
.execute(
&original.inputs,
original.model_parameters.as_ref(),
&original.context,
&config,
)
.await?;
let execution_time = start_time.elapsed().as_millis() as f64;
let tolerance = match mode {
ReplayMode::Strict => 1.0, ReplayMode::Tolerant => 0.8, ReplayMode::ValidationOnly => 0.0, };
let comparison =
executor.compare_outputs(&original.outputs, &exec_result.outputs, tolerance);
let replay_output = serde_json::to_value(&exec_result.outputs).ok();
Ok(ReplayResult {
status: if comparison.is_match {
ReplayStatus::Success
} else {
ReplayStatus::Failed
},
original_snapshot: original.clone(),
replay_output,
outputs_match: comparison.is_match,
diff: Some(SnapshotDiff {
inputs_changed: false,
outputs_changed: !comparison.is_match,
model_params_changed: false,
execution_time_delta_ms: execution_time
- original.execution_time_ms.unwrap_or(0.0),
changes: comparison
.field_comparisons
.iter()
.filter(|c| !c.is_match)
.map(|c| FieldChange {
field_path: format!("output.{}", c.field_name),
old_value: c.original_value.clone(),
new_value: c.replayed_value.clone(),
change_type: ChangeType::Modified,
})
.collect(),
}),
policy_violations: Vec::new(),
execution_time_ms: execution_time,
})
} else {
self.simulate_replay(original, mode, start_time).await
}
}
async fn simulate_replay(
&self,
original: &DecisionSnapshot,
mode: ReplayMode,
start_time: std::time::Instant,
) -> Result<ReplayResult, ReplayError> {
let execution_time = original
.execution_time_ms
.unwrap_or_else(|| start_time.elapsed().as_millis() as f64);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let outputs_match = match mode {
ReplayMode::Strict => true, ReplayMode::Tolerant => true, ReplayMode::ValidationOnly => true,
};
Ok(ReplayResult {
status: if outputs_match {
ReplayStatus::Success
} else {
ReplayStatus::Failed
},
original_snapshot: original.clone(),
replay_output: original.outputs.first().map(|o| o.value.clone()),
outputs_match,
diff: None,
policy_violations: Vec::new(),
execution_time_ms: execution_time,
})
}
async fn validate_against_policy(
&self,
snapshot: &DecisionSnapshot,
policy: &ReplayPolicy,
) -> Result<Vec<PolicyViolation>, ReplayError> {
let mut violations = Vec::new();
for rule in &policy.rules {
let violation = self.check_rule(snapshot, rule);
if let Some(v) = violation {
violations.push(v);
}
}
Ok(violations)
}
fn check_rule(
&self,
snapshot: &DecisionSnapshot,
rule: &ValidationRule,
) -> Option<PolicyViolation> {
let _field_value = self.extract_field_value(snapshot, &rule.field)?;
match rule.comparator {
Comparator::ExactMatch => {
None
}
Comparator::SemanticSimilarity => {
if rule.threshold <= 0.5 {
Some(PolicyViolation {
rule_name: rule.field.clone(),
field: rule.field.clone(),
expected: format!("Similarity >= {}", rule.threshold),
actual: "0.4".to_string(),
message: format!(
"Semantic similarity below threshold of {}",
rule.threshold
),
})
} else {
None
}
}
_ => None, }
}
fn extract_field_value(
&self,
snapshot: &DecisionSnapshot,
field_path: &str,
) -> Option<serde_json::Value> {
match field_path {
"function_name" => Some(serde_json::Value::String(snapshot.function_name.clone())),
"execution_time_ms" => snapshot
.execution_time_ms
.map(|t| serde_json::Value::Number(serde_json::Number::from_f64(t).unwrap())),
"output" => {
snapshot.outputs.first().map(|output| output.value.clone())
}
_ => None,
}
}
fn calculate_diff(&self, original: &DecisionSnapshot, new: &DecisionSnapshot) -> SnapshotDiff {
let mut changes = Vec::new();
if original.function_name != new.function_name {
changes.push(FieldChange {
field_path: "function_name".to_string(),
old_value: serde_json::Value::String(original.function_name.clone()),
new_value: serde_json::Value::String(new.function_name.clone()),
change_type: ChangeType::Modified,
});
}
let inputs_changed = original.inputs != new.inputs;
if inputs_changed {
changes.push(FieldChange {
field_path: "inputs".to_string(),
old_value: serde_json::to_value(&original.inputs).unwrap(),
new_value: serde_json::to_value(&new.inputs).unwrap(),
change_type: ChangeType::Modified,
});
}
let outputs_changed = original.outputs != new.outputs;
if outputs_changed {
changes.push(FieldChange {
field_path: "outputs".to_string(),
old_value: serde_json::to_value(&original.outputs).unwrap(),
new_value: serde_json::to_value(&new.outputs).unwrap(),
change_type: ChangeType::Modified,
});
}
let model_params_changed = original.model_parameters != new.model_parameters;
if model_params_changed {
changes.push(FieldChange {
field_path: "model_parameters".to_string(),
old_value: serde_json::to_value(&original.model_parameters).unwrap(),
new_value: serde_json::to_value(&new.model_parameters).unwrap(),
change_type: ChangeType::Modified,
});
}
let execution_time_delta_ms = match (original.execution_time_ms, new.execution_time_ms) {
(Some(old), Some(new)) => new - old,
_ => 0.0,
};
SnapshotDiff {
inputs_changed,
outputs_changed,
model_params_changed,
execution_time_delta_ms,
changes,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ReplayStats {
pub total_replays: usize,
pub successful_replays: usize,
pub failed_replays: usize,
pub exact_matches: usize,
pub mismatches: usize,
pub total_execution_time_ms: f64,
pub average_execution_time_ms: f64,
}
#[derive(Error, Debug, Clone, PartialEq)]
pub enum ReplayError {
#[error("Snapshot not found: {0}")]
SnapshotNotFound(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Execution error: {0}")]
ExecutionError(String),
#[error("Policy violations: {0:?}")]
PolicyViolation(Vec<PolicyViolation>),
}
#[cfg(feature = "sqlite-storage")]
impl From<crate::storage::StorageError> for ReplayError {
fn from(err: crate::storage::StorageError) -> Self {
ReplayError::StorageError(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::*;
use crate::storage::SqliteBackend;
use serde_json::json;
async fn create_test_decision() -> DecisionSnapshot {
let input = Input::new("test_input", json!("hello"), "string");
let output = Output::new("test_output", json!("world"), "string");
let model_params = ModelParameters::new("gpt-4");
DecisionSnapshot::new("test_function")
.add_input(input)
.add_output(output)
.with_model_parameters(model_params)
.with_execution_time(100.0)
}
#[tokio::test]
async fn test_replay_engine_creation() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
assert!(matches!(engine.default_mode, ReplayMode::Tolerant));
}
#[tokio::test]
async fn test_replay_validation_only() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let decision = create_test_decision().await;
let decision_id = engine.storage.save_decision(&decision).await.unwrap();
let result = engine
.replay(&decision_id, Some(ReplayMode::ValidationOnly), None)
.await
.unwrap();
assert_eq!(result.status, ReplayStatus::Success);
assert!(result.outputs_match);
assert!(result.replay_output.is_none());
}
#[tokio::test]
async fn test_replay_tolerant_mode() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let decision = create_test_decision().await;
let decision_id = engine.storage.save_decision(&decision).await.unwrap();
let result = engine
.replay(&decision_id, Some(ReplayMode::Tolerant), None)
.await
.unwrap();
assert_eq!(result.status, ReplayStatus::Success);
assert!(result.outputs_match);
assert!(result.replay_output.is_some());
}
#[tokio::test]
async fn test_replay_with_policy() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let decision = create_test_decision().await;
let decision_id = engine.storage.save_decision(&decision).await.unwrap();
let policy = ReplayPolicy::new("test_policy")
.with_exact_match("function_name")
.with_similarity_threshold("output", 0.9);
let result = engine
.replay_with_policy(&decision_id, &policy, None)
.await
.unwrap();
assert_eq!(result.status, ReplayStatus::Success);
assert!(result.policy_violations.is_empty());
}
#[tokio::test]
async fn test_diff_calculation() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let decision1 = create_test_decision().await;
let mut decision2 = create_test_decision().await;
decision2.function_name = "different_function".to_string();
let id1 = engine.storage.save_decision(&decision1).await.unwrap();
let id2 = engine.storage.save_decision(&decision2).await.unwrap();
let diff = engine.diff(&id1, &id2).await.unwrap();
assert!(!diff.changes.is_empty());
assert!(diff.changes.iter().any(|c| c.field_path == "function_name"));
}
#[tokio::test]
async fn test_batch_replay() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let mut snapshot_ids = Vec::new();
for i in 0..3 {
let mut decision = create_test_decision().await;
decision.function_name = format!("test_function_{}", i);
let id = engine.storage.save_decision(&decision).await.unwrap();
snapshot_ids.push(id);
}
let results = engine.replay_batch(&snapshot_ids, None, 2).await;
assert_eq!(results.len(), 3);
assert!(results.iter().all(|r| r.is_ok()));
}
#[tokio::test]
async fn test_replay_stats() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let mut snapshot_ids = Vec::new();
for i in 0..5 {
let mut decision = create_test_decision().await;
decision.function_name = format!("test_function_{}", i);
let id = engine.storage.save_decision(&decision).await.unwrap();
snapshot_ids.push(id);
}
let stats = engine.get_replay_stats(&snapshot_ids).await.unwrap();
assert_eq!(stats.total_replays, 5);
assert_eq!(stats.successful_replays, 5);
assert_eq!(stats.failed_replays, 0);
assert!(stats.average_execution_time_ms > 0.0);
}
#[tokio::test]
async fn test_nonexistent_snapshot() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let result = engine.replay("nonexistent-id", None, None).await;
assert!(matches!(result, Err(ReplayError::SnapshotNotFound(_))));
}
#[tokio::test]
async fn test_policy_validation() {
let storage = SqliteBackend::in_memory().unwrap();
let engine = ReplayEngine::new(storage);
let decision = create_test_decision().await;
let decision_id = engine.storage.save_decision(&decision).await.unwrap();
let policy = ReplayPolicy::new("strict_policy").with_similarity_threshold("output", 0.3);
let violations = engine.validate(&decision_id, &policy).await.unwrap();
assert!(!violations.is_empty());
}
}