1use crate::{Error, Result};
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
16#[serde(rename_all = "snake_case")]
17pub enum DriftStrategy {
18 Linear,
20 Stepped,
22 StateMachine,
24 RandomWalk,
26 Custom(String),
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct DriftRule {
33 pub field: String,
35 pub strategy: DriftStrategy,
37 pub params: HashMap<String, Value>,
39 pub rate: f64,
41 pub min_value: Option<Value>,
43 pub max_value: Option<Value>,
45 pub states: Option<Vec<String>>,
47 pub transitions: Option<HashMap<String, Vec<(String, f64)>>>,
49}
50
51impl DriftRule {
52 pub fn new(field: String, strategy: DriftStrategy) -> Self {
54 Self {
55 field,
56 strategy,
57 params: HashMap::new(),
58 rate: 1.0,
59 min_value: None,
60 max_value: None,
61 states: None,
62 transitions: None,
63 }
64 }
65
66 pub fn with_rate(mut self, rate: f64) -> Self {
68 self.rate = rate;
69 self
70 }
71
72 pub fn with_bounds(mut self, min: Value, max: Value) -> Self {
74 self.min_value = Some(min);
75 self.max_value = Some(max);
76 self
77 }
78
79 pub fn with_states(mut self, states: Vec<String>) -> Self {
81 self.states = Some(states);
82 self
83 }
84
85 pub fn with_transitions(mut self, transitions: HashMap<String, Vec<(String, f64)>>) -> Self {
87 self.transitions = Some(transitions);
88 self
89 }
90
91 pub fn with_param(mut self, key: String, value: Value) -> Self {
93 self.params.insert(key, value);
94 self
95 }
96
97 pub fn validate(&self) -> Result<()> {
99 if self.field.is_empty() {
100 return Err(Error::generic("Field name cannot be empty"));
101 }
102
103 if self.rate < 0.0 {
104 return Err(Error::generic("Rate must be non-negative"));
105 }
106
107 if self.strategy == DriftStrategy::StateMachine
108 && (self.states.is_none() || self.transitions.is_none())
109 {
110 return Err(Error::generic("State machine strategy requires states and transitions"));
111 }
112
113 Ok(())
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct DataDriftConfig {
120 pub rules: Vec<DriftRule>,
122 pub time_based: bool,
124 pub request_based: bool,
126 pub interval: u64,
128 pub seed: Option<u64>,
130}
131
132impl Default for DataDriftConfig {
133 fn default() -> Self {
134 Self {
135 rules: Vec::new(),
136 time_based: false,
137 request_based: true,
138 interval: 1,
139 seed: None,
140 }
141 }
142}
143
144impl DataDriftConfig {
145 pub fn new() -> Self {
147 Self::default()
148 }
149
150 pub fn with_rule(mut self, rule: DriftRule) -> Self {
152 self.rules.push(rule);
153 self
154 }
155
156 pub fn with_time_based(mut self, interval_secs: u64) -> Self {
158 self.time_based = true;
159 self.interval = interval_secs;
160 self
161 }
162
163 pub fn with_request_based(mut self, interval_requests: u64) -> Self {
165 self.request_based = true;
166 self.interval = interval_requests;
167 self
168 }
169
170 pub fn with_seed(mut self, seed: u64) -> Self {
172 self.seed = Some(seed);
173 self
174 }
175
176 pub fn validate(&self) -> Result<()> {
178 for rule in &self.rules {
179 rule.validate()?;
180 }
181
182 if self.interval == 0 {
183 return Err(Error::generic("Interval must be greater than 0"));
184 }
185
186 Ok(())
187 }
188}
189
190#[derive(Debug)]
192struct DriftState {
193 values: HashMap<String, Value>,
195 request_count: u64,
197 start_time: std::time::Instant,
199 rng: rand::rngs::StdRng,
201}
202
203pub struct DataDriftEngine {
205 config: DataDriftConfig,
207 state: Arc<RwLock<DriftState>>,
209}
210
211impl DataDriftEngine {
212 pub fn new(config: DataDriftConfig) -> Result<Self> {
214 config.validate()?;
215
216 use rand::SeedableRng;
217 let rng = if let Some(seed) = config.seed {
218 rand::rngs::StdRng::seed_from_u64(seed)
219 } else {
220 rand::rngs::StdRng::seed_from_u64(fastrand::u64(..))
221 };
222
223 let state = DriftState {
224 values: HashMap::new(),
225 request_count: 0,
226 start_time: std::time::Instant::now(),
227 rng,
228 };
229
230 Ok(Self {
231 config,
232 state: Arc::new(RwLock::new(state)),
233 })
234 }
235
236 pub async fn apply_drift(&self, mut data: Value) -> Result<Value> {
238 let mut state = self.state.write().await;
239 state.request_count += 1;
240
241 let should_drift = if self.config.time_based {
243 let elapsed_secs = state.start_time.elapsed().as_secs();
244 elapsed_secs % self.config.interval == 0
245 } else if self.config.request_based {
246 state.request_count % self.config.interval == 0
247 } else {
248 true };
250
251 if !should_drift {
252 return Ok(data);
253 }
254
255 for rule in &self.config.rules {
257 if let Some(obj) = data.as_object_mut() {
258 if let Some(field_value) = obj.get(&rule.field) {
259 let new_value = self.apply_rule(rule, field_value.clone(), &mut state)?;
260 obj.insert(rule.field.clone(), new_value);
261 }
262 }
263 }
264
265 Ok(data)
266 }
267
268 fn apply_rule(
270 &self,
271 rule: &DriftRule,
272 current: Value,
273 state: &mut DriftState,
274 ) -> Result<Value> {
275 use rand::Rng;
276
277 match &rule.strategy {
278 DriftStrategy::Linear => {
279 if let Some(num) = current.as_f64() {
281 let delta = rule.rate;
282 let mut new_val = num + delta;
283
284 if let Some(min) = &rule.min_value {
286 if let Some(min_num) = min.as_f64() {
287 new_val = new_val.max(min_num);
288 }
289 }
290 if let Some(max) = &rule.max_value {
291 if let Some(max_num) = max.as_f64() {
292 new_val = new_val.min(max_num);
293 }
294 }
295
296 Ok(Value::from(new_val))
297 } else {
298 Ok(current)
299 }
300 }
301 DriftStrategy::Stepped => {
302 if let Some(num) = current.as_i64() {
304 let step = rule.rate as i64;
305 let new_val = num + step;
306 Ok(Value::from(new_val))
307 } else {
308 Ok(current)
309 }
310 }
311 DriftStrategy::StateMachine => {
312 if let Some(current_state) = current.as_str() {
314 if let Some(transitions) = &rule.transitions {
315 if let Some(possible_transitions) = transitions.get(current_state) {
316 let random_val: f64 = state.rng.random();
318 let mut cumulative = 0.0;
319
320 for (next_state, probability) in possible_transitions {
321 cumulative += probability;
322 if random_val <= cumulative {
323 return Ok(Value::String(next_state.clone()));
324 }
325 }
326 }
327 }
328 }
329 Ok(current)
330 }
331 DriftStrategy::RandomWalk => {
332 if let Some(num) = current.as_f64() {
334 let delta = state.rng.random_range(-rule.rate..=rule.rate);
335 let mut new_val = num + delta;
336
337 if let Some(min) = &rule.min_value {
339 if let Some(min_num) = min.as_f64() {
340 new_val = new_val.max(min_num);
341 }
342 }
343 if let Some(max) = &rule.max_value {
344 if let Some(max_num) = max.as_f64() {
345 new_val = new_val.min(max_num);
346 }
347 }
348
349 Ok(Value::from(new_val))
350 } else {
351 Ok(current)
352 }
353 }
354 DriftStrategy::Custom(_expr) => {
355 Ok(current)
357 }
358 }
359 }
360
361 pub async fn reset(&self) {
363 let mut state = self.state.write().await;
364 state.values.clear();
365 state.request_count = 0;
366 state.start_time = std::time::Instant::now();
367 }
368
369 pub async fn request_count(&self) -> u64 {
371 self.state.read().await.request_count
372 }
373
374 pub async fn elapsed_secs(&self) -> u64 {
376 self.state.read().await.start_time.elapsed().as_secs()
377 }
378
379 pub fn update_config(&mut self, config: DataDriftConfig) -> Result<()> {
381 config.validate()?;
382 self.config = config;
383 Ok(())
384 }
385
386 pub fn config(&self) -> &DataDriftConfig {
388 &self.config
389 }
390}
391
392pub mod scenarios {
394 use super::*;
395
396 pub fn order_status_drift() -> DriftRule {
398 let mut transitions = HashMap::new();
399 transitions.insert(
400 "pending".to_string(),
401 vec![
402 ("processing".to_string(), 0.7),
403 ("cancelled".to_string(), 0.3),
404 ],
405 );
406 transitions.insert(
407 "processing".to_string(),
408 vec![("shipped".to_string(), 0.9), ("cancelled".to_string(), 0.1)],
409 );
410 transitions.insert("shipped".to_string(), vec![("delivered".to_string(), 1.0)]);
411 transitions.insert("delivered".to_string(), vec![]);
412 transitions.insert("cancelled".to_string(), vec![]);
413
414 DriftRule::new("status".to_string(), DriftStrategy::StateMachine)
415 .with_states(vec![
416 "pending".to_string(),
417 "processing".to_string(),
418 "shipped".to_string(),
419 "delivered".to_string(),
420 "cancelled".to_string(),
421 ])
422 .with_transitions(transitions)
423 }
424
425 pub fn stock_depletion_drift() -> DriftRule {
427 DriftRule::new("quantity".to_string(), DriftStrategy::Linear)
428 .with_rate(-1.0)
429 .with_bounds(Value::from(0), Value::from(1000))
430 }
431
432 pub fn price_fluctuation_drift() -> DriftRule {
434 DriftRule::new("price".to_string(), DriftStrategy::RandomWalk)
435 .with_rate(0.5)
436 .with_bounds(Value::from(0.0), Value::from(10000.0))
437 }
438
439 pub fn activity_score_drift() -> DriftRule {
441 DriftRule::new("activity_score".to_string(), DriftStrategy::Linear)
442 .with_rate(0.1)
443 .with_bounds(Value::from(0.0), Value::from(100.0))
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_drift_strategy_serde() {
453 let strategy = DriftStrategy::Linear;
454 let serialized = serde_json::to_string(&strategy).unwrap();
455 let deserialized: DriftStrategy = serde_json::from_str(&serialized).unwrap();
456 assert_eq!(strategy, deserialized);
457 }
458
459 #[test]
460 fn test_drift_rule_builder() {
461 let rule = DriftRule::new("quantity".to_string(), DriftStrategy::Linear)
462 .with_rate(1.5)
463 .with_bounds(Value::from(0), Value::from(100));
464
465 assert_eq!(rule.field, "quantity");
466 assert_eq!(rule.strategy, DriftStrategy::Linear);
467 assert_eq!(rule.rate, 1.5);
468 }
469
470 #[test]
471 fn test_drift_rule_validate() {
472 let rule = DriftRule::new("test".to_string(), DriftStrategy::Linear);
473 assert!(rule.validate().is_ok());
474 }
475
476 #[test]
477 fn test_drift_rule_validate_empty_field() {
478 let rule = DriftRule::new("".to_string(), DriftStrategy::Linear);
479 assert!(rule.validate().is_err());
480 }
481
482 #[test]
483 fn test_drift_config_builder() {
484 let rule = DriftRule::new("field".to_string(), DriftStrategy::Linear);
485 let config = DataDriftConfig::new().with_rule(rule).with_request_based(10).with_seed(42);
486
487 assert_eq!(config.rules.len(), 1);
488 assert!(config.request_based);
489 assert_eq!(config.interval, 10);
490 assert_eq!(config.seed, Some(42));
491 }
492
493 #[tokio::test]
494 async fn test_drift_engine_creation() {
495 let config = DataDriftConfig::new();
496 let result = DataDriftEngine::new(config);
497 assert!(result.is_ok());
498 }
499
500 #[tokio::test]
501 async fn test_drift_engine_reset() {
502 let config = DataDriftConfig::new();
503 let engine = DataDriftEngine::new(config).unwrap();
504 engine.reset().await;
505 assert_eq!(engine.request_count().await, 0);
506 }
507
508 #[test]
509 fn test_order_status_drift_scenario() {
510 let rule = scenarios::order_status_drift();
511 assert_eq!(rule.field, "status");
512 assert_eq!(rule.strategy, DriftStrategy::StateMachine);
513 }
514
515 #[test]
516 fn test_stock_depletion_drift_scenario() {
517 let rule = scenarios::stock_depletion_drift();
518 assert_eq!(rule.field, "quantity");
519 assert_eq!(rule.strategy, DriftStrategy::Linear);
520 assert_eq!(rule.rate, -1.0);
521 }
522}