use chrono::NaiveDate;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "category", rename_all = "snake_case")]
pub enum DriftEventType {
Statistical(StatisticalDriftEvent),
Categorical(CategoricalDriftEvent),
Temporal(TemporalDriftEvent),
Organizational(OrganizationalDriftEvent),
Process(ProcessDriftEvent),
Technology(TechnologyDriftEvent),
Regulatory(RegulatoryDriftLabel),
AuditFocus(AuditFocusDriftEvent),
Market(MarketDriftEvent),
Behavioral(BehavioralDriftEvent),
}
impl DriftEventType {
pub fn category_name(&self) -> &'static str {
match self {
Self::Statistical(_) => "statistical",
Self::Categorical(_) => "categorical",
Self::Temporal(_) => "temporal",
Self::Organizational(_) => "organizational",
Self::Process(_) => "process",
Self::Technology(_) => "technology",
Self::Regulatory(_) => "regulatory",
Self::AuditFocus(_) => "audit_focus",
Self::Market(_) => "market",
Self::Behavioral(_) => "behavioral",
}
}
pub fn type_name(&self) -> &str {
match self {
Self::Statistical(e) => e.shift_type.as_str(),
Self::Categorical(e) => e.shift_type.as_str(),
Self::Temporal(e) => e.shift_type.as_str(),
Self::Organizational(e) => &e.event_type,
Self::Process(e) => &e.process_type,
Self::Technology(e) => &e.transition_type,
Self::Regulatory(e) => &e.regulation_type,
Self::AuditFocus(e) => &e.focus_type,
Self::Market(e) => e.market_type.as_str(),
Self::Behavioral(e) => &e.behavior_type,
}
}
pub fn detection_difficulty(&self) -> DetectionDifficulty {
match self {
Self::Statistical(e) => e.detection_difficulty,
Self::Categorical(e) => e.detection_difficulty,
Self::Temporal(e) => e.detection_difficulty,
Self::Organizational(e) => e.detection_difficulty,
Self::Process(e) => e.detection_difficulty,
Self::Technology(e) => e.detection_difficulty,
Self::Regulatory(e) => e.detection_difficulty,
Self::AuditFocus(e) => e.detection_difficulty,
Self::Market(e) => e.detection_difficulty,
Self::Behavioral(e) => e.detection_difficulty,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum DetectionDifficulty {
Easy,
#[default]
Medium,
Hard,
}
impl DetectionDifficulty {
pub fn score(&self) -> f64 {
match self {
Self::Easy => 0.0,
Self::Medium => 0.5,
Self::Hard => 1.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatisticalDriftEvent {
pub shift_type: StatisticalShiftType,
pub affected_field: String,
pub magnitude: f64,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StatisticalShiftType {
MeanShift,
VarianceChange,
DistributionChange,
CorrelationChange,
TailChange,
BenfordDeviation,
}
impl StatisticalShiftType {
pub fn as_str(&self) -> &'static str {
match self {
Self::MeanShift => "mean_shift",
Self::VarianceChange => "variance_change",
Self::DistributionChange => "distribution_change",
Self::CorrelationChange => "correlation_change",
Self::TailChange => "tail_change",
Self::BenfordDeviation => "benford_deviation",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CategoricalDriftEvent {
pub shift_type: CategoricalShiftType,
pub affected_field: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub proportions_before: HashMap<String, f64>,
#[serde(default)]
pub proportions_after: HashMap<String, f64>,
#[serde(default)]
pub new_categories: Vec<String>,
#[serde(default)]
pub removed_categories: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CategoricalShiftType {
ProportionShift,
NewCategory,
CategoryRemoval,
Consolidation,
}
impl CategoricalShiftType {
pub fn as_str(&self) -> &'static str {
match self {
Self::ProportionShift => "proportion_shift",
Self::NewCategory => "new_category",
Self::CategoryRemoval => "category_removal",
Self::Consolidation => "consolidation",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalDriftEvent {
pub shift_type: TemporalShiftType,
#[serde(default)]
pub affected_field: Option<String>,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub magnitude: f64,
#[serde(default)]
pub description: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TemporalShiftType {
SeasonalityChange,
TrendChange,
PeriodicityChange,
IntradayChange,
LagChange,
}
impl TemporalShiftType {
pub fn as_str(&self) -> &'static str {
match self {
Self::SeasonalityChange => "seasonality_change",
Self::TrendChange => "trend_change",
Self::PeriodicityChange => "periodicity_change",
Self::IntradayChange => "intraday_change",
Self::LagChange => "lag_change",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OrganizationalDriftEvent {
pub event_type: String,
pub related_event_id: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub affected_entities: Vec<String>,
#[serde(default)]
pub impact_metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProcessDriftEvent {
pub process_type: String,
pub related_event_id: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub affected_processes: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TechnologyDriftEvent {
pub transition_type: String,
pub related_event_id: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub systems: Vec<String>,
#[serde(default)]
pub current_phase: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegulatoryDriftLabel {
pub regulation_type: String,
pub regulation_name: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub affected_accounts: Vec<String>,
#[serde(default)]
pub framework: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditFocusDriftEvent {
pub focus_type: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub risk_areas: Vec<String>,
#[serde(default)]
pub priority_level: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MarketDriftEvent {
pub market_type: MarketEventType,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub magnitude: f64,
#[serde(default)]
pub is_recession: bool,
#[serde(default)]
pub affected_sectors: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MarketEventType {
EconomicCycle,
RecessionStart,
RecessionEnd,
PriceShock,
CommodityChange,
}
impl MarketEventType {
pub fn as_str(&self) -> &'static str {
match self {
Self::EconomicCycle => "economic_cycle",
Self::RecessionStart => "recession_start",
Self::RecessionEnd => "recession_end",
Self::PriceShock => "price_shock",
Self::CommodityChange => "commodity_change",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BehavioralDriftEvent {
pub behavior_type: String,
pub entity_type: String,
#[serde(default)]
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LabeledDriftEvent {
pub event_id: String,
pub event_type: DriftEventType,
pub start_date: NaiveDate,
#[serde(default)]
pub end_date: Option<NaiveDate>,
pub start_period: u32,
#[serde(default)]
pub end_period: Option<u32>,
#[serde(default)]
pub affected_fields: Vec<String>,
pub magnitude: f64,
pub detection_difficulty: DetectionDifficulty,
#[serde(default)]
pub related_org_event: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl LabeledDriftEvent {
pub fn new(
event_id: impl Into<String>,
event_type: DriftEventType,
start_date: NaiveDate,
start_period: u32,
magnitude: f64,
) -> Self {
let detection_difficulty = event_type.detection_difficulty();
Self {
event_id: event_id.into(),
event_type,
start_date,
end_date: None,
start_period,
end_period: None,
affected_fields: Vec::new(),
magnitude,
detection_difficulty,
related_org_event: None,
tags: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn is_active_at(&self, period: u32) -> bool {
if period < self.start_period {
return false;
}
match self.end_period {
Some(end) => period <= end,
None => true,
}
}
pub fn duration_periods(&self) -> Option<u32> {
self.end_period.map(|end| end - self.start_period + 1)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_drift_event_type_names() {
let stat_event = DriftEventType::Statistical(StatisticalDriftEvent {
shift_type: StatisticalShiftType::MeanShift,
affected_field: "amount".to_string(),
magnitude: 0.15,
detection_difficulty: DetectionDifficulty::Easy,
metrics: HashMap::new(),
});
assert_eq!(stat_event.category_name(), "statistical");
assert_eq!(stat_event.type_name(), "mean_shift");
}
#[test]
fn test_labeled_drift_event() {
let event = LabeledDriftEvent::new(
"DRIFT-001",
DriftEventType::Statistical(StatisticalDriftEvent {
shift_type: StatisticalShiftType::MeanShift,
affected_field: "amount".to_string(),
magnitude: 0.20,
detection_difficulty: DetectionDifficulty::Medium,
metrics: HashMap::new(),
}),
NaiveDate::from_ymd_opt(2024, 6, 1).unwrap(),
6,
0.20,
);
assert!(event.is_active_at(6));
assert!(event.is_active_at(12)); assert!(!event.is_active_at(5));
}
#[test]
fn test_detection_difficulty_score() {
assert!(DetectionDifficulty::Easy.score() < DetectionDifficulty::Medium.score());
assert!(DetectionDifficulty::Medium.score() < DetectionDifficulty::Hard.score());
}
}