use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DriftStrategy {
Linear,
Stepped,
StateMachine,
RandomWalk,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DriftRule {
pub field: String,
pub strategy: DriftStrategy,
pub params: HashMap<String, Value>,
pub rate: f64,
pub min_value: Option<Value>,
pub max_value: Option<Value>,
pub states: Option<Vec<String>>,
pub transitions: Option<HashMap<String, Vec<(String, f64)>>>,
}
impl DriftRule {
pub fn new(field: String, strategy: DriftStrategy) -> Self {
Self {
field,
strategy,
params: HashMap::new(),
rate: 1.0,
min_value: None,
max_value: None,
states: None,
transitions: None,
}
}
pub fn with_rate(mut self, rate: f64) -> Self {
self.rate = rate;
self
}
pub fn with_bounds(mut self, min: Value, max: Value) -> Self {
self.min_value = Some(min);
self.max_value = Some(max);
self
}
pub fn with_states(mut self, states: Vec<String>) -> Self {
self.states = Some(states);
self
}
pub fn with_transitions(mut self, transitions: HashMap<String, Vec<(String, f64)>>) -> Self {
self.transitions = Some(transitions);
self
}
pub fn with_param(mut self, key: String, value: Value) -> Self {
self.params.insert(key, value);
self
}
pub fn validate(&self) -> Result<()> {
if self.field.is_empty() {
return Err(Error::generic("Field name cannot be empty"));
}
if self.rate < 0.0 {
return Err(Error::generic("Rate must be non-negative"));
}
if self.strategy == DriftStrategy::StateMachine
&& (self.states.is_none() || self.transitions.is_none())
{
return Err(Error::generic("State machine strategy requires states and transitions"));
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataDriftConfig {
pub rules: Vec<DriftRule>,
pub time_based: bool,
pub request_based: bool,
pub interval: u64,
pub seed: Option<u64>,
}
impl Default for DataDriftConfig {
fn default() -> Self {
Self {
rules: Vec::new(),
time_based: false,
request_based: true,
interval: 1,
seed: None,
}
}
}
impl DataDriftConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_rule(mut self, rule: DriftRule) -> Self {
self.rules.push(rule);
self
}
pub fn with_time_based(mut self, interval_secs: u64) -> Self {
self.time_based = true;
self.interval = interval_secs;
self
}
pub fn with_request_based(mut self, interval_requests: u64) -> Self {
self.request_based = true;
self.interval = interval_requests;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn validate(&self) -> Result<()> {
for rule in &self.rules {
rule.validate()?;
}
if self.interval == 0 {
return Err(Error::generic("Interval must be greater than 0"));
}
Ok(())
}
}
#[derive(Debug)]
struct DriftState {
values: HashMap<String, Value>,
request_count: u64,
start_time: std::time::Instant,
rng: rand::rngs::StdRng,
}
pub struct DataDriftEngine {
config: DataDriftConfig,
state: Arc<RwLock<DriftState>>,
}
impl DataDriftEngine {
pub fn new(config: DataDriftConfig) -> Result<Self> {
config.validate()?;
use rand::SeedableRng;
let rng = if let Some(seed) = config.seed {
rand::rngs::StdRng::seed_from_u64(seed)
} else {
rand::rngs::StdRng::seed_from_u64(fastrand::u64(..))
};
let state = DriftState {
values: HashMap::new(),
request_count: 0,
start_time: std::time::Instant::now(),
rng,
};
Ok(Self {
config,
state: Arc::new(RwLock::new(state)),
})
}
pub async fn apply_drift(&self, mut data: Value) -> Result<Value> {
let mut state = self.state.write().await;
state.request_count += 1;
let should_drift = if self.config.time_based {
let elapsed_secs = state.start_time.elapsed().as_secs();
elapsed_secs % self.config.interval == 0
} else if self.config.request_based {
state.request_count % self.config.interval == 0
} else {
true };
if !should_drift {
return Ok(data);
}
for rule in &self.config.rules {
if let Some(obj) = data.as_object_mut() {
if let Some(field_value) = obj.get(&rule.field) {
let new_value = self.apply_rule(rule, field_value.clone(), &mut state)?;
obj.insert(rule.field.clone(), new_value);
}
}
}
Ok(data)
}
fn apply_rule(
&self,
rule: &DriftRule,
current: Value,
state: &mut DriftState,
) -> Result<Value> {
use rand::Rng;
match &rule.strategy {
DriftStrategy::Linear => {
if let Some(num) = current.as_f64() {
let delta = rule.rate;
let mut new_val = num + delta;
if let Some(min) = &rule.min_value {
if let Some(min_num) = min.as_f64() {
new_val = new_val.max(min_num);
}
}
if let Some(max) = &rule.max_value {
if let Some(max_num) = max.as_f64() {
new_val = new_val.min(max_num);
}
}
Ok(Value::from(new_val))
} else {
Ok(current)
}
}
DriftStrategy::Stepped => {
if let Some(num) = current.as_i64() {
let step = rule.rate as i64;
let new_val = num + step;
Ok(Value::from(new_val))
} else {
Ok(current)
}
}
DriftStrategy::StateMachine => {
if let Some(current_state) = current.as_str() {
if let Some(transitions) = &rule.transitions {
if let Some(possible_transitions) = transitions.get(current_state) {
let random_val: f64 = state.rng.random();
let mut cumulative = 0.0;
for (next_state, probability) in possible_transitions {
cumulative += probability;
if random_val <= cumulative {
return Ok(Value::String(next_state.clone()));
}
}
}
}
}
Ok(current)
}
DriftStrategy::RandomWalk => {
if let Some(num) = current.as_f64() {
let delta = state.rng.random_range(-rule.rate..=rule.rate);
let mut new_val = num + delta;
if let Some(min) = &rule.min_value {
if let Some(min_num) = min.as_f64() {
new_val = new_val.max(min_num);
}
}
if let Some(max) = &rule.max_value {
if let Some(max_num) = max.as_f64() {
new_val = new_val.min(max_num);
}
}
Ok(Value::from(new_val))
} else {
Ok(current)
}
}
DriftStrategy::Custom(expr) => {
let expr = expr.trim();
if let Some(num) = current.as_f64() {
if let Some(rest) = expr.strip_prefix("value") {
let rest = rest.trim();
let result = if let Some(operand) = rest.strip_prefix('+') {
operand.trim().parse::<f64>().ok().map(|n| num + n)
} else if let Some(operand) = rest.strip_prefix('-') {
operand.trim().parse::<f64>().ok().map(|n| num - n)
} else if let Some(operand) = rest.strip_prefix('*') {
operand.trim().parse::<f64>().ok().map(|n| num * n)
} else if let Some(operand) = rest.strip_prefix('%') {
operand.trim().parse::<f64>().ok().map(|n| {
if n != 0.0 {
num % n
} else {
num
}
})
} else {
None
};
if let Some(mut new_val) = result {
if let Some(min) = &rule.min_value {
if let Some(min_num) = min.as_f64() {
new_val = new_val.max(min_num);
}
}
if let Some(max) = &rule.max_value {
if let Some(max_num) = max.as_f64() {
new_val = new_val.min(max_num);
}
}
return Ok(Value::from(new_val));
}
}
if let Some(inner) =
expr.strip_prefix("clamp(").and_then(|s| s.strip_suffix(')'))
{
let parts: Vec<&str> = inner.split(',').collect();
if parts.len() == 2 {
if let (Ok(min), Ok(max)) =
(parts[0].trim().parse::<f64>(), parts[1].trim().parse::<f64>())
{
return Ok(Value::from(num.clamp(min, max)));
}
}
}
if let Ok(literal) = expr.parse::<f64>() {
return Ok(Value::from(literal));
}
}
if !expr.starts_with("value") && !expr.starts_with("clamp") {
if let Ok(parsed) = serde_json::from_str::<Value>(expr) {
return Ok(parsed);
}
return Ok(Value::String(expr.to_string()));
}
Ok(current)
}
}
}
pub async fn reset(&self) {
let mut state = self.state.write().await;
state.values.clear();
state.request_count = 0;
state.start_time = std::time::Instant::now();
}
pub async fn request_count(&self) -> u64 {
self.state.read().await.request_count
}
pub async fn elapsed_secs(&self) -> u64 {
self.state.read().await.start_time.elapsed().as_secs()
}
pub fn update_config(&mut self, config: DataDriftConfig) -> Result<()> {
config.validate()?;
self.config = config;
Ok(())
}
pub fn config(&self) -> &DataDriftConfig {
&self.config
}
}
pub mod scenarios {
use super::*;
pub fn order_status_drift() -> DriftRule {
let mut transitions = HashMap::new();
transitions.insert(
"pending".to_string(),
vec![
("processing".to_string(), 0.7),
("cancelled".to_string(), 0.3),
],
);
transitions.insert(
"processing".to_string(),
vec![("shipped".to_string(), 0.9), ("cancelled".to_string(), 0.1)],
);
transitions.insert("shipped".to_string(), vec![("delivered".to_string(), 1.0)]);
transitions.insert("delivered".to_string(), vec![]);
transitions.insert("cancelled".to_string(), vec![]);
DriftRule::new("status".to_string(), DriftStrategy::StateMachine)
.with_states(vec![
"pending".to_string(),
"processing".to_string(),
"shipped".to_string(),
"delivered".to_string(),
"cancelled".to_string(),
])
.with_transitions(transitions)
}
pub fn stock_depletion_drift() -> DriftRule {
DriftRule::new("quantity".to_string(), DriftStrategy::Linear)
.with_rate(-1.0)
.with_bounds(Value::from(0), Value::from(1000))
}
pub fn price_fluctuation_drift() -> DriftRule {
DriftRule::new("price".to_string(), DriftStrategy::RandomWalk)
.with_rate(0.5)
.with_bounds(Value::from(0.0), Value::from(10000.0))
}
pub fn activity_score_drift() -> DriftRule {
DriftRule::new("activity_score".to_string(), DriftStrategy::Linear)
.with_rate(0.1)
.with_bounds(Value::from(0.0), Value::from(100.0))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_drift_strategy_serde() {
let strategy = DriftStrategy::Linear;
let serialized = serde_json::to_string(&strategy).unwrap();
let deserialized: DriftStrategy = serde_json::from_str(&serialized).unwrap();
assert_eq!(strategy, deserialized);
}
#[test]
fn test_drift_rule_builder() {
let rule = DriftRule::new("quantity".to_string(), DriftStrategy::Linear)
.with_rate(1.5)
.with_bounds(Value::from(0), Value::from(100));
assert_eq!(rule.field, "quantity");
assert_eq!(rule.strategy, DriftStrategy::Linear);
assert_eq!(rule.rate, 1.5);
}
#[test]
fn test_drift_rule_validate() {
let rule = DriftRule::new("test".to_string(), DriftStrategy::Linear);
assert!(rule.validate().is_ok());
}
#[test]
fn test_drift_rule_validate_empty_field() {
let rule = DriftRule::new("".to_string(), DriftStrategy::Linear);
assert!(rule.validate().is_err());
}
#[test]
fn test_drift_config_builder() {
let rule = DriftRule::new("field".to_string(), DriftStrategy::Linear);
let config = DataDriftConfig::new().with_rule(rule).with_request_based(10).with_seed(42);
assert_eq!(config.rules.len(), 1);
assert!(config.request_based);
assert_eq!(config.interval, 10);
assert_eq!(config.seed, Some(42));
}
#[tokio::test]
async fn test_drift_engine_creation() {
let config = DataDriftConfig::new();
let result = DataDriftEngine::new(config);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_drift_engine_reset() {
let config = DataDriftConfig::new();
let engine = DataDriftEngine::new(config).unwrap();
engine.reset().await;
assert_eq!(engine.request_count().await, 0);
}
#[test]
fn test_order_status_drift_scenario() {
let rule = scenarios::order_status_drift();
assert_eq!(rule.field, "status");
assert_eq!(rule.strategy, DriftStrategy::StateMachine);
}
#[test]
fn test_stock_depletion_drift_scenario() {
let rule = scenarios::stock_depletion_drift();
assert_eq!(rule.field, "quantity");
assert_eq!(rule.strategy, DriftStrategy::Linear);
assert_eq!(rule.rate, -1.0);
}
}