Skip to main content

datasynth_generators/
drift_event_generator.rs

1//! Generator for drift events (statistical, temporal, market, behavioral, organizational, process).
2//!
3//! Produces `LabeledDriftEvent` instances by:
4//! 1. Observing organizational events and deriving organizational drift labels.
5//! 2. Observing process evolution events and deriving process drift labels.
6//! 3. Generating standalone statistical, temporal, market, and behavioral drifts.
7//!
8//! All generation is fully deterministic given the same seed.
9
10use chrono::NaiveDate;
11use datasynth_core::utils::seeded_rng;
12use rand::Rng;
13use rand_chacha::ChaCha8Rng;
14use std::collections::HashMap;
15
16use datasynth_core::models::drift_events::{
17    BehavioralDriftEvent, DetectionDifficulty, DriftEventType, LabeledDriftEvent, MarketDriftEvent,
18    MarketEventType, OrganizationalDriftEvent, ProcessDriftEvent, StatisticalDriftEvent,
19    StatisticalShiftType, TemporalDriftEvent, TemporalShiftType,
20};
21use datasynth_core::models::organizational_event::{OrganizationalEvent, OrganizationalEventType};
22use datasynth_core::models::process_evolution::{ProcessEvolutionEvent, ProcessEvolutionType};
23
24/// Configuration for the drift event generator.
25#[derive(Debug, Clone)]
26pub struct DriftEventGeneratorConfig {
27    /// Average standalone drifts per year (statistical, temporal, market, behavioral).
28    pub standalone_drifts_per_year: f64,
29    /// Probability of generating a drift from an org event.
30    pub org_event_drift_prob: f64,
31    /// Probability of generating a drift from a process event.
32    pub process_event_drift_prob: f64,
33}
34
35impl Default for DriftEventGeneratorConfig {
36    fn default() -> Self {
37        Self {
38            standalone_drifts_per_year: 6.0,
39            org_event_drift_prob: 0.8,
40            process_event_drift_prob: 0.7,
41        }
42    }
43}
44
45/// Generates [`LabeledDriftEvent`] instances from organizational events,
46/// process evolution events, and standalone random drifts.
47pub struct DriftEventGenerator {
48    rng: ChaCha8Rng,
49    config: DriftEventGeneratorConfig,
50    event_counter: usize,
51}
52
53/// Discriminator added to the seed so this generator's RNG stream does not
54/// overlap with other generators that may share the same base seed.
55const SEED_DISCRIMINATOR: u64 = 0xAE_0D;
56
57impl DriftEventGenerator {
58    /// Create a new generator with the given seed and default config.
59    pub fn new(seed: u64) -> Self {
60        Self {
61            rng: seeded_rng(seed, SEED_DISCRIMINATOR),
62            config: DriftEventGeneratorConfig::default(),
63            event_counter: 0,
64        }
65    }
66
67    /// Create a new generator with the given seed and custom config.
68    pub fn with_config(seed: u64, config: DriftEventGeneratorConfig) -> Self {
69        Self {
70            rng: seeded_rng(seed, SEED_DISCRIMINATOR),
71            config,
72            event_counter: 0,
73        }
74    }
75
76    /// Generate all drift events: from org events, process events, and standalone.
77    ///
78    /// Results are returned sorted by `start_date`.
79    pub fn generate_all(
80        &mut self,
81        start_date: NaiveDate,
82        end_date: NaiveDate,
83        org_events: &[OrganizationalEvent],
84        proc_events: &[ProcessEvolutionEvent],
85    ) -> Vec<LabeledDriftEvent> {
86        let mut all = Vec::new();
87
88        let mut from_org = self.generate_from_org_events(org_events);
89        let mut from_proc = self.generate_from_process_events(proc_events);
90        let mut standalone = self.generate_standalone_drifts(start_date, end_date);
91
92        all.append(&mut from_org);
93        all.append(&mut from_proc);
94        all.append(&mut standalone);
95
96        all.sort_by_key(|e| e.start_date);
97        all
98    }
99
100    /// Generate drift events derived from organizational events.
101    ///
102    /// For each org event, with probability `config.org_event_drift_prob`, creates
103    /// an Organizational drift label.
104    pub fn generate_from_org_events(
105        &mut self,
106        org_events: &[OrganizationalEvent],
107    ) -> Vec<LabeledDriftEvent> {
108        let mut drifts = Vec::new();
109
110        for org_event in org_events {
111            if !self.rng.random_bool(self.config.org_event_drift_prob) {
112                continue;
113            }
114
115            let event_id = self.next_event_id();
116
117            let (detection_difficulty, magnitude) = match &org_event.event_type {
118                OrganizationalEventType::Merger(_) | OrganizationalEventType::Acquisition(_) => {
119                    let mag = self.rng.random_range(0.3..0.8);
120                    (DetectionDifficulty::Easy, mag)
121                }
122                OrganizationalEventType::Reorganization(_)
123                | OrganizationalEventType::WorkforceReduction(_) => {
124                    let mag = self.rng.random_range(0.1..0.4);
125                    (DetectionDifficulty::Medium, mag)
126                }
127                OrganizationalEventType::LeadershipChange(_) => {
128                    let mag = self.rng.random_range(0.1..0.4);
129                    (DetectionDifficulty::Hard, mag)
130                }
131                OrganizationalEventType::Divestiture(_) => {
132                    let mag = self.rng.random_range(0.1..0.4);
133                    (DetectionDifficulty::Medium, mag)
134                }
135            };
136
137            let duration_days = self.rng.random_range(30..90_i64);
138            let end_date = org_event.effective_date + chrono::Duration::days(duration_days);
139
140            let affected_entities: Vec<String> = org_event
141                .tags
142                .iter()
143                .filter(|t| t.starts_with("company:"))
144                .cloned()
145                .collect();
146
147            let drift_type = DriftEventType::Organizational(OrganizationalDriftEvent {
148                event_type: org_event.event_type.type_name().to_string(),
149                related_event_id: org_event.event_id.clone(),
150                detection_difficulty,
151                affected_entities: affected_entities.clone(),
152                impact_metrics: HashMap::new(),
153            });
154
155            let start_period = 0_u32;
156            let end_period = (duration_days / 30) as u32;
157
158            let mut labeled = LabeledDriftEvent::new(
159                event_id,
160                drift_type,
161                org_event.effective_date,
162                start_period,
163                magnitude,
164            );
165            labeled.end_date = Some(end_date);
166            labeled.end_period = Some(end_period);
167            labeled.related_org_event = Some(org_event.event_id.clone());
168            labeled.affected_fields = affected_entities;
169            labeled.tags = vec![
170                format!("source:organizational"),
171                format!("org_type:{}", org_event.event_type.type_name()),
172            ];
173
174            drifts.push(labeled);
175        }
176
177        drifts
178    }
179
180    /// Generate drift events derived from process evolution events.
181    ///
182    /// For each process event, with probability `config.process_event_drift_prob`,
183    /// creates a Process drift label.
184    pub fn generate_from_process_events(
185        &mut self,
186        proc_events: &[ProcessEvolutionEvent],
187    ) -> Vec<LabeledDriftEvent> {
188        let mut drifts = Vec::new();
189
190        for proc_event in proc_events {
191            if !self.rng.random_bool(self.config.process_event_drift_prob) {
192                continue;
193            }
194
195            let event_id = self.next_event_id();
196
197            let detection_difficulty = match &proc_event.event_type {
198                ProcessEvolutionType::ProcessAutomation(_)
199                | ProcessEvolutionType::ApprovalWorkflowChange(_) => DetectionDifficulty::Medium,
200                ProcessEvolutionType::PolicyChange(_)
201                | ProcessEvolutionType::ControlEnhancement(_) => DetectionDifficulty::Hard,
202            };
203
204            let transition_months = proc_event.event_type.transition_months();
205            let duration_days = (transition_months as i64) * 30;
206            let end_date = proc_event.effective_date + chrono::Duration::days(duration_days);
207
208            // Magnitude based on error_rate_impact, scaled to 0.1..0.6 range
209            let raw_impact = proc_event.event_type.error_rate_impact().abs();
210            let magnitude = (raw_impact * 6.0).clamp(0.1, 0.6);
211
212            let drift_type = DriftEventType::Process(ProcessDriftEvent {
213                process_type: proc_event.event_type.type_name().to_string(),
214                related_event_id: proc_event.event_id.clone(),
215                detection_difficulty,
216                affected_processes: proc_event.tags.clone(),
217            });
218
219            let start_period = 0_u32;
220            let end_period = transition_months;
221
222            let mut labeled = LabeledDriftEvent::new(
223                event_id,
224                drift_type,
225                proc_event.effective_date,
226                start_period,
227                magnitude,
228            );
229            labeled.end_date = Some(end_date);
230            labeled.end_period = Some(end_period);
231            labeled.tags = vec![
232                "source:process".to_string(),
233                format!("process_type:{}", proc_event.event_type.type_name()),
234            ];
235
236            drifts.push(labeled);
237        }
238
239        drifts
240    }
241
242    /// Generate standalone drifts (statistical, temporal, market, behavioral)
243    /// randomly distributed across the date range.
244    pub fn generate_standalone_drifts(
245        &mut self,
246        start_date: NaiveDate,
247        end_date: NaiveDate,
248    ) -> Vec<LabeledDriftEvent> {
249        let total_days = (end_date - start_date).num_days().max(1) as f64;
250        let total_years = total_days / 365.25;
251        let expected_count =
252            (self.config.standalone_drifts_per_year * total_years).round() as usize;
253        let count = expected_count.max(1);
254
255        let mut drifts = Vec::with_capacity(count);
256
257        for _ in 0..count {
258            let event_id = self.next_event_id();
259
260            // Pick random start date within range
261            let days_offset = self.rng.random_range(0..total_days as i64);
262            let drift_start = start_date + chrono::Duration::days(days_offset);
263            let duration_days = self.rng.random_range(30..180_i64);
264            let drift_end = drift_start + chrono::Duration::days(duration_days);
265
266            // Pick random category: 0=Statistical, 1=Temporal, 2=Market, 3=Behavioral
267            let category = self.rng.random_range(0..4_u32);
268
269            let (drift_type, magnitude) = match category {
270                0 => self.build_statistical_drift(),
271                1 => self.build_temporal_drift(),
272                2 => self.build_market_drift(),
273                _ => self.build_behavioral_drift(),
274            };
275
276            // Detection difficulty derived from magnitude
277            let detection_difficulty = if magnitude > 0.3 {
278                DetectionDifficulty::Easy
279            } else if magnitude > 0.15 {
280                DetectionDifficulty::Medium
281            } else {
282                DetectionDifficulty::Hard
283            };
284
285            let start_period = 0_u32;
286            let end_period = (duration_days / 30) as u32;
287
288            let mut labeled =
289                LabeledDriftEvent::new(event_id, drift_type, drift_start, start_period, magnitude);
290            labeled.end_date = Some(drift_end);
291            labeled.end_period = Some(end_period);
292            labeled.detection_difficulty = detection_difficulty;
293            labeled.tags = vec!["source:standalone".to_string()];
294
295            drifts.push(labeled);
296        }
297
298        drifts
299    }
300
301    // ------------------------------------------------------------------
302    // Standalone drift type builders
303    // ------------------------------------------------------------------
304
305    fn build_statistical_drift(&mut self) -> (DriftEventType, f64) {
306        let shift_types = [
307            StatisticalShiftType::MeanShift,
308            StatisticalShiftType::VarianceChange,
309            StatisticalShiftType::DistributionChange,
310            StatisticalShiftType::CorrelationChange,
311            StatisticalShiftType::TailChange,
312            StatisticalShiftType::BenfordDeviation,
313        ];
314        let idx = self.rng.random_range(0..shift_types.len());
315        let shift_type = shift_types[idx];
316
317        let fields = [
318            "amount",
319            "line_count",
320            "processing_time",
321            "approval_duration",
322        ];
323        let field_idx = self.rng.random_range(0..fields.len());
324        let affected_field = fields[field_idx].to_string();
325
326        let magnitude = self.rng.random_range(0.05..0.40);
327
328        let detection_difficulty = if magnitude > 0.3 {
329            DetectionDifficulty::Easy
330        } else if magnitude > 0.15 {
331            DetectionDifficulty::Medium
332        } else {
333            DetectionDifficulty::Hard
334        };
335
336        let drift_type = DriftEventType::Statistical(StatisticalDriftEvent {
337            shift_type,
338            affected_field,
339            magnitude,
340            detection_difficulty,
341            metrics: HashMap::new(),
342        });
343
344        (drift_type, magnitude)
345    }
346
347    fn build_temporal_drift(&mut self) -> (DriftEventType, f64) {
348        let shift_types = [
349            TemporalShiftType::SeasonalityChange,
350            TemporalShiftType::TrendChange,
351            TemporalShiftType::PeriodicityChange,
352            TemporalShiftType::IntradayChange,
353            TemporalShiftType::LagChange,
354        ];
355        let idx = self.rng.random_range(0..shift_types.len());
356        let shift_type = shift_types[idx];
357
358        let magnitude = self.rng.random_range(0.10..0.50);
359
360        let detection_difficulty = if magnitude > 0.3 {
361            DetectionDifficulty::Easy
362        } else if magnitude > 0.15 {
363            DetectionDifficulty::Medium
364        } else {
365            DetectionDifficulty::Hard
366        };
367
368        let drift_type = DriftEventType::Temporal(TemporalDriftEvent {
369            shift_type,
370            affected_field: None,
371            detection_difficulty,
372            magnitude,
373            description: None,
374        });
375
376        (drift_type, magnitude)
377    }
378
379    fn build_market_drift(&mut self) -> (DriftEventType, f64) {
380        let market_types = [
381            MarketEventType::EconomicCycle,
382            MarketEventType::RecessionStart,
383            MarketEventType::RecessionEnd,
384            MarketEventType::PriceShock,
385            MarketEventType::CommodityChange,
386        ];
387        let idx = self.rng.random_range(0..market_types.len());
388        let market_type = market_types[idx];
389
390        let magnitude = self.rng.random_range(0.10..0.60);
391
392        let is_recession = matches!(
393            market_type,
394            MarketEventType::RecessionStart | MarketEventType::RecessionEnd
395        );
396
397        let detection_difficulty = if magnitude > 0.3 {
398            DetectionDifficulty::Easy
399        } else if magnitude > 0.15 {
400            DetectionDifficulty::Medium
401        } else {
402            DetectionDifficulty::Hard
403        };
404
405        let drift_type = DriftEventType::Market(MarketDriftEvent {
406            market_type,
407            detection_difficulty,
408            magnitude,
409            is_recession,
410            affected_sectors: Vec::new(),
411        });
412
413        (drift_type, magnitude)
414    }
415
416    fn build_behavioral_drift(&mut self) -> (DriftEventType, f64) {
417        let behavior_types = [
418            "vendor_quality",
419            "customer_payment",
420            "employee_productivity",
421            "approval_pattern",
422        ];
423        let entity_types = ["vendor", "customer", "employee"];
424
425        let bt_idx = self.rng.random_range(0..behavior_types.len());
426        let et_idx = self.rng.random_range(0..entity_types.len());
427
428        let behavior_type = behavior_types[bt_idx].to_string();
429        let entity_type = entity_types[et_idx].to_string();
430
431        let magnitude = self.rng.random_range(0.05..0.40);
432
433        let detection_difficulty = if magnitude > 0.3 {
434            DetectionDifficulty::Easy
435        } else if magnitude > 0.15 {
436            DetectionDifficulty::Medium
437        } else {
438            DetectionDifficulty::Hard
439        };
440
441        let drift_type = DriftEventType::Behavioral(BehavioralDriftEvent {
442            behavior_type,
443            entity_type,
444            detection_difficulty,
445            metrics: HashMap::new(),
446        });
447
448        (drift_type, magnitude)
449    }
450
451    // ------------------------------------------------------------------
452    // Helper
453    // ------------------------------------------------------------------
454
455    fn next_event_id(&mut self) -> String {
456        self.event_counter += 1;
457        format!("DRIFT-{:06}", self.event_counter)
458    }
459}
460
461#[cfg(test)]
462#[allow(clippy::unwrap_used)]
463mod tests {
464    use super::*;
465    use datasynth_core::models::organizational_event::{
466        AcquisitionConfig, MergerConfig, OrganizationalEventType,
467    };
468    use datasynth_core::models::process_evolution::{
469        ProcessAutomationConfig, ProcessEvolutionType,
470    };
471
472    fn make_org_events() -> Vec<OrganizationalEvent> {
473        let acq = OrganizationalEvent {
474            event_id: "ORG-001".to_string(),
475            event_type: OrganizationalEventType::Acquisition(AcquisitionConfig {
476                acquisition_date: NaiveDate::from_ymd_opt(2024, 3, 1).unwrap(),
477                ..Default::default()
478            }),
479            effective_date: NaiveDate::from_ymd_opt(2024, 3, 1).unwrap(),
480            description: Some("Acquisition".to_string()),
481            tags: vec!["company:C001".to_string(), "type:acquisition".to_string()],
482        };
483
484        let merger = OrganizationalEvent {
485            event_id: "ORG-002".to_string(),
486            event_type: OrganizationalEventType::Merger(MergerConfig {
487                merger_date: NaiveDate::from_ymd_opt(2024, 6, 1).unwrap(),
488                ..Default::default()
489            }),
490            effective_date: NaiveDate::from_ymd_opt(2024, 6, 1).unwrap(),
491            description: Some("Merger".to_string()),
492            tags: vec!["company:C002".to_string(), "type:merger".to_string()],
493        };
494
495        vec![acq, merger]
496    }
497
498    fn make_proc_events() -> Vec<ProcessEvolutionEvent> {
499        vec![
500            ProcessEvolutionEvent::new(
501                "PROC-001",
502                ProcessEvolutionType::ProcessAutomation(ProcessAutomationConfig {
503                    rollout_months: 6,
504                    ..Default::default()
505                }),
506                NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
507            ),
508            ProcessEvolutionEvent::new(
509                "PROC-002",
510                ProcessEvolutionType::ProcessAutomation(ProcessAutomationConfig {
511                    rollout_months: 3,
512                    ..Default::default()
513                }),
514                NaiveDate::from_ymd_opt(2024, 8, 1).unwrap(),
515            ),
516        ]
517    }
518
519    #[test]
520    fn test_deterministic_generation() {
521        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
522        let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
523        let org = make_org_events();
524        let proc = make_proc_events();
525
526        let mut gen1 = DriftEventGenerator::new(42);
527        let mut gen2 = DriftEventGenerator::new(42);
528
529        let drifts1 = gen1.generate_all(start, end, &org, &proc);
530        let drifts2 = gen2.generate_all(start, end, &org, &proc);
531
532        assert_eq!(drifts1.len(), drifts2.len());
533        for (d1, d2) in drifts1.iter().zip(drifts2.iter()) {
534            assert_eq!(d1.event_id, d2.event_id);
535            assert_eq!(d1.start_date, d2.start_date);
536            assert!((d1.magnitude - d2.magnitude).abs() < 1e-10);
537        }
538    }
539
540    #[test]
541    fn test_drift_from_org_events() {
542        let org = make_org_events();
543        let config = DriftEventGeneratorConfig {
544            org_event_drift_prob: 1.0, // Always generate
545            ..Default::default()
546        };
547        let mut gen = DriftEventGenerator::with_config(42, config);
548
549        let drifts = gen.generate_from_org_events(&org);
550
551        // With prob=1.0, all org events should produce drifts
552        assert_eq!(drifts.len(), org.len());
553
554        for drift in &drifts {
555            // Each drift should have related_org_event set
556            assert!(drift.related_org_event.is_some());
557
558            // related_org_event should match one of the org event IDs
559            let related_id = drift.related_org_event.as_ref().unwrap();
560            assert!(
561                org.iter().any(|e| &e.event_id == related_id),
562                "related_org_event should match an org event id"
563            );
564
565            // Event type should be Organizational
566            assert_eq!(
567                drift.event_type.category_name(),
568                "organizational",
569                "drift from org event should be Organizational category"
570            );
571        }
572    }
573
574    #[test]
575    fn test_drift_from_process_events() {
576        let proc = make_proc_events();
577        let config = DriftEventGeneratorConfig {
578            process_event_drift_prob: 1.0, // Always generate
579            ..Default::default()
580        };
581        let mut gen = DriftEventGenerator::with_config(42, config);
582
583        let drifts = gen.generate_from_process_events(&proc);
584
585        // With prob=1.0, all process events should produce drifts
586        assert_eq!(drifts.len(), proc.len());
587
588        for drift in &drifts {
589            assert_eq!(
590                drift.event_type.category_name(),
591                "process",
592                "drift from process event should be Process category"
593            );
594        }
595    }
596
597    #[test]
598    fn test_standalone_drifts() {
599        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
600        let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
601
602        let mut gen = DriftEventGenerator::new(42);
603        let drifts = gen.generate_standalone_drifts(start, end);
604
605        // With default 6 drifts/year and ~1 year range, we expect ~6 drifts
606        assert!(!drifts.is_empty(), "should produce standalone drifts");
607        assert!(
608            drifts.len() >= 4,
609            "should produce at least 4 standalone drifts"
610        );
611    }
612
613    #[test]
614    fn test_magnitude_in_valid_range() {
615        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
616        let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
617        let org = make_org_events();
618        let proc = make_proc_events();
619
620        let mut gen = DriftEventGenerator::new(42);
621        let drifts = gen.generate_all(start, end, &org, &proc);
622
623        for drift in &drifts {
624            assert!(
625                drift.magnitude >= 0.0 && drift.magnitude <= 1.0,
626                "magnitude {} should be in [0.0, 1.0]",
627                drift.magnitude
628            );
629        }
630    }
631
632    #[test]
633    fn test_detection_difficulty_correlates_with_magnitude() {
634        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
635        let end = NaiveDate::from_ymd_opt(2025, 12, 31).unwrap();
636
637        let config = DriftEventGeneratorConfig {
638            standalone_drifts_per_year: 100.0,
639            org_event_drift_prob: 0.0,
640            process_event_drift_prob: 0.0,
641        };
642        let mut gen = DriftEventGenerator::with_config(42, config);
643        let drifts = gen.generate_standalone_drifts(start, end);
644
645        // For standalone drifts, detection difficulty is set based on magnitude:
646        // >0.3 -> Easy, >0.15 -> Medium, else Hard
647        for drift in &drifts {
648            if drift.magnitude > 0.3 {
649                assert_eq!(
650                    drift.detection_difficulty,
651                    DetectionDifficulty::Easy,
652                    "magnitude {} should be Easy",
653                    drift.magnitude
654                );
655            } else if drift.magnitude > 0.15 {
656                assert_eq!(
657                    drift.detection_difficulty,
658                    DetectionDifficulty::Medium,
659                    "magnitude {} should be Medium",
660                    drift.magnitude
661                );
662            } else {
663                assert_eq!(
664                    drift.detection_difficulty,
665                    DetectionDifficulty::Hard,
666                    "magnitude {} should be Hard",
667                    drift.magnitude
668                );
669            }
670        }
671    }
672
673    #[test]
674    fn test_all_standalone_categories() {
675        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
676        let end = NaiveDate::from_ymd_opt(2024, 12, 31).unwrap();
677
678        let config = DriftEventGeneratorConfig {
679            standalone_drifts_per_year: 60.0,
680            org_event_drift_prob: 0.0,
681            process_event_drift_prob: 0.0,
682        };
683        let mut gen = DriftEventGenerator::with_config(42, config);
684        let drifts = gen.generate_standalone_drifts(start, end);
685
686        let has_statistical = drifts
687            .iter()
688            .any(|d| d.event_type.category_name() == "statistical");
689        let has_temporal = drifts
690            .iter()
691            .any(|d| d.event_type.category_name() == "temporal");
692        let has_market = drifts
693            .iter()
694            .any(|d| d.event_type.category_name() == "market");
695        let has_behavioral = drifts
696            .iter()
697            .any(|d| d.event_type.category_name() == "behavioral");
698
699        assert!(has_statistical, "should generate statistical drifts");
700        assert!(has_temporal, "should generate temporal drifts");
701        assert!(has_market, "should generate market drifts");
702        assert!(has_behavioral, "should generate behavioral drifts");
703    }
704}