use chrono::{DateTime, Utc};
use uuid::Uuid;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ShadowComparison {
pub agreement_rate: f64,
pub sample_count: usize,
pub latency_diff_ms: f64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ShadowConfig {
pub sample_rate: f64,
pub min_samples: usize,
pub min_agreement: f64,
pub max_latency_increase_ms: f64,
}
impl Default for ShadowConfig {
fn default() -> Self {
Self {
sample_rate: 0.1, min_samples: 1000, min_agreement: 0.95, max_latency_increase_ms: 50.0, }
}
}
impl ShadowConfig {
pub fn quick() -> Self {
Self {
sample_rate: 0.2,
min_samples: 100,
min_agreement: 0.90,
max_latency_increase_ms: 100.0,
}
}
pub fn strict() -> Self {
Self {
sample_rate: 0.1,
min_samples: 10000,
min_agreement: 0.99,
max_latency_increase_ms: 20.0,
}
}
pub fn validate(&self) -> Result<(), String> {
if !(0.0..=1.0).contains(&self.sample_rate) {
return Err(format!(
"sample_rate must be in [0.0, 1.0], got {}",
self.sample_rate
));
}
if self.min_samples == 0 {
return Err("min_samples must be > 0".to_string());
}
if !(0.0..=1.0).contains(&self.min_agreement) {
return Err(format!(
"min_agreement must be in [0.0, 1.0], got {}",
self.min_agreement
));
}
if self.max_latency_increase_ms < 0.0 {
return Err(format!(
"max_latency_increase_ms must be >= 0, got {}",
self.max_latency_increase_ms
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
pub enum ShadowState {
Running {
started_at: DateTime<Utc>,
samples_collected: usize,
},
Passed {
comparison: ShadowComparison,
completed_at: DateTime<Utc>,
},
Failed {
comparison: ShadowComparison,
reason: String,
completed_at: DateTime<Utc>,
},
Cancelled {
reason: String,
cancelled_at: DateTime<Utc>,
},
}
#[derive(Debug)]
pub struct ShadowSession {
pub id: Uuid,
pub production_model_id: Uuid,
pub candidate_model_id: Uuid,
pub config: ShadowConfig,
state: ShadowState,
samples: Vec<(bool, f64)>,
}
impl ShadowSession {
pub fn new(production_model_id: Uuid, candidate_model_id: Uuid, config: ShadowConfig) -> Self {
Self {
id: Uuid::new_v4(),
production_model_id,
candidate_model_id,
config,
state: ShadowState::Running {
started_at: Utc::now(),
samples_collected: 0,
},
samples: Vec::new(),
}
}
pub fn with_id(
id: Uuid,
production_model_id: Uuid,
candidate_model_id: Uuid,
config: ShadowConfig,
) -> Self {
let mut session = Self::new(production_model_id, candidate_model_id, config);
session.id = id;
session
}
pub fn record_sample(&mut self, agreed: bool, latency_diff_ms: f64) {
if let ShadowState::Running {
samples_collected, ..
} = &mut self.state
{
self.samples.push((agreed, latency_diff_ms));
*samples_collected = self.samples.len();
}
}
pub fn evaluate(&mut self) -> &ShadowState {
if let ShadowState::Running { .. } = &self.state {
if self.samples.len() >= self.config.min_samples {
let comparison = self.compute_comparison();
if comparison.agreement_rate >= self.config.min_agreement
&& comparison.latency_diff_ms <= self.config.max_latency_increase_ms
{
self.state = ShadowState::Passed {
comparison,
completed_at: Utc::now(),
};
} else {
let reason = self.failure_reason(&comparison);
self.state = ShadowState::Failed {
comparison,
reason,
completed_at: Utc::now(),
};
}
}
}
&self.state
}
pub fn cancel(&mut self, reason: impl Into<String>) {
self.state = ShadowState::Cancelled {
reason: reason.into(),
cancelled_at: Utc::now(),
};
}
pub fn state(&self) -> &ShadowState {
&self.state
}
pub fn sample_count(&self) -> usize {
self.samples.len()
}
pub fn passed(&self) -> bool {
matches!(self.state, ShadowState::Passed { .. })
}
pub fn failed(&self) -> bool {
matches!(self.state, ShadowState::Failed { .. })
}
pub fn is_running(&self) -> bool {
matches!(self.state, ShadowState::Running { .. })
}
pub fn is_complete(&self) -> bool {
!self.is_running()
}
pub fn current_comparison(&self) -> ShadowComparison {
self.compute_comparison()
}
pub fn progress(&self) -> f64 {
(self.samples.len() as f64 / self.config.min_samples as f64).min(1.0)
}
fn compute_comparison(&self) -> ShadowComparison {
if self.samples.is_empty() {
return ShadowComparison {
agreement_rate: 0.0,
sample_count: 0,
latency_diff_ms: 0.0,
};
}
let agreed_count = self.samples.iter().filter(|(a, _)| *a).count();
let total_latency: f64 = self.samples.iter().map(|(_, l)| *l).sum();
ShadowComparison {
agreement_rate: agreed_count as f64 / self.samples.len() as f64,
sample_count: self.samples.len(),
latency_diff_ms: total_latency / self.samples.len() as f64,
}
}
fn failure_reason(&self, comparison: &ShadowComparison) -> String {
let mut reasons = Vec::new();
if comparison.agreement_rate < self.config.min_agreement {
reasons.push(format!(
"Agreement rate {:.2}% below threshold {:.2}%",
comparison.agreement_rate * 100.0,
self.config.min_agreement * 100.0
));
}
if comparison.latency_diff_ms > self.config.max_latency_increase_ms {
reasons.push(format!(
"Latency increase {:.1}ms exceeds threshold {:.1}ms",
comparison.latency_diff_ms, self.config.max_latency_increase_ms
));
}
if reasons.is_empty() {
"Unknown failure".to_string()
} else {
reasons.join("; ")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shadow_config_default() {
let config = ShadowConfig::default();
assert_eq!(config.sample_rate, 0.1);
assert_eq!(config.min_samples, 1000);
assert_eq!(config.min_agreement, 0.95);
assert_eq!(config.max_latency_increase_ms, 50.0);
assert!(config.validate().is_ok());
}
#[test]
fn test_shadow_config_quick() {
let config = ShadowConfig::quick();
assert_eq!(config.min_samples, 100);
assert!(config.validate().is_ok());
}
#[test]
fn test_shadow_config_strict() {
let config = ShadowConfig::strict();
assert_eq!(config.min_samples, 10000);
assert_eq!(config.min_agreement, 0.99);
assert!(config.validate().is_ok());
}
#[test]
fn test_shadow_config_validation() {
let mut config = ShadowConfig::default();
config.sample_rate = 1.5;
assert!(config.validate().is_err());
config.sample_rate = 0.1;
config.min_samples = 0;
assert!(config.validate().is_err());
config.min_samples = 100;
config.min_agreement = 1.5;
assert!(config.validate().is_err());
config.min_agreement = 0.95;
config.max_latency_increase_ms = -10.0;
assert!(config.validate().is_err());
}
#[test]
fn test_shadow_session_creation() {
let prod_id = Uuid::new_v4();
let cand_id = Uuid::new_v4();
let config = ShadowConfig::quick();
let session = ShadowSession::new(prod_id, cand_id, config);
assert_eq!(session.production_model_id, prod_id);
assert_eq!(session.candidate_model_id, cand_id);
assert!(session.is_running());
assert!(!session.passed());
assert!(!session.failed());
assert_eq!(session.sample_count(), 0);
}
#[test]
fn test_shadow_session_record_sample() {
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), ShadowConfig::quick());
session.record_sample(true, 5.0);
session.record_sample(false, 10.0);
session.record_sample(true, -2.0);
assert_eq!(session.sample_count(), 3);
let comparison = session.current_comparison();
assert_eq!(comparison.sample_count, 3);
assert!((comparison.agreement_rate - 2.0 / 3.0).abs() < 0.001);
assert!((comparison.latency_diff_ms - 13.0 / 3.0).abs() < 0.001);
}
#[test]
fn test_shadow_session_evaluate_pass() {
let config = ShadowConfig {
sample_rate: 0.1,
min_samples: 10,
min_agreement: 0.90,
max_latency_increase_ms: 50.0,
};
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), config);
for _ in 0..10 {
session.record_sample(true, 5.0);
}
let state = session.evaluate();
assert!(matches!(state, ShadowState::Passed { .. }));
assert!(session.passed());
}
#[test]
fn test_shadow_session_evaluate_fail_agreement() {
let config = ShadowConfig {
sample_rate: 0.1,
min_samples: 10,
min_agreement: 0.90,
max_latency_increase_ms: 50.0,
};
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), config);
for i in 0..10 {
session.record_sample(i < 5, 5.0);
}
session.evaluate();
assert!(session.failed());
if let ShadowState::Failed { reason, .. } = session.state() {
assert!(reason.contains("Agreement rate"));
} else {
panic!("Expected Failed state");
}
}
#[test]
fn test_shadow_session_evaluate_fail_latency() {
let config = ShadowConfig {
sample_rate: 0.1,
min_samples: 10,
min_agreement: 0.90,
max_latency_increase_ms: 10.0,
};
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), config);
for _ in 0..10 {
session.record_sample(true, 50.0); }
session.evaluate();
assert!(session.failed());
if let ShadowState::Failed { reason, .. } = session.state() {
assert!(reason.contains("Latency increase"));
} else {
panic!("Expected Failed state");
}
}
#[test]
fn test_shadow_session_evaluate_not_enough_samples() {
let config = ShadowConfig {
sample_rate: 0.1,
min_samples: 100,
min_agreement: 0.90,
max_latency_increase_ms: 50.0,
};
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), config);
for _ in 0..50 {
session.record_sample(true, 5.0);
}
session.evaluate();
assert!(session.is_running());
assert!(!session.passed());
assert!(!session.failed());
}
#[test]
fn test_shadow_session_cancel() {
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), ShadowConfig::quick());
session.record_sample(true, 5.0);
session.cancel("Manual cancellation for testing");
assert!(!session.is_running());
assert!(session.is_complete());
assert!(!session.passed());
if let ShadowState::Cancelled { reason, .. } = session.state() {
assert_eq!(reason, "Manual cancellation for testing");
} else {
panic!("Expected Cancelled state");
}
}
#[test]
fn test_shadow_session_no_samples_after_cancel() {
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), ShadowConfig::quick());
session.record_sample(true, 5.0);
assert_eq!(session.sample_count(), 1);
session.cancel("done");
session.record_sample(true, 5.0);
assert_eq!(session.sample_count(), 1);
}
#[test]
fn test_shadow_session_progress() {
let config = ShadowConfig {
sample_rate: 0.1,
min_samples: 100,
min_agreement: 0.90,
max_latency_increase_ms: 50.0,
};
let mut session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), config);
assert_eq!(session.progress(), 0.0);
for _ in 0..50 {
session.record_sample(true, 5.0);
}
assert!((session.progress() - 0.5).abs() < 0.001);
for _ in 0..50 {
session.record_sample(true, 5.0);
}
assert!((session.progress() - 1.0).abs() < 0.001);
for _ in 0..50 {
session.record_sample(true, 5.0);
}
assert!((session.progress() - 1.0).abs() < 0.001);
}
#[test]
fn test_shadow_comparison_empty() {
let session = ShadowSession::new(Uuid::new_v4(), Uuid::new_v4(), ShadowConfig::quick());
let comparison = session.current_comparison();
assert_eq!(comparison.sample_count, 0);
assert_eq!(comparison.agreement_rate, 0.0);
assert_eq!(comparison.latency_diff_ms, 0.0);
}
#[test]
fn test_shadow_session_with_id() {
let session_id = Uuid::new_v4();
let session = ShadowSession::with_id(
session_id,
Uuid::new_v4(),
Uuid::new_v4(),
ShadowConfig::quick(),
);
assert_eq!(session.id, session_id);
}
}