1use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16#[serde(rename_all = "snake_case")]
17pub enum DriftStrategy {
18 Linear,
20 Stepped,
22 StateMachine,
24 RandomWalk,
26 Custom(String),
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct DriftRule {
33 pub field: String,
35 pub strategy: DriftStrategy,
37 pub params: HashMap<String, Value>,
39 pub rate: f64,
41 pub min_value: Option<Value>,
43 pub max_value: Option<Value>,
45 pub states: Option<Vec<String>>,
47 pub transitions: Option<HashMap<String, Vec<(String, f64)>>>,
49}
50
51impl DriftRule {
52 pub fn new(field: String, strategy: DriftStrategy) -> Self {
54 Self {
55 field,
56 strategy,
57 params: HashMap::new(),
58 rate: 1.0,
59 min_value: None,
60 max_value: None,
61 states: None,
62 transitions: None,
63 }
64 }
65
66 pub fn with_rate(mut self, rate: f64) -> Self {
68 self.rate = rate;
69 self
70 }
71
72 pub fn with_bounds(mut self, min: Value, max: Value) -> Self {
74 self.min_value = Some(min);
75 self.max_value = Some(max);
76 self
77 }
78
79 pub fn with_states(mut self, states: Vec<String>) -> Self {
81 self.states = Some(states);
82 self
83 }
84
85 pub fn with_transitions(mut self, transitions: HashMap<String, Vec<(String, f64)>>) -> Self {
87 self.transitions = Some(transitions);
88 self
89 }
90
91 pub fn with_param(mut self, key: String, value: Value) -> Self {
93 self.params.insert(key, value);
94 self
95 }
96
97 pub fn validate(&self) -> Result<()> {
99 if self.field.is_empty() {
100 return Err(Error::generic("Field name cannot be empty"));
101 }
102
103 if self.rate < 0.0 {
104 return Err(Error::generic("Rate must be non-negative"));
105 }
106
107 if self.strategy == DriftStrategy::StateMachine
108 && (self.states.is_none() || self.transitions.is_none())
109 {
110 return Err(Error::generic("State machine strategy requires states and transitions"));
111 }
112
113 Ok(())
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct DataDriftConfig {
120 pub rules: Vec<DriftRule>,
122 pub time_based: bool,
124 pub request_based: bool,
126 pub interval: u64,
128 pub seed: Option<u64>,
130}
131
132impl Default for DataDriftConfig {
133 fn default() -> Self {
134 Self {
135 rules: Vec::new(),
136 time_based: false,
137 request_based: true,
138 interval: 1,
139 seed: None,
140 }
141 }
142}
143
144impl DataDriftConfig {
145 pub fn new() -> Self {
147 Self::default()
148 }
149
150 pub fn with_rule(mut self, rule: DriftRule) -> Self {
152 self.rules.push(rule);
153 self
154 }
155
156 pub fn with_time_based(mut self, interval_secs: u64) -> Self {
158 self.time_based = true;
159 self.interval = interval_secs;
160 self
161 }
162
163 pub fn with_request_based(mut self, interval_requests: u64) -> Self {
165 self.request_based = true;
166 self.interval = interval_requests;
167 self
168 }
169
170 pub fn with_seed(mut self, seed: u64) -> Self {
172 self.seed = Some(seed);
173 self
174 }
175
176 pub fn validate(&self) -> Result<()> {
178 for rule in &self.rules {
179 rule.validate()?;
180 }
181
182 if self.interval == 0 {
183 return Err(Error::generic("Interval must be greater than 0"));
184 }
185
186 Ok(())
187 }
188}
189
190#[derive(Debug)]
192struct DriftState {
193 values: HashMap<String, Value>,
195 request_count: u64,
197 start_time: std::time::Instant,
199 rng: rand::rngs::StdRng,
201}
202
203pub struct DataDriftEngine {
205 config: DataDriftConfig,
207 state: Arc<RwLock<DriftState>>,
209}
210
211impl DataDriftEngine {
212 pub fn new(config: DataDriftConfig) -> Result<Self> {
214 config.validate()?;
215
216 use rand::SeedableRng;
217 let rng = if let Some(seed) = config.seed {
218 rand::rngs::StdRng::seed_from_u64(seed)
219 } else {
220 rand::rngs::StdRng::seed_from_u64(fastrand::u64(..))
221 };
222
223 let state = DriftState {
224 values: HashMap::new(),
225 request_count: 0,
226 start_time: std::time::Instant::now(),
227 rng,
228 };
229
230 Ok(Self {
231 config,
232 state: Arc::new(RwLock::new(state)),
233 })
234 }
235
236 pub async fn apply_drift(&self, mut data: Value) -> Result<Value> {
238 let mut state = self.state.write().await;
239 state.request_count += 1;
240
241 let should_drift = if self.config.time_based {
243 let elapsed_secs = state.start_time.elapsed().as_secs();
244 elapsed_secs % self.config.interval == 0
245 } else if self.config.request_based {
246 state.request_count % self.config.interval == 0
247 } else {
248 true };
250
251 if !should_drift {
252 return Ok(data);
253 }
254
255 for rule in &self.config.rules {
257 if let Some(obj) = data.as_object_mut() {
258 if let Some(field_value) = obj.get(&rule.field) {
259 let new_value = self.apply_rule(rule, field_value.clone(), &mut state)?;
260 obj.insert(rule.field.clone(), new_value);
261 }
262 }
263 }
264
265 Ok(data)
266 }
267
268 fn apply_rule(
270 &self,
271 rule: &DriftRule,
272 current: Value,
273 state: &mut DriftState,
274 ) -> Result<Value> {
275 use rand::Rng;
276
277 match &rule.strategy {
278 DriftStrategy::Linear => {
279 if let Some(num) = current.as_f64() {
281 let delta = rule.rate;
282 let mut new_val = num + delta;
283
284 if let Some(min) = &rule.min_value {
286 if let Some(min_num) = min.as_f64() {
287 new_val = new_val.max(min_num);
288 }
289 }
290 if let Some(max) = &rule.max_value {
291 if let Some(max_num) = max.as_f64() {
292 new_val = new_val.min(max_num);
293 }
294 }
295
296 Ok(Value::from(new_val))
297 } else {
298 Ok(current)
299 }
300 }
301 DriftStrategy::Stepped => {
302 if let Some(num) = current.as_i64() {
304 let step = rule.rate as i64;
305 let new_val = num + step;
306 Ok(Value::from(new_val))
307 } else {
308 Ok(current)
309 }
310 }
311 DriftStrategy::StateMachine => {
312 if let Some(current_state) = current.as_str() {
314 if let Some(transitions) = &rule.transitions {
315 if let Some(possible_transitions) = transitions.get(current_state) {
316 let random_val: f64 = state.rng.random();
318 let mut cumulative = 0.0;
319
320 for (next_state, probability) in possible_transitions {
321 cumulative += probability;
322 if random_val <= cumulative {
323 return Ok(Value::String(next_state.clone()));
324 }
325 }
326 }
327 }
328 }
329 Ok(current)
330 }
331 DriftStrategy::RandomWalk => {
332 if let Some(num) = current.as_f64() {
334 let delta = state.rng.random_range(-rule.rate..=rule.rate);
335 let mut new_val = num + delta;
336
337 if let Some(min) = &rule.min_value {
339 if let Some(min_num) = min.as_f64() {
340 new_val = new_val.max(min_num);
341 }
342 }
343 if let Some(max) = &rule.max_value {
344 if let Some(max_num) = max.as_f64() {
345 new_val = new_val.min(max_num);
346 }
347 }
348
349 Ok(Value::from(new_val))
350 } else {
351 Ok(current)
352 }
353 }
354 DriftStrategy::Custom(expr) => {
355 let expr = expr.trim();
363
364 if let Some(num) = current.as_f64() {
365 if let Some(rest) = expr.strip_prefix("value") {
367 let rest = rest.trim();
368 let result = if let Some(operand) = rest.strip_prefix('+') {
369 operand.trim().parse::<f64>().ok().map(|n| num + n)
370 } else if let Some(operand) = rest.strip_prefix('-') {
371 operand.trim().parse::<f64>().ok().map(|n| num - n)
372 } else if let Some(operand) = rest.strip_prefix('*') {
373 operand.trim().parse::<f64>().ok().map(|n| num * n)
374 } else if let Some(operand) = rest.strip_prefix('%') {
375 operand.trim().parse::<f64>().ok().map(|n| {
376 if n != 0.0 {
377 num % n
378 } else {
379 num
380 }
381 })
382 } else {
383 None
384 };
385
386 if let Some(mut new_val) = result {
387 if let Some(min) = &rule.min_value {
389 if let Some(min_num) = min.as_f64() {
390 new_val = new_val.max(min_num);
391 }
392 }
393 if let Some(max) = &rule.max_value {
394 if let Some(max_num) = max.as_f64() {
395 new_val = new_val.min(max_num);
396 }
397 }
398 return Ok(Value::from(new_val));
399 }
400 }
401
402 if let Some(inner) =
404 expr.strip_prefix("clamp(").and_then(|s| s.strip_suffix(')'))
405 {
406 let parts: Vec<&str> = inner.split(',').collect();
407 if parts.len() == 2 {
408 if let (Ok(min), Ok(max)) =
409 (parts[0].trim().parse::<f64>(), parts[1].trim().parse::<f64>())
410 {
411 return Ok(Value::from(num.clamp(min, max)));
412 }
413 }
414 }
415
416 if let Ok(literal) = expr.parse::<f64>() {
418 return Ok(Value::from(literal));
419 }
420 }
421
422 if !expr.starts_with("value") && !expr.starts_with("clamp") {
424 if let Ok(parsed) = serde_json::from_str::<Value>(expr) {
426 return Ok(parsed);
427 }
428 return Ok(Value::String(expr.to_string()));
430 }
431
432 Ok(current)
433 }
434 }
435 }
436
437 pub async fn reset(&self) {
439 let mut state = self.state.write().await;
440 state.values.clear();
441 state.request_count = 0;
442 state.start_time = std::time::Instant::now();
443 }
444
445 pub async fn request_count(&self) -> u64 {
447 self.state.read().await.request_count
448 }
449
450 pub async fn elapsed_secs(&self) -> u64 {
452 self.state.read().await.start_time.elapsed().as_secs()
453 }
454
455 pub fn update_config(&mut self, config: DataDriftConfig) -> Result<()> {
457 config.validate()?;
458 self.config = config;
459 Ok(())
460 }
461
462 pub fn config(&self) -> &DataDriftConfig {
464 &self.config
465 }
466}
467
468pub mod scenarios {
470 use super::*;
471
472 pub fn order_status_drift() -> DriftRule {
474 let mut transitions = HashMap::new();
475 transitions.insert(
476 "pending".to_string(),
477 vec![
478 ("processing".to_string(), 0.7),
479 ("cancelled".to_string(), 0.3),
480 ],
481 );
482 transitions.insert(
483 "processing".to_string(),
484 vec![("shipped".to_string(), 0.9), ("cancelled".to_string(), 0.1)],
485 );
486 transitions.insert("shipped".to_string(), vec![("delivered".to_string(), 1.0)]);
487 transitions.insert("delivered".to_string(), vec![]);
488 transitions.insert("cancelled".to_string(), vec![]);
489
490 DriftRule::new("status".to_string(), DriftStrategy::StateMachine)
491 .with_states(vec![
492 "pending".to_string(),
493 "processing".to_string(),
494 "shipped".to_string(),
495 "delivered".to_string(),
496 "cancelled".to_string(),
497 ])
498 .with_transitions(transitions)
499 }
500
501 pub fn stock_depletion_drift() -> DriftRule {
503 DriftRule::new("quantity".to_string(), DriftStrategy::Linear)
504 .with_rate(-1.0)
505 .with_bounds(Value::from(0), Value::from(1000))
506 }
507
508 pub fn price_fluctuation_drift() -> DriftRule {
510 DriftRule::new("price".to_string(), DriftStrategy::RandomWalk)
511 .with_rate(0.5)
512 .with_bounds(Value::from(0.0), Value::from(10000.0))
513 }
514
515 pub fn activity_score_drift() -> DriftRule {
517 DriftRule::new("activity_score".to_string(), DriftStrategy::Linear)
518 .with_rate(0.1)
519 .with_bounds(Value::from(0.0), Value::from(100.0))
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_drift_strategy_serde() {
529 let strategy = DriftStrategy::Linear;
530 let serialized = serde_json::to_string(&strategy).unwrap();
531 let deserialized: DriftStrategy = serde_json::from_str(&serialized).unwrap();
532 assert_eq!(strategy, deserialized);
533 }
534
535 #[test]
536 fn test_drift_rule_builder() {
537 let rule = DriftRule::new("quantity".to_string(), DriftStrategy::Linear)
538 .with_rate(1.5)
539 .with_bounds(Value::from(0), Value::from(100));
540
541 assert_eq!(rule.field, "quantity");
542 assert_eq!(rule.strategy, DriftStrategy::Linear);
543 assert_eq!(rule.rate, 1.5);
544 }
545
546 #[test]
547 fn test_drift_rule_validate() {
548 let rule = DriftRule::new("test".to_string(), DriftStrategy::Linear);
549 assert!(rule.validate().is_ok());
550 }
551
552 #[test]
553 fn test_drift_rule_validate_empty_field() {
554 let rule = DriftRule::new("".to_string(), DriftStrategy::Linear);
555 assert!(rule.validate().is_err());
556 }
557
558 #[test]
559 fn test_drift_config_builder() {
560 let rule = DriftRule::new("field".to_string(), DriftStrategy::Linear);
561 let config = DataDriftConfig::new().with_rule(rule).with_request_based(10).with_seed(42);
562
563 assert_eq!(config.rules.len(), 1);
564 assert!(config.request_based);
565 assert_eq!(config.interval, 10);
566 assert_eq!(config.seed, Some(42));
567 }
568
569 #[tokio::test]
570 async fn test_drift_engine_creation() {
571 let config = DataDriftConfig::new();
572 let result = DataDriftEngine::new(config);
573 assert!(result.is_ok());
574 }
575
576 #[tokio::test]
577 async fn test_drift_engine_reset() {
578 let config = DataDriftConfig::new();
579 let engine = DataDriftEngine::new(config).unwrap();
580 engine.reset().await;
581 assert_eq!(engine.request_count().await, 0);
582 }
583
584 #[test]
585 fn test_order_status_drift_scenario() {
586 let rule = scenarios::order_status_drift();
587 assert_eq!(rule.field, "status");
588 assert_eq!(rule.strategy, DriftStrategy::StateMachine);
589 }
590
591 #[test]
592 fn test_stock_depletion_drift_scenario() {
593 let rule = scenarios::stock_depletion_drift();
594 assert_eq!(rule.field, "quantity");
595 assert_eq!(rule.strategy, DriftStrategy::Linear);
596 assert_eq!(rule.rate, -1.0);
597 }
598}