Skip to main content

datasynth_core/distributions/
drift_recorder.rs

1//! Drift label recorder for ground truth generation.
2//!
3//! Records drift events during data generation for use as
4//! ground truth labels in ML model training and evaluation.
5
6use crate::distributions::drift::{DriftAdjustments, RegimeChange, RegimeChangeType};
7use crate::models::drift_events::{
8    CategoricalDriftEvent, CategoricalShiftType, DetectionDifficulty, DriftEventType,
9    LabeledDriftEvent, MarketDriftEvent, MarketEventType, OrganizationalDriftEvent,
10    ProcessDriftEvent, StatisticalDriftEvent, StatisticalShiftType, TechnologyDriftEvent,
11    TemporalDriftEvent, TemporalShiftType,
12};
13use chrono::NaiveDate;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::io::Write;
17use std::path::Path;
18
19/// Configuration for drift recording.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct DriftRecorderConfig {
22    /// Enable recording.
23    #[serde(default)]
24    pub enabled: bool,
25    /// Record statistical drift events.
26    #[serde(default = "default_true")]
27    pub statistical: bool,
28    /// Record categorical drift events.
29    #[serde(default = "default_true")]
30    pub categorical: bool,
31    /// Record temporal drift events.
32    #[serde(default = "default_true")]
33    pub temporal: bool,
34    /// Record organizational events.
35    #[serde(default = "default_true")]
36    pub organizational: bool,
37    /// Record process events.
38    #[serde(default = "default_true")]
39    pub process_events: bool,
40    /// Record technology events.
41    #[serde(default = "default_true")]
42    pub technology_events: bool,
43    /// Record regulatory events.
44    #[serde(default = "default_true")]
45    pub regulatory: bool,
46    /// Record market events.
47    #[serde(default = "default_true")]
48    pub market: bool,
49    /// Record behavioral events.
50    #[serde(default = "default_true")]
51    pub behavioral: bool,
52    /// Minimum magnitude threshold to record.
53    #[serde(default = "default_min_magnitude")]
54    pub min_magnitude_threshold: f64,
55}
56
57fn default_true() -> bool {
58    true
59}
60
61fn default_min_magnitude() -> f64 {
62    0.05
63}
64
65impl Default for DriftRecorderConfig {
66    fn default() -> Self {
67        Self {
68            enabled: false,
69            statistical: true,
70            categorical: true,
71            temporal: true,
72            organizational: true,
73            process_events: true,
74            technology_events: true,
75            regulatory: true,
76            market: true,
77            behavioral: true,
78            min_magnitude_threshold: 0.05,
79        }
80    }
81}
82
83/// Drift label recorder.
84pub struct DriftLabelRecorder {
85    /// Recorded events.
86    events: Vec<LabeledDriftEvent>,
87    /// Configuration.
88    config: DriftRecorderConfig,
89    /// Start date of the simulation.
90    start_date: NaiveDate,
91    /// Event ID counter.
92    event_counter: u64,
93    /// Track previous drift state for delta detection.
94    previous_drift: Option<DriftAdjustments>,
95    /// Track active regime changes (reserved for future use).
96    #[allow(dead_code)]
97    active_regimes: HashMap<String, u32>,
98    /// Track if in recession (for recession end detection).
99    was_in_recession: bool,
100}
101
102impl DriftLabelRecorder {
103    /// Create a new drift label recorder.
104    pub fn new(config: DriftRecorderConfig, start_date: NaiveDate) -> Self {
105        Self {
106            events: Vec::new(),
107            config,
108            start_date,
109            event_counter: 0,
110            previous_drift: None,
111            active_regimes: HashMap::new(),
112            was_in_recession: false,
113        }
114    }
115
116    /// Check if recording is enabled.
117    pub fn is_enabled(&self) -> bool {
118        self.config.enabled
119    }
120
121    /// Generate a unique event ID.
122    fn next_event_id(&mut self) -> String {
123        self.event_counter += 1;
124        format!("DRIFT-{:06}", self.event_counter)
125    }
126
127    /// Convert a period to a date.
128    fn period_to_date(&self, period: u32) -> NaiveDate {
129        self.start_date + chrono::Duration::days(period as i64 * 30)
130    }
131
132    /// Record a regime change event.
133    pub fn record_regime_change(&mut self, regime: &RegimeChange, period: u32, _date: NaiveDate) {
134        if !self.config.enabled || !self.config.organizational {
135            return;
136        }
137
138        let event_type = match regime.change_type {
139            RegimeChangeType::Acquisition => "acquisition",
140            RegimeChangeType::Divestiture => "divestiture",
141            RegimeChangeType::PriceIncrease => "price_increase",
142            RegimeChangeType::PriceDecrease => "price_decrease",
143            RegimeChangeType::ProductLaunch => "product_launch",
144            RegimeChangeType::ProductDiscontinuation => "product_discontinuation",
145            RegimeChangeType::PolicyChange => "policy_change",
146            RegimeChangeType::CompetitorEntry => "competitor_entry",
147            RegimeChangeType::Custom => "custom",
148        };
149
150        let magnitude = (regime.volume_multiplier() - 1.0)
151            .abs()
152            .max((regime.amount_mean_multiplier() - 1.0).abs());
153
154        if magnitude < self.config.min_magnitude_threshold {
155            return;
156        }
157
158        let detection_difficulty = if magnitude > 0.20 {
159            DetectionDifficulty::Easy
160        } else if magnitude > 0.10 {
161            DetectionDifficulty::Medium
162        } else {
163            DetectionDifficulty::Hard
164        };
165
166        let mut event = LabeledDriftEvent::new(
167            self.next_event_id(),
168            DriftEventType::Organizational(OrganizationalDriftEvent {
169                event_type: event_type.to_string(),
170                related_event_id: regime.description.clone().unwrap_or_default(),
171                detection_difficulty,
172                affected_entities: Vec::new(),
173                impact_metrics: {
174                    let mut m = HashMap::new();
175                    m.insert("volume_multiplier".to_string(), regime.volume_multiplier());
176                    m.insert(
177                        "amount_multiplier".to_string(),
178                        regime.amount_mean_multiplier(),
179                    );
180                    m
181                },
182            }),
183            self.period_to_date(period),
184            period,
185            magnitude,
186        );
187
188        event.end_period = Some(period + regime.transition_periods);
189        event.tags.push("regime_change".to_string());
190        event.tags.push(event_type.to_string());
191
192        self.events.push(event);
193    }
194
195    /// Record statistical drift from drift adjustments.
196    pub fn record_statistical_drift(&mut self, adjustments: &DriftAdjustments, period: u32) {
197        if !self.config.enabled || !self.config.statistical {
198            return;
199        }
200
201        let date = self.period_to_date(period);
202
203        // Check for mean shift - extract values before borrowing self mutably
204        if let Some(ref prev) = self.previous_drift {
205            let mean_delta =
206                (adjustments.amount_mean_multiplier - prev.amount_mean_multiplier).abs();
207            let var_delta =
208                (adjustments.amount_variance_multiplier - prev.amount_variance_multiplier).abs();
209            let prev_mean = prev.amount_mean_multiplier;
210            let current_mean = adjustments.amount_mean_multiplier;
211            let min_threshold = self.config.min_magnitude_threshold;
212
213            if mean_delta >= min_threshold {
214                let detection_difficulty = if mean_delta > 0.20 {
215                    DetectionDifficulty::Easy
216                } else if mean_delta > 0.10 {
217                    DetectionDifficulty::Medium
218                } else {
219                    DetectionDifficulty::Hard
220                };
221
222                let event_id = self.next_event_id();
223                let event = LabeledDriftEvent::new(
224                    event_id,
225                    DriftEventType::Statistical(StatisticalDriftEvent {
226                        shift_type: StatisticalShiftType::MeanShift,
227                        affected_field: "amount".to_string(),
228                        magnitude: mean_delta,
229                        detection_difficulty,
230                        metrics: {
231                            let mut m = HashMap::new();
232                            m.insert("previous_multiplier".to_string(), prev_mean);
233                            m.insert("current_multiplier".to_string(), current_mean);
234                            m
235                        },
236                    }),
237                    date,
238                    period,
239                    mean_delta,
240                );
241
242                self.events.push(event);
243            }
244
245            // Check for variance change
246            if var_delta >= min_threshold {
247                let event_id = self.next_event_id();
248                let event = LabeledDriftEvent::new(
249                    event_id,
250                    DriftEventType::Statistical(StatisticalDriftEvent {
251                        shift_type: StatisticalShiftType::VarianceChange,
252                        affected_field: "amount".to_string(),
253                        magnitude: var_delta,
254                        detection_difficulty: DetectionDifficulty::Medium,
255                        metrics: HashMap::new(),
256                    }),
257                    date,
258                    period,
259                    var_delta,
260                );
261
262                self.events.push(event);
263            }
264        }
265
266        // Check for sudden drift
267        if adjustments.sudden_drift_occurred {
268            let event = LabeledDriftEvent::new(
269                self.next_event_id(),
270                DriftEventType::Statistical(StatisticalDriftEvent {
271                    shift_type: StatisticalShiftType::DistributionChange,
272                    affected_field: "amount".to_string(),
273                    magnitude: 0.5, // Sudden drifts are typically significant
274                    detection_difficulty: DetectionDifficulty::Easy,
275                    metrics: HashMap::new(),
276                }),
277                date,
278                period,
279                0.5,
280            );
281
282            self.events.push(event);
283        }
284
285        self.previous_drift = Some(adjustments.clone());
286    }
287
288    /// Record a market/economic drift event.
289    pub fn record_market_drift(
290        &mut self,
291        market_type: MarketEventType,
292        period: u32,
293        magnitude: f64,
294        is_recession: bool,
295    ) {
296        if !self.config.enabled || !self.config.market {
297            return;
298        }
299
300        if magnitude < self.config.min_magnitude_threshold
301            && market_type != MarketEventType::RecessionStart
302            && market_type != MarketEventType::RecessionEnd
303        {
304            return;
305        }
306
307        // Detect recession transitions
308        let actual_type = if is_recession && !self.was_in_recession {
309            self.was_in_recession = true;
310            MarketEventType::RecessionStart
311        } else if !is_recession && self.was_in_recession {
312            self.was_in_recession = false;
313            MarketEventType::RecessionEnd
314        } else {
315            market_type
316        };
317
318        let detection_difficulty = match actual_type {
319            MarketEventType::RecessionStart | MarketEventType::RecessionEnd => {
320                DetectionDifficulty::Easy
321            }
322            MarketEventType::PriceShock => DetectionDifficulty::Easy,
323            MarketEventType::EconomicCycle => DetectionDifficulty::Medium,
324            MarketEventType::CommodityChange => DetectionDifficulty::Medium,
325        };
326
327        let event = LabeledDriftEvent::new(
328            self.next_event_id(),
329            DriftEventType::Market(MarketDriftEvent {
330                market_type: actual_type,
331                detection_difficulty,
332                magnitude,
333                is_recession,
334                affected_sectors: Vec::new(),
335            }),
336            self.period_to_date(period),
337            period,
338            magnitude,
339        );
340
341        self.events.push(event);
342    }
343
344    /// Record a process evolution drift event.
345    pub fn record_process_drift(
346        &mut self,
347        process_type: &str,
348        related_event_id: &str,
349        period: u32,
350        magnitude: f64,
351        affected_processes: Vec<String>,
352    ) {
353        if !self.config.enabled || !self.config.process_events {
354            return;
355        }
356
357        if magnitude < self.config.min_magnitude_threshold {
358            return;
359        }
360
361        let mut event = LabeledDriftEvent::new(
362            self.next_event_id(),
363            DriftEventType::Process(ProcessDriftEvent {
364                process_type: process_type.to_string(),
365                related_event_id: related_event_id.to_string(),
366                detection_difficulty: DetectionDifficulty::Medium,
367                affected_processes,
368            }),
369            self.period_to_date(period),
370            period,
371            magnitude,
372        );
373
374        event.related_org_event = Some(related_event_id.to_string());
375        self.events.push(event);
376    }
377
378    /// Record a technology transition drift event.
379    pub fn record_technology_drift(
380        &mut self,
381        transition_type: &str,
382        related_event_id: &str,
383        period: u32,
384        magnitude: f64,
385        systems: Vec<String>,
386        current_phase: Option<&str>,
387    ) {
388        if !self.config.enabled || !self.config.technology_events {
389            return;
390        }
391
392        if magnitude < self.config.min_magnitude_threshold {
393            return;
394        }
395
396        let mut event = LabeledDriftEvent::new(
397            self.next_event_id(),
398            DriftEventType::Technology(TechnologyDriftEvent {
399                transition_type: transition_type.to_string(),
400                related_event_id: related_event_id.to_string(),
401                detection_difficulty: DetectionDifficulty::Easy, // Tech transitions are usually obvious
402                systems,
403                current_phase: current_phase.map(String::from),
404            }),
405            self.period_to_date(period),
406            period,
407            magnitude,
408        );
409
410        event.related_org_event = Some(related_event_id.to_string());
411        self.events.push(event);
412    }
413
414    /// Record a temporal pattern drift event.
415    pub fn record_temporal_drift(
416        &mut self,
417        shift_type: TemporalShiftType,
418        period: u32,
419        magnitude: f64,
420        affected_field: Option<&str>,
421        description: Option<&str>,
422    ) {
423        if !self.config.enabled || !self.config.temporal {
424            return;
425        }
426
427        if magnitude < self.config.min_magnitude_threshold {
428            return;
429        }
430
431        let event = LabeledDriftEvent::new(
432            self.next_event_id(),
433            DriftEventType::Temporal(TemporalDriftEvent {
434                shift_type,
435                affected_field: affected_field.map(String::from),
436                detection_difficulty: DetectionDifficulty::Hard, // Temporal drifts are subtle
437                magnitude,
438                description: description.map(String::from),
439            }),
440            self.period_to_date(period),
441            period,
442            magnitude,
443        );
444
445        self.events.push(event);
446    }
447
448    /// Record a categorical drift event.
449    pub fn record_categorical_drift(
450        &mut self,
451        shift_type: CategoricalShiftType,
452        affected_field: &str,
453        period: u32,
454        proportions_before: HashMap<String, f64>,
455        proportions_after: HashMap<String, f64>,
456    ) {
457        if !self.config.enabled || !self.config.categorical {
458            return;
459        }
460
461        // Calculate magnitude as max proportion change
462        let magnitude = proportions_before
463            .keys()
464            .chain(proportions_after.keys())
465            .map(|k| {
466                let before = proportions_before.get(k).copied().unwrap_or(0.0);
467                let after = proportions_after.get(k).copied().unwrap_or(0.0);
468                (after - before).abs()
469            })
470            .fold(0.0f64, f64::max);
471
472        if magnitude < self.config.min_magnitude_threshold {
473            return;
474        }
475
476        let new_categories: Vec<String> = proportions_after
477            .keys()
478            .filter(|k| !proportions_before.contains_key(*k))
479            .cloned()
480            .collect();
481
482        let removed_categories: Vec<String> = proportions_before
483            .keys()
484            .filter(|k| !proportions_after.contains_key(*k))
485            .cloned()
486            .collect();
487
488        let event = LabeledDriftEvent::new(
489            self.next_event_id(),
490            DriftEventType::Categorical(CategoricalDriftEvent {
491                shift_type,
492                affected_field: affected_field.to_string(),
493                detection_difficulty: DetectionDifficulty::Medium,
494                proportions_before,
495                proportions_after,
496                new_categories,
497                removed_categories,
498            }),
499            self.period_to_date(period),
500            period,
501            magnitude,
502        );
503
504        self.events.push(event);
505    }
506
507    /// Get all recorded events.
508    pub fn events(&self) -> &[LabeledDriftEvent] {
509        &self.events
510    }
511
512    /// Get events in a specific period range.
513    pub fn events_in_range(&self, start_period: u32, end_period: u32) -> Vec<&LabeledDriftEvent> {
514        self.events
515            .iter()
516            .filter(|e| e.start_period >= start_period && e.start_period <= end_period)
517            .collect()
518    }
519
520    /// Get events by category.
521    pub fn events_by_category(&self, category: &str) -> Vec<&LabeledDriftEvent> {
522        self.events
523            .iter()
524            .filter(|e| e.event_type.category_name() == category)
525            .collect()
526    }
527
528    /// Get total event count.
529    pub fn event_count(&self) -> usize {
530        self.events.len()
531    }
532
533    /// Export events to CSV file.
534    pub fn export_to_csv(&self, path: &Path) -> std::io::Result<usize> {
535        let mut file = std::fs::File::create(path)?;
536
537        // Write header
538        writeln!(
539            file,
540            "event_id,category,type,start_date,end_date,start_period,end_period,magnitude,detection_difficulty,affected_fields,tags"
541        )?;
542
543        // Write events
544        for event in &self.events {
545            let end_date = event.end_date.map(|d| d.to_string()).unwrap_or_default();
546            let end_period = event.end_period.map(|p| p.to_string()).unwrap_or_default();
547            let affected_fields = event.affected_fields.join(";");
548            let tags = event.tags.join(";");
549
550            writeln!(
551                file,
552                "{},{},{},{},{},{},{},{:.4},{:?},{},{}",
553                event.event_id,
554                event.event_type.category_name(),
555                event.event_type.type_name(),
556                event.start_date,
557                end_date,
558                event.start_period,
559                end_period,
560                event.magnitude,
561                event.detection_difficulty,
562                affected_fields,
563                tags
564            )?;
565        }
566
567        Ok(self.events.len())
568    }
569
570    /// Export events to JSON file.
571    pub fn export_to_json(&self, path: &Path) -> std::io::Result<usize> {
572        let json = serde_json::to_string_pretty(&self.events).map_err(std::io::Error::other)?;
573        std::fs::write(path, json)?;
574        Ok(self.events.len())
575    }
576
577    /// Get summary statistics.
578    pub fn summary(&self) -> DriftRecorderSummary {
579        let mut by_category: HashMap<String, usize> = HashMap::new();
580        let mut by_difficulty: HashMap<String, usize> = HashMap::new();
581        let mut total_magnitude = 0.0;
582
583        for event in &self.events {
584            *by_category
585                .entry(event.event_type.category_name().to_string())
586                .or_insert(0) += 1;
587            *by_difficulty
588                .entry(format!("{:?}", event.detection_difficulty))
589                .or_insert(0) += 1;
590            total_magnitude += event.magnitude;
591        }
592
593        DriftRecorderSummary {
594            total_events: self.events.len(),
595            by_category,
596            by_difficulty,
597            avg_magnitude: if self.events.is_empty() {
598                0.0
599            } else {
600                total_magnitude / self.events.len() as f64
601            },
602        }
603    }
604}
605
606/// Summary statistics for drift recording.
607#[derive(Debug, Clone, Serialize, Deserialize)]
608pub struct DriftRecorderSummary {
609    /// Total number of events.
610    pub total_events: usize,
611    /// Events by category.
612    pub by_category: HashMap<String, usize>,
613    /// Events by detection difficulty.
614    pub by_difficulty: HashMap<String, usize>,
615    /// Average magnitude.
616    pub avg_magnitude: f64,
617}
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622
623    #[test]
624    fn test_drift_recorder_creation() {
625        let config = DriftRecorderConfig {
626            enabled: true,
627            ..Default::default()
628        };
629        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
630        let recorder = DriftLabelRecorder::new(config, start);
631
632        assert!(recorder.is_enabled());
633        assert_eq!(recorder.event_count(), 0);
634    }
635
636    #[test]
637    fn test_record_regime_change() {
638        let config = DriftRecorderConfig {
639            enabled: true,
640            min_magnitude_threshold: 0.0,
641            ..Default::default()
642        };
643        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
644        let mut recorder = DriftLabelRecorder::new(config, start);
645
646        let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
647        recorder.record_regime_change(&regime, 6, start);
648
649        assert_eq!(recorder.event_count(), 1);
650        let event = &recorder.events()[0];
651        assert_eq!(event.event_type.category_name(), "organizational");
652    }
653
654    #[test]
655    fn test_record_statistical_drift() {
656        let config = DriftRecorderConfig {
657            enabled: true,
658            min_magnitude_threshold: 0.01, // Low but not zero to avoid edge case
659            ..Default::default()
660        };
661        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
662        let mut recorder = DriftLabelRecorder::new(config, start);
663
664        // First call establishes baseline
665        let adj1 = DriftAdjustments {
666            amount_mean_multiplier: 1.0,
667            ..DriftAdjustments::none()
668        };
669        recorder.record_statistical_drift(&adj1, 0);
670
671        // Second call detects drift (mean shift of 0.25 > threshold of 0.01)
672        let adj2 = DriftAdjustments {
673            amount_mean_multiplier: 1.25,
674            ..DriftAdjustments::none()
675        };
676        recorder.record_statistical_drift(&adj2, 1);
677
678        // Only mean shift should be recorded (variance delta is 0)
679        assert_eq!(recorder.event_count(), 1);
680    }
681
682    #[test]
683    fn test_summary() {
684        let config = DriftRecorderConfig {
685            enabled: true,
686            min_magnitude_threshold: 0.0,
687            ..Default::default()
688        };
689        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
690        let mut recorder = DriftLabelRecorder::new(config, start);
691
692        let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
693        recorder.record_regime_change(&regime, 6, start);
694
695        let summary = recorder.summary();
696        assert_eq!(summary.total_events, 1);
697        assert!(summary.by_category.contains_key("organizational"));
698    }
699}