Skip to main content

datasynth_core/models/
drift_events.rs

1//! Drift event types for ground truth labeling.
2//!
3//! Provides comprehensive drift event typing for ML model training:
4//! - Statistical shifts (mean, variance, distribution)
5//! - Categorical shifts (proportions, new categories)
6//! - Temporal shifts (seasonality, trends)
7//! - Regulatory and audit focus changes
8
9use chrono::NaiveDate;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Drift event type with associated metadata.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "category", rename_all = "snake_case")]
16pub enum DriftEventType {
17    /// Statistical distribution shift.
18    Statistical(StatisticalDriftEvent),
19    /// Categorical distribution shift.
20    Categorical(CategoricalDriftEvent),
21    /// Temporal pattern shift.
22    Temporal(TemporalDriftEvent),
23    /// Organizational event drift.
24    Organizational(OrganizationalDriftEvent),
25    /// Process evolution drift.
26    Process(ProcessDriftEvent),
27    /// Technology transition drift.
28    Technology(TechnologyDriftEvent),
29    /// Regulatory change drift.
30    Regulatory(RegulatoryDriftLabel),
31    /// Audit focus shift.
32    AuditFocus(AuditFocusDriftEvent),
33    /// Market/economic drift.
34    Market(MarketDriftEvent),
35    /// Behavioral drift.
36    Behavioral(BehavioralDriftEvent),
37}
38
39impl DriftEventType {
40    /// Get the category name.
41    pub fn category_name(&self) -> &'static str {
42        match self {
43            Self::Statistical(_) => "statistical",
44            Self::Categorical(_) => "categorical",
45            Self::Temporal(_) => "temporal",
46            Self::Organizational(_) => "organizational",
47            Self::Process(_) => "process",
48            Self::Technology(_) => "technology",
49            Self::Regulatory(_) => "regulatory",
50            Self::AuditFocus(_) => "audit_focus",
51            Self::Market(_) => "market",
52            Self::Behavioral(_) => "behavioral",
53        }
54    }
55
56    /// Get the specific type name.
57    pub fn type_name(&self) -> &str {
58        match self {
59            Self::Statistical(e) => e.shift_type.as_str(),
60            Self::Categorical(e) => e.shift_type.as_str(),
61            Self::Temporal(e) => e.shift_type.as_str(),
62            Self::Organizational(e) => &e.event_type,
63            Self::Process(e) => &e.process_type,
64            Self::Technology(e) => &e.transition_type,
65            Self::Regulatory(e) => &e.regulation_type,
66            Self::AuditFocus(e) => &e.focus_type,
67            Self::Market(e) => e.market_type.as_str(),
68            Self::Behavioral(e) => &e.behavior_type,
69        }
70    }
71
72    /// Get the detection difficulty.
73    pub fn detection_difficulty(&self) -> DetectionDifficulty {
74        match self {
75            Self::Statistical(e) => e.detection_difficulty,
76            Self::Categorical(e) => e.detection_difficulty,
77            Self::Temporal(e) => e.detection_difficulty,
78            Self::Organizational(e) => e.detection_difficulty,
79            Self::Process(e) => e.detection_difficulty,
80            Self::Technology(e) => e.detection_difficulty,
81            Self::Regulatory(e) => e.detection_difficulty,
82            Self::AuditFocus(e) => e.detection_difficulty,
83            Self::Market(e) => e.detection_difficulty,
84            Self::Behavioral(e) => e.detection_difficulty,
85        }
86    }
87}
88
89/// Detection difficulty level for drift events.
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
91#[serde(rename_all = "snake_case")]
92pub enum DetectionDifficulty {
93    /// Easy to detect (large magnitude, clear signal).
94    Easy,
95    /// Medium difficulty (moderate signal-to-noise).
96    #[default]
97    Medium,
98    /// Hard to detect (subtle, gradual, or noisy).
99    Hard,
100}
101
102impl DetectionDifficulty {
103    /// Get a numeric score (0.0 = easy, 1.0 = hard).
104    pub fn score(&self) -> f64 {
105        match self {
106            Self::Easy => 0.0,
107            Self::Medium => 0.5,
108            Self::Hard => 1.0,
109        }
110    }
111}
112
113// =============================================================================
114// Statistical Drift Events
115// =============================================================================
116
117/// Statistical drift event.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct StatisticalDriftEvent {
120    /// Type of statistical shift.
121    pub shift_type: StatisticalShiftType,
122    /// Affected field/feature.
123    pub affected_field: String,
124    /// Magnitude of the shift.
125    pub magnitude: f64,
126    /// Detection difficulty.
127    #[serde(default)]
128    pub detection_difficulty: DetectionDifficulty,
129    /// Additional metrics.
130    #[serde(default)]
131    pub metrics: HashMap<String, f64>,
132}
133
134/// Type of statistical shift.
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
136#[serde(rename_all = "snake_case")]
137pub enum StatisticalShiftType {
138    /// Mean shift.
139    MeanShift,
140    /// Variance change.
141    VarianceChange,
142    /// Distribution shape change.
143    DistributionChange,
144    /// Correlation change.
145    CorrelationChange,
146    /// Tail behavior change.
147    TailChange,
148    /// Benford distribution deviation.
149    BenfordDeviation,
150}
151
152impl StatisticalShiftType {
153    /// Get the type as a string.
154    pub fn as_str(&self) -> &'static str {
155        match self {
156            Self::MeanShift => "mean_shift",
157            Self::VarianceChange => "variance_change",
158            Self::DistributionChange => "distribution_change",
159            Self::CorrelationChange => "correlation_change",
160            Self::TailChange => "tail_change",
161            Self::BenfordDeviation => "benford_deviation",
162        }
163    }
164}
165
166// =============================================================================
167// Categorical Drift Events
168// =============================================================================
169
170/// Categorical drift event.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct CategoricalDriftEvent {
173    /// Type of categorical shift.
174    pub shift_type: CategoricalShiftType,
175    /// Affected field/feature.
176    pub affected_field: String,
177    /// Detection difficulty.
178    #[serde(default)]
179    pub detection_difficulty: DetectionDifficulty,
180    /// Category proportions before (if applicable).
181    #[serde(default)]
182    pub proportions_before: HashMap<String, f64>,
183    /// Category proportions after (if applicable).
184    #[serde(default)]
185    pub proportions_after: HashMap<String, f64>,
186    /// New categories introduced.
187    #[serde(default)]
188    pub new_categories: Vec<String>,
189    /// Categories removed.
190    #[serde(default)]
191    pub removed_categories: Vec<String>,
192}
193
194/// Type of categorical shift.
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
196#[serde(rename_all = "snake_case")]
197pub enum CategoricalShiftType {
198    /// Proportion shift between existing categories.
199    ProportionShift,
200    /// New category introduced.
201    NewCategory,
202    /// Category removed/deprecated.
203    CategoryRemoval,
204    /// Category consolidation.
205    Consolidation,
206}
207
208impl CategoricalShiftType {
209    /// Get the type as a string.
210    pub fn as_str(&self) -> &'static str {
211        match self {
212            Self::ProportionShift => "proportion_shift",
213            Self::NewCategory => "new_category",
214            Self::CategoryRemoval => "category_removal",
215            Self::Consolidation => "consolidation",
216        }
217    }
218}
219
220// =============================================================================
221// Temporal Drift Events
222// =============================================================================
223
224/// Temporal drift event.
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct TemporalDriftEvent {
227    /// Type of temporal shift.
228    pub shift_type: TemporalShiftType,
229    /// Affected field/feature.
230    #[serde(default)]
231    pub affected_field: Option<String>,
232    /// Detection difficulty.
233    #[serde(default)]
234    pub detection_difficulty: DetectionDifficulty,
235    /// Magnitude of change.
236    #[serde(default)]
237    pub magnitude: f64,
238    /// Description.
239    #[serde(default)]
240    pub description: Option<String>,
241}
242
243/// Type of temporal shift.
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
245#[serde(rename_all = "snake_case")]
246pub enum TemporalShiftType {
247    /// Seasonality pattern change.
248    SeasonalityChange,
249    /// Trend change.
250    TrendChange,
251    /// Periodicity change.
252    PeriodicityChange,
253    /// Intraday pattern change.
254    IntradayChange,
255    /// Processing lag change.
256    LagChange,
257}
258
259impl TemporalShiftType {
260    /// Get the type as a string.
261    pub fn as_str(&self) -> &'static str {
262        match self {
263            Self::SeasonalityChange => "seasonality_change",
264            Self::TrendChange => "trend_change",
265            Self::PeriodicityChange => "periodicity_change",
266            Self::IntradayChange => "intraday_change",
267            Self::LagChange => "lag_change",
268        }
269    }
270}
271
272// =============================================================================
273// Organizational, Process, and Technology Drift Events
274// =============================================================================
275
276/// Organizational drift event.
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct OrganizationalDriftEvent {
279    /// Event type (e.g., "acquisition", "divestiture").
280    pub event_type: String,
281    /// Related event ID.
282    pub related_event_id: String,
283    /// Detection difficulty.
284    #[serde(default)]
285    pub detection_difficulty: DetectionDifficulty,
286    /// Affected entities.
287    #[serde(default)]
288    pub affected_entities: Vec<String>,
289    /// Impact metrics.
290    #[serde(default)]
291    pub impact_metrics: HashMap<String, f64>,
292}
293
294/// Process drift event.
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct ProcessDriftEvent {
297    /// Process type (e.g., "automation", "workflow_change").
298    pub process_type: String,
299    /// Related event ID.
300    pub related_event_id: String,
301    /// Detection difficulty.
302    #[serde(default)]
303    pub detection_difficulty: DetectionDifficulty,
304    /// Affected processes.
305    #[serde(default)]
306    pub affected_processes: Vec<String>,
307}
308
309/// Technology drift event.
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct TechnologyDriftEvent {
312    /// Transition type (e.g., "erp_migration", "module_implementation").
313    pub transition_type: String,
314    /// Related event ID.
315    pub related_event_id: String,
316    /// Detection difficulty.
317    #[serde(default)]
318    pub detection_difficulty: DetectionDifficulty,
319    /// Systems involved.
320    #[serde(default)]
321    pub systems: Vec<String>,
322    /// Current phase.
323    #[serde(default)]
324    pub current_phase: Option<String>,
325}
326
327// =============================================================================
328// Regulatory and Audit Drift Events
329// =============================================================================
330
331/// Regulatory drift event.
332#[derive(Debug, Clone, Serialize, Deserialize)]
333pub struct RegulatoryDriftLabel {
334    /// Regulation type.
335    pub regulation_type: String,
336    /// Standard or regulation name.
337    pub regulation_name: String,
338    /// Detection difficulty.
339    #[serde(default)]
340    pub detection_difficulty: DetectionDifficulty,
341    /// Affected accounts.
342    #[serde(default)]
343    pub affected_accounts: Vec<String>,
344    /// Compliance framework.
345    #[serde(default)]
346    pub framework: Option<String>,
347}
348
349/// Audit focus drift event.
350#[derive(Debug, Clone, Serialize, Deserialize)]
351pub struct AuditFocusDriftEvent {
352    /// Focus type.
353    pub focus_type: String,
354    /// Detection difficulty.
355    #[serde(default)]
356    pub detection_difficulty: DetectionDifficulty,
357    /// Risk areas.
358    #[serde(default)]
359    pub risk_areas: Vec<String>,
360    /// Priority level.
361    #[serde(default)]
362    pub priority_level: u8,
363}
364
365// =============================================================================
366// Market and Behavioral Drift Events
367// =============================================================================
368
369/// Market drift event.
370#[derive(Debug, Clone, Serialize, Deserialize)]
371pub struct MarketDriftEvent {
372    /// Market type.
373    pub market_type: MarketEventType,
374    /// Detection difficulty.
375    #[serde(default)]
376    pub detection_difficulty: DetectionDifficulty,
377    /// Magnitude.
378    #[serde(default)]
379    pub magnitude: f64,
380    /// Is recession.
381    #[serde(default)]
382    pub is_recession: bool,
383    /// Affected sectors.
384    #[serde(default)]
385    pub affected_sectors: Vec<String>,
386}
387
388/// Market event type.
389#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
390#[serde(rename_all = "snake_case")]
391pub enum MarketEventType {
392    /// Economic cycle change.
393    EconomicCycle,
394    /// Recession start.
395    RecessionStart,
396    /// Recession end.
397    RecessionEnd,
398    /// Price shock.
399    PriceShock,
400    /// Commodity price change.
401    CommodityChange,
402}
403
404impl MarketEventType {
405    /// Get the type as a string.
406    pub fn as_str(&self) -> &'static str {
407        match self {
408            Self::EconomicCycle => "economic_cycle",
409            Self::RecessionStart => "recession_start",
410            Self::RecessionEnd => "recession_end",
411            Self::PriceShock => "price_shock",
412            Self::CommodityChange => "commodity_change",
413        }
414    }
415}
416
417/// Behavioral drift event.
418#[derive(Debug, Clone, Serialize, Deserialize)]
419pub struct BehavioralDriftEvent {
420    /// Behavior type (e.g., "vendor_quality", "customer_payment").
421    pub behavior_type: String,
422    /// Entity type affected.
423    pub entity_type: String,
424    /// Detection difficulty.
425    #[serde(default)]
426    pub detection_difficulty: DetectionDifficulty,
427    /// Behavior metrics.
428    #[serde(default)]
429    pub metrics: HashMap<String, f64>,
430}
431
432// =============================================================================
433// Labeled Drift Event
434// =============================================================================
435
436/// A labeled drift event with full metadata.
437#[derive(Debug, Clone, Serialize, Deserialize)]
438pub struct LabeledDriftEvent {
439    /// Unique event ID.
440    pub event_id: String,
441    /// Event type.
442    pub event_type: DriftEventType,
443    /// Start date.
444    pub start_date: NaiveDate,
445    /// End date (None for ongoing).
446    #[serde(default)]
447    pub end_date: Option<NaiveDate>,
448    /// Start period (0-indexed).
449    pub start_period: u32,
450    /// End period (None for ongoing).
451    #[serde(default)]
452    pub end_period: Option<u32>,
453    /// Affected fields/features.
454    #[serde(default)]
455    pub affected_fields: Vec<String>,
456    /// Magnitude of the drift.
457    pub magnitude: f64,
458    /// Detection difficulty.
459    pub detection_difficulty: DetectionDifficulty,
460    /// Related organizational event ID if applicable.
461    #[serde(default)]
462    pub related_org_event: Option<String>,
463    /// Tags for categorization.
464    #[serde(default)]
465    pub tags: Vec<String>,
466    /// Additional metadata.
467    #[serde(default)]
468    pub metadata: HashMap<String, String>,
469}
470
471impl LabeledDriftEvent {
472    /// Create a new labeled drift event.
473    pub fn new(
474        event_id: impl Into<String>,
475        event_type: DriftEventType,
476        start_date: NaiveDate,
477        start_period: u32,
478        magnitude: f64,
479    ) -> Self {
480        let detection_difficulty = event_type.detection_difficulty();
481        Self {
482            event_id: event_id.into(),
483            event_type,
484            start_date,
485            end_date: None,
486            start_period,
487            end_period: None,
488            affected_fields: Vec::new(),
489            magnitude,
490            detection_difficulty,
491            related_org_event: None,
492            tags: Vec::new(),
493            metadata: HashMap::new(),
494        }
495    }
496
497    /// Check if the event is active at a given period.
498    pub fn is_active_at(&self, period: u32) -> bool {
499        if period < self.start_period {
500            return false;
501        }
502        match self.end_period {
503            Some(end) => period <= end,
504            None => true,
505        }
506    }
507
508    /// Get the duration in periods (None if ongoing).
509    pub fn duration_periods(&self) -> Option<u32> {
510        self.end_period.map(|end| end - self.start_period + 1)
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_drift_event_type_names() {
520        let stat_event = DriftEventType::Statistical(StatisticalDriftEvent {
521            shift_type: StatisticalShiftType::MeanShift,
522            affected_field: "amount".to_string(),
523            magnitude: 0.15,
524            detection_difficulty: DetectionDifficulty::Easy,
525            metrics: HashMap::new(),
526        });
527
528        assert_eq!(stat_event.category_name(), "statistical");
529        assert_eq!(stat_event.type_name(), "mean_shift");
530    }
531
532    #[test]
533    fn test_labeled_drift_event() {
534        let event = LabeledDriftEvent::new(
535            "DRIFT-001",
536            DriftEventType::Statistical(StatisticalDriftEvent {
537                shift_type: StatisticalShiftType::MeanShift,
538                affected_field: "amount".to_string(),
539                magnitude: 0.20,
540                detection_difficulty: DetectionDifficulty::Medium,
541                metrics: HashMap::new(),
542            }),
543            NaiveDate::from_ymd_opt(2024, 6, 1).unwrap(),
544            6,
545            0.20,
546        );
547
548        assert!(event.is_active_at(6));
549        assert!(event.is_active_at(12)); // Ongoing
550        assert!(!event.is_active_at(5));
551    }
552
553    #[test]
554    fn test_detection_difficulty_score() {
555        assert!(DetectionDifficulty::Easy.score() < DetectionDifficulty::Medium.score());
556        assert!(DetectionDifficulty::Medium.score() < DetectionDifficulty::Hard.score());
557    }
558}