Skip to main content

datasynth_core/distributions/
drift_recorder.rs

1//! Drift label recorder for ground truth generation.
2//!
3//! Records drift events during data generation for use as
4//! ground truth labels in ML model training and evaluation.
5
6use crate::distributions::drift::{DriftAdjustments, RegimeChange, RegimeChangeType};
7use crate::models::drift_events::{
8    CategoricalDriftEvent, CategoricalShiftType, DetectionDifficulty, DriftEventType,
9    LabeledDriftEvent, MarketDriftEvent, MarketEventType, OrganizationalDriftEvent,
10    ProcessDriftEvent, StatisticalDriftEvent, StatisticalShiftType, TechnologyDriftEvent,
11    TemporalDriftEvent, TemporalShiftType,
12};
13use chrono::NaiveDate;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::io::Write;
17use std::path::Path;
18
19/// Configuration for drift recording.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct DriftRecorderConfig {
22    /// Enable recording.
23    #[serde(default)]
24    pub enabled: bool,
25    /// Record statistical drift events.
26    #[serde(default = "default_true")]
27    pub statistical: bool,
28    /// Record categorical drift events.
29    #[serde(default = "default_true")]
30    pub categorical: bool,
31    /// Record temporal drift events.
32    #[serde(default = "default_true")]
33    pub temporal: bool,
34    /// Record organizational events.
35    #[serde(default = "default_true")]
36    pub organizational: bool,
37    /// Record process events.
38    #[serde(default = "default_true")]
39    pub process_events: bool,
40    /// Record technology events.
41    #[serde(default = "default_true")]
42    pub technology_events: bool,
43    /// Record regulatory events.
44    #[serde(default = "default_true")]
45    pub regulatory: bool,
46    /// Record market events.
47    #[serde(default = "default_true")]
48    pub market: bool,
49    /// Record behavioral events.
50    #[serde(default = "default_true")]
51    pub behavioral: bool,
52    /// Minimum magnitude threshold to record.
53    #[serde(default = "default_min_magnitude")]
54    pub min_magnitude_threshold: f64,
55}
56
57fn default_true() -> bool {
58    true
59}
60
61fn default_min_magnitude() -> f64 {
62    0.05
63}
64
65impl Default for DriftRecorderConfig {
66    fn default() -> Self {
67        Self {
68            enabled: false,
69            statistical: true,
70            categorical: true,
71            temporal: true,
72            organizational: true,
73            process_events: true,
74            technology_events: true,
75            regulatory: true,
76            market: true,
77            behavioral: true,
78            min_magnitude_threshold: 0.05,
79        }
80    }
81}
82
83/// Drift label recorder.
84pub struct DriftLabelRecorder {
85    /// Recorded events.
86    events: Vec<LabeledDriftEvent>,
87    /// Configuration.
88    config: DriftRecorderConfig,
89    /// Start date of the simulation.
90    start_date: NaiveDate,
91    /// Event ID counter.
92    event_counter: u64,
93    /// Track previous drift state for delta detection.
94    previous_drift: Option<DriftAdjustments>,
95    /// Track if in recession (for recession end detection).
96    was_in_recession: bool,
97}
98
99impl DriftLabelRecorder {
100    /// Create a new drift label recorder.
101    pub fn new(config: DriftRecorderConfig, start_date: NaiveDate) -> Self {
102        Self {
103            events: Vec::new(),
104            config,
105            start_date,
106            event_counter: 0,
107            previous_drift: None,
108            was_in_recession: false,
109        }
110    }
111
112    /// Check if recording is enabled.
113    pub fn is_enabled(&self) -> bool {
114        self.config.enabled
115    }
116
117    /// Generate a unique event ID.
118    fn next_event_id(&mut self) -> String {
119        self.event_counter += 1;
120        format!("DRIFT-{:06}", self.event_counter)
121    }
122
123    /// Convert a period to a date.
124    fn period_to_date(&self, period: u32) -> NaiveDate {
125        self.start_date + chrono::Duration::days(period as i64 * 30)
126    }
127
128    /// Record a regime change event.
129    pub fn record_regime_change(&mut self, regime: &RegimeChange, period: u32, _date: NaiveDate) {
130        if !self.config.enabled || !self.config.organizational {
131            return;
132        }
133
134        let event_type = match regime.change_type {
135            RegimeChangeType::Acquisition => "acquisition",
136            RegimeChangeType::Divestiture => "divestiture",
137            RegimeChangeType::PriceIncrease => "price_increase",
138            RegimeChangeType::PriceDecrease => "price_decrease",
139            RegimeChangeType::ProductLaunch => "product_launch",
140            RegimeChangeType::ProductDiscontinuation => "product_discontinuation",
141            RegimeChangeType::PolicyChange => "policy_change",
142            RegimeChangeType::CompetitorEntry => "competitor_entry",
143            RegimeChangeType::Custom => "custom",
144        };
145
146        let magnitude = (regime.volume_multiplier() - 1.0)
147            .abs()
148            .max((regime.amount_mean_multiplier() - 1.0).abs());
149
150        if magnitude < self.config.min_magnitude_threshold {
151            return;
152        }
153
154        let detection_difficulty = if magnitude > 0.20 {
155            DetectionDifficulty::Easy
156        } else if magnitude > 0.10 {
157            DetectionDifficulty::Medium
158        } else {
159            DetectionDifficulty::Hard
160        };
161
162        let mut event = LabeledDriftEvent::new(
163            self.next_event_id(),
164            DriftEventType::Organizational(OrganizationalDriftEvent {
165                event_type: event_type.to_string(),
166                related_event_id: regime.description.clone().unwrap_or_default(),
167                detection_difficulty,
168                affected_entities: Vec::new(),
169                impact_metrics: {
170                    let mut m = HashMap::new();
171                    m.insert("volume_multiplier".to_string(), regime.volume_multiplier());
172                    m.insert(
173                        "amount_multiplier".to_string(),
174                        regime.amount_mean_multiplier(),
175                    );
176                    m
177                },
178            }),
179            self.period_to_date(period),
180            period,
181            magnitude,
182        );
183
184        event.end_period = Some(period + regime.transition_periods);
185        event.tags.push("regime_change".to_string());
186        event.tags.push(event_type.to_string());
187
188        self.events.push(event);
189    }
190
191    /// Record statistical drift from drift adjustments.
192    pub fn record_statistical_drift(&mut self, adjustments: &DriftAdjustments, period: u32) {
193        if !self.config.enabled || !self.config.statistical {
194            return;
195        }
196
197        let date = self.period_to_date(period);
198
199        // Check for mean shift - extract values before borrowing self mutably
200        if let Some(ref prev) = self.previous_drift {
201            let mean_delta =
202                (adjustments.amount_mean_multiplier - prev.amount_mean_multiplier).abs();
203            let var_delta =
204                (adjustments.amount_variance_multiplier - prev.amount_variance_multiplier).abs();
205            let prev_mean = prev.amount_mean_multiplier;
206            let current_mean = adjustments.amount_mean_multiplier;
207            let min_threshold = self.config.min_magnitude_threshold;
208
209            if mean_delta >= min_threshold {
210                let detection_difficulty = if mean_delta > 0.20 {
211                    DetectionDifficulty::Easy
212                } else if mean_delta > 0.10 {
213                    DetectionDifficulty::Medium
214                } else {
215                    DetectionDifficulty::Hard
216                };
217
218                let event_id = self.next_event_id();
219                let event = LabeledDriftEvent::new(
220                    event_id,
221                    DriftEventType::Statistical(StatisticalDriftEvent {
222                        shift_type: StatisticalShiftType::MeanShift,
223                        affected_field: "amount".to_string(),
224                        magnitude: mean_delta,
225                        detection_difficulty,
226                        metrics: {
227                            let mut m = HashMap::new();
228                            m.insert("previous_multiplier".to_string(), prev_mean);
229                            m.insert("current_multiplier".to_string(), current_mean);
230                            m
231                        },
232                    }),
233                    date,
234                    period,
235                    mean_delta,
236                );
237
238                self.events.push(event);
239            }
240
241            // Check for variance change
242            if var_delta >= min_threshold {
243                let event_id = self.next_event_id();
244                let event = LabeledDriftEvent::new(
245                    event_id,
246                    DriftEventType::Statistical(StatisticalDriftEvent {
247                        shift_type: StatisticalShiftType::VarianceChange,
248                        affected_field: "amount".to_string(),
249                        magnitude: var_delta,
250                        detection_difficulty: DetectionDifficulty::Medium,
251                        metrics: HashMap::new(),
252                    }),
253                    date,
254                    period,
255                    var_delta,
256                );
257
258                self.events.push(event);
259            }
260        }
261
262        // Check for sudden drift
263        if adjustments.sudden_drift_occurred {
264            let event = LabeledDriftEvent::new(
265                self.next_event_id(),
266                DriftEventType::Statistical(StatisticalDriftEvent {
267                    shift_type: StatisticalShiftType::DistributionChange,
268                    affected_field: "amount".to_string(),
269                    magnitude: 0.5, // Sudden drifts are typically significant
270                    detection_difficulty: DetectionDifficulty::Easy,
271                    metrics: HashMap::new(),
272                }),
273                date,
274                period,
275                0.5,
276            );
277
278            self.events.push(event);
279        }
280
281        self.previous_drift = Some(adjustments.clone());
282    }
283
284    /// Record a market/economic drift event.
285    pub fn record_market_drift(
286        &mut self,
287        market_type: MarketEventType,
288        period: u32,
289        magnitude: f64,
290        is_recession: bool,
291    ) {
292        if !self.config.enabled || !self.config.market {
293            return;
294        }
295
296        if magnitude < self.config.min_magnitude_threshold
297            && market_type != MarketEventType::RecessionStart
298            && market_type != MarketEventType::RecessionEnd
299        {
300            return;
301        }
302
303        // Detect recession transitions
304        let actual_type = if is_recession && !self.was_in_recession {
305            self.was_in_recession = true;
306            MarketEventType::RecessionStart
307        } else if !is_recession && self.was_in_recession {
308            self.was_in_recession = false;
309            MarketEventType::RecessionEnd
310        } else {
311            market_type
312        };
313
314        let detection_difficulty = match actual_type {
315            MarketEventType::RecessionStart | MarketEventType::RecessionEnd => {
316                DetectionDifficulty::Easy
317            }
318            MarketEventType::PriceShock => DetectionDifficulty::Easy,
319            MarketEventType::EconomicCycle => DetectionDifficulty::Medium,
320            MarketEventType::CommodityChange => DetectionDifficulty::Medium,
321        };
322
323        let event = LabeledDriftEvent::new(
324            self.next_event_id(),
325            DriftEventType::Market(MarketDriftEvent {
326                market_type: actual_type,
327                detection_difficulty,
328                magnitude,
329                is_recession,
330                affected_sectors: Vec::new(),
331            }),
332            self.period_to_date(period),
333            period,
334            magnitude,
335        );
336
337        self.events.push(event);
338    }
339
340    /// Record a process evolution drift event.
341    pub fn record_process_drift(
342        &mut self,
343        process_type: &str,
344        related_event_id: &str,
345        period: u32,
346        magnitude: f64,
347        affected_processes: Vec<String>,
348    ) {
349        if !self.config.enabled || !self.config.process_events {
350            return;
351        }
352
353        if magnitude < self.config.min_magnitude_threshold {
354            return;
355        }
356
357        let mut event = LabeledDriftEvent::new(
358            self.next_event_id(),
359            DriftEventType::Process(ProcessDriftEvent {
360                process_type: process_type.to_string(),
361                related_event_id: related_event_id.to_string(),
362                detection_difficulty: DetectionDifficulty::Medium,
363                affected_processes,
364            }),
365            self.period_to_date(period),
366            period,
367            magnitude,
368        );
369
370        event.related_org_event = Some(related_event_id.to_string());
371        self.events.push(event);
372    }
373
374    /// Record a technology transition drift event.
375    pub fn record_technology_drift(
376        &mut self,
377        transition_type: &str,
378        related_event_id: &str,
379        period: u32,
380        magnitude: f64,
381        systems: Vec<String>,
382        current_phase: Option<&str>,
383    ) {
384        if !self.config.enabled || !self.config.technology_events {
385            return;
386        }
387
388        if magnitude < self.config.min_magnitude_threshold {
389            return;
390        }
391
392        let mut event = LabeledDriftEvent::new(
393            self.next_event_id(),
394            DriftEventType::Technology(TechnologyDriftEvent {
395                transition_type: transition_type.to_string(),
396                related_event_id: related_event_id.to_string(),
397                detection_difficulty: DetectionDifficulty::Easy, // Tech transitions are usually obvious
398                systems,
399                current_phase: current_phase.map(String::from),
400            }),
401            self.period_to_date(period),
402            period,
403            magnitude,
404        );
405
406        event.related_org_event = Some(related_event_id.to_string());
407        self.events.push(event);
408    }
409
410    /// Record a temporal pattern drift event.
411    pub fn record_temporal_drift(
412        &mut self,
413        shift_type: TemporalShiftType,
414        period: u32,
415        magnitude: f64,
416        affected_field: Option<&str>,
417        description: Option<&str>,
418    ) {
419        if !self.config.enabled || !self.config.temporal {
420            return;
421        }
422
423        if magnitude < self.config.min_magnitude_threshold {
424            return;
425        }
426
427        let event = LabeledDriftEvent::new(
428            self.next_event_id(),
429            DriftEventType::Temporal(TemporalDriftEvent {
430                shift_type,
431                affected_field: affected_field.map(String::from),
432                detection_difficulty: DetectionDifficulty::Hard, // Temporal drifts are subtle
433                magnitude,
434                description: description.map(String::from),
435            }),
436            self.period_to_date(period),
437            period,
438            magnitude,
439        );
440
441        self.events.push(event);
442    }
443
444    /// Record a categorical drift event.
445    pub fn record_categorical_drift(
446        &mut self,
447        shift_type: CategoricalShiftType,
448        affected_field: &str,
449        period: u32,
450        proportions_before: HashMap<String, f64>,
451        proportions_after: HashMap<String, f64>,
452    ) {
453        if !self.config.enabled || !self.config.categorical {
454            return;
455        }
456
457        // Calculate magnitude as max proportion change
458        let magnitude = proportions_before
459            .keys()
460            .chain(proportions_after.keys())
461            .map(|k| {
462                let before = proportions_before.get(k).copied().unwrap_or(0.0);
463                let after = proportions_after.get(k).copied().unwrap_or(0.0);
464                (after - before).abs()
465            })
466            .fold(0.0f64, f64::max);
467
468        if magnitude < self.config.min_magnitude_threshold {
469            return;
470        }
471
472        let new_categories: Vec<String> = proportions_after
473            .keys()
474            .filter(|k| !proportions_before.contains_key(*k))
475            .cloned()
476            .collect();
477
478        let removed_categories: Vec<String> = proportions_before
479            .keys()
480            .filter(|k| !proportions_after.contains_key(*k))
481            .cloned()
482            .collect();
483
484        let event = LabeledDriftEvent::new(
485            self.next_event_id(),
486            DriftEventType::Categorical(CategoricalDriftEvent {
487                shift_type,
488                affected_field: affected_field.to_string(),
489                detection_difficulty: DetectionDifficulty::Medium,
490                proportions_before,
491                proportions_after,
492                new_categories,
493                removed_categories,
494            }),
495            self.period_to_date(period),
496            period,
497            magnitude,
498        );
499
500        self.events.push(event);
501    }
502
503    /// Get all recorded events.
504    pub fn events(&self) -> &[LabeledDriftEvent] {
505        &self.events
506    }
507
508    /// Get events in a specific period range.
509    pub fn events_in_range(&self, start_period: u32, end_period: u32) -> Vec<&LabeledDriftEvent> {
510        self.events
511            .iter()
512            .filter(|e| e.start_period >= start_period && e.start_period <= end_period)
513            .collect()
514    }
515
516    /// Get events by category.
517    pub fn events_by_category(&self, category: &str) -> Vec<&LabeledDriftEvent> {
518        self.events
519            .iter()
520            .filter(|e| e.event_type.category_name() == category)
521            .collect()
522    }
523
524    /// Get total event count.
525    pub fn event_count(&self) -> usize {
526        self.events.len()
527    }
528
529    /// Export events to CSV file.
530    pub fn export_to_csv(&self, path: &Path) -> std::io::Result<usize> {
531        let mut file = std::fs::File::create(path)?;
532
533        // Write header
534        writeln!(
535            file,
536            "event_id,category,type,start_date,end_date,start_period,end_period,magnitude,detection_difficulty,affected_fields,tags"
537        )?;
538
539        // Write events
540        for event in &self.events {
541            let end_date = event.end_date.map(|d| d.to_string()).unwrap_or_default();
542            let end_period = event.end_period.map(|p| p.to_string()).unwrap_or_default();
543            let affected_fields = event.affected_fields.join(";");
544            let tags = event.tags.join(";");
545
546            writeln!(
547                file,
548                "{},{},{},{},{},{},{},{:.4},{:?},{},{}",
549                event.event_id,
550                event.event_type.category_name(),
551                event.event_type.type_name(),
552                event.start_date,
553                end_date,
554                event.start_period,
555                end_period,
556                event.magnitude,
557                event.detection_difficulty,
558                affected_fields,
559                tags
560            )?;
561        }
562
563        Ok(self.events.len())
564    }
565
566    /// Export events to JSON file.
567    pub fn export_to_json(&self, path: &Path) -> std::io::Result<usize> {
568        let json = serde_json::to_string_pretty(&self.events).map_err(std::io::Error::other)?;
569        std::fs::write(path, json)?;
570        Ok(self.events.len())
571    }
572
573    /// Get summary statistics.
574    pub fn summary(&self) -> DriftRecorderSummary {
575        let mut by_category: HashMap<String, usize> = HashMap::new();
576        let mut by_difficulty: HashMap<String, usize> = HashMap::new();
577        let mut total_magnitude = 0.0;
578
579        for event in &self.events {
580            *by_category
581                .entry(event.event_type.category_name().to_string())
582                .or_insert(0) += 1;
583            *by_difficulty
584                .entry(format!("{:?}", event.detection_difficulty))
585                .or_insert(0) += 1;
586            total_magnitude += event.magnitude;
587        }
588
589        DriftRecorderSummary {
590            total_events: self.events.len(),
591            by_category,
592            by_difficulty,
593            avg_magnitude: if self.events.is_empty() {
594                0.0
595            } else {
596                total_magnitude / self.events.len() as f64
597            },
598        }
599    }
600}
601
602/// Summary statistics for drift recording.
603#[derive(Debug, Clone, Serialize, Deserialize)]
604pub struct DriftRecorderSummary {
605    /// Total number of events.
606    pub total_events: usize,
607    /// Events by category.
608    pub by_category: HashMap<String, usize>,
609    /// Events by detection difficulty.
610    pub by_difficulty: HashMap<String, usize>,
611    /// Average magnitude.
612    pub avg_magnitude: f64,
613}
614
615#[cfg(test)]
616#[allow(clippy::unwrap_used)]
617mod tests {
618    use super::*;
619
620    #[test]
621    fn test_drift_recorder_creation() {
622        let config = DriftRecorderConfig {
623            enabled: true,
624            ..Default::default()
625        };
626        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
627        let recorder = DriftLabelRecorder::new(config, start);
628
629        assert!(recorder.is_enabled());
630        assert_eq!(recorder.event_count(), 0);
631    }
632
633    #[test]
634    fn test_record_regime_change() {
635        let config = DriftRecorderConfig {
636            enabled: true,
637            min_magnitude_threshold: 0.0,
638            ..Default::default()
639        };
640        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
641        let mut recorder = DriftLabelRecorder::new(config, start);
642
643        let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
644        recorder.record_regime_change(&regime, 6, start);
645
646        assert_eq!(recorder.event_count(), 1);
647        let event = &recorder.events()[0];
648        assert_eq!(event.event_type.category_name(), "organizational");
649    }
650
651    #[test]
652    fn test_record_statistical_drift() {
653        let config = DriftRecorderConfig {
654            enabled: true,
655            min_magnitude_threshold: 0.01, // Low but not zero to avoid edge case
656            ..Default::default()
657        };
658        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
659        let mut recorder = DriftLabelRecorder::new(config, start);
660
661        // First call establishes baseline
662        let adj1 = DriftAdjustments {
663            amount_mean_multiplier: 1.0,
664            ..DriftAdjustments::none()
665        };
666        recorder.record_statistical_drift(&adj1, 0);
667
668        // Second call detects drift (mean shift of 0.25 > threshold of 0.01)
669        let adj2 = DriftAdjustments {
670            amount_mean_multiplier: 1.25,
671            ..DriftAdjustments::none()
672        };
673        recorder.record_statistical_drift(&adj2, 1);
674
675        // Only mean shift should be recorded (variance delta is 0)
676        assert_eq!(recorder.event_count(), 1);
677    }
678
679    #[test]
680    fn test_summary() {
681        let config = DriftRecorderConfig {
682            enabled: true,
683            min_magnitude_threshold: 0.0,
684            ..Default::default()
685        };
686        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
687        let mut recorder = DriftLabelRecorder::new(config, start);
688
689        let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
690        recorder.record_regime_change(&regime, 6, start);
691
692        let summary = recorder.summary();
693        assert_eq!(summary.total_events, 1);
694        assert!(summary.by_category.contains_key("organizational"));
695    }
696}