1use crate::distributions::drift::{DriftAdjustments, RegimeChange, RegimeChangeType};
7use crate::models::drift_events::{
8 CategoricalDriftEvent, CategoricalShiftType, DetectionDifficulty, DriftEventType,
9 LabeledDriftEvent, MarketDriftEvent, MarketEventType, OrganizationalDriftEvent,
10 ProcessDriftEvent, StatisticalDriftEvent, StatisticalShiftType, TechnologyDriftEvent,
11 TemporalDriftEvent, TemporalShiftType,
12};
13use chrono::NaiveDate;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::io::Write;
17use std::path::Path;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct DriftRecorderConfig {
22 #[serde(default)]
24 pub enabled: bool,
25 #[serde(default = "default_true")]
27 pub statistical: bool,
28 #[serde(default = "default_true")]
30 pub categorical: bool,
31 #[serde(default = "default_true")]
33 pub temporal: bool,
34 #[serde(default = "default_true")]
36 pub organizational: bool,
37 #[serde(default = "default_true")]
39 pub process_events: bool,
40 #[serde(default = "default_true")]
42 pub technology_events: bool,
43 #[serde(default = "default_true")]
45 pub regulatory: bool,
46 #[serde(default = "default_true")]
48 pub market: bool,
49 #[serde(default = "default_true")]
51 pub behavioral: bool,
52 #[serde(default = "default_min_magnitude")]
54 pub min_magnitude_threshold: f64,
55}
56
57fn default_true() -> bool {
58 true
59}
60
61fn default_min_magnitude() -> f64 {
62 0.05
63}
64
65impl Default for DriftRecorderConfig {
66 fn default() -> Self {
67 Self {
68 enabled: false,
69 statistical: true,
70 categorical: true,
71 temporal: true,
72 organizational: true,
73 process_events: true,
74 technology_events: true,
75 regulatory: true,
76 market: true,
77 behavioral: true,
78 min_magnitude_threshold: 0.05,
79 }
80 }
81}
82
83pub struct DriftLabelRecorder {
85 events: Vec<LabeledDriftEvent>,
87 config: DriftRecorderConfig,
89 start_date: NaiveDate,
91 event_counter: u64,
93 previous_drift: Option<DriftAdjustments>,
95 #[allow(dead_code)]
97 active_regimes: HashMap<String, u32>,
98 was_in_recession: bool,
100}
101
102impl DriftLabelRecorder {
103 pub fn new(config: DriftRecorderConfig, start_date: NaiveDate) -> Self {
105 Self {
106 events: Vec::new(),
107 config,
108 start_date,
109 event_counter: 0,
110 previous_drift: None,
111 active_regimes: HashMap::new(),
112 was_in_recession: false,
113 }
114 }
115
116 pub fn is_enabled(&self) -> bool {
118 self.config.enabled
119 }
120
121 fn next_event_id(&mut self) -> String {
123 self.event_counter += 1;
124 format!("DRIFT-{:06}", self.event_counter)
125 }
126
127 fn period_to_date(&self, period: u32) -> NaiveDate {
129 self.start_date + chrono::Duration::days(period as i64 * 30)
130 }
131
132 pub fn record_regime_change(&mut self, regime: &RegimeChange, period: u32, _date: NaiveDate) {
134 if !self.config.enabled || !self.config.organizational {
135 return;
136 }
137
138 let event_type = match regime.change_type {
139 RegimeChangeType::Acquisition => "acquisition",
140 RegimeChangeType::Divestiture => "divestiture",
141 RegimeChangeType::PriceIncrease => "price_increase",
142 RegimeChangeType::PriceDecrease => "price_decrease",
143 RegimeChangeType::ProductLaunch => "product_launch",
144 RegimeChangeType::ProductDiscontinuation => "product_discontinuation",
145 RegimeChangeType::PolicyChange => "policy_change",
146 RegimeChangeType::CompetitorEntry => "competitor_entry",
147 RegimeChangeType::Custom => "custom",
148 };
149
150 let magnitude = (regime.volume_multiplier() - 1.0)
151 .abs()
152 .max((regime.amount_mean_multiplier() - 1.0).abs());
153
154 if magnitude < self.config.min_magnitude_threshold {
155 return;
156 }
157
158 let detection_difficulty = if magnitude > 0.20 {
159 DetectionDifficulty::Easy
160 } else if magnitude > 0.10 {
161 DetectionDifficulty::Medium
162 } else {
163 DetectionDifficulty::Hard
164 };
165
166 let mut event = LabeledDriftEvent::new(
167 self.next_event_id(),
168 DriftEventType::Organizational(OrganizationalDriftEvent {
169 event_type: event_type.to_string(),
170 related_event_id: regime.description.clone().unwrap_or_default(),
171 detection_difficulty,
172 affected_entities: Vec::new(),
173 impact_metrics: {
174 let mut m = HashMap::new();
175 m.insert("volume_multiplier".to_string(), regime.volume_multiplier());
176 m.insert(
177 "amount_multiplier".to_string(),
178 regime.amount_mean_multiplier(),
179 );
180 m
181 },
182 }),
183 self.period_to_date(period),
184 period,
185 magnitude,
186 );
187
188 event.end_period = Some(period + regime.transition_periods);
189 event.tags.push("regime_change".to_string());
190 event.tags.push(event_type.to_string());
191
192 self.events.push(event);
193 }
194
195 pub fn record_statistical_drift(&mut self, adjustments: &DriftAdjustments, period: u32) {
197 if !self.config.enabled || !self.config.statistical {
198 return;
199 }
200
201 let date = self.period_to_date(period);
202
203 if let Some(ref prev) = self.previous_drift {
205 let mean_delta =
206 (adjustments.amount_mean_multiplier - prev.amount_mean_multiplier).abs();
207 let var_delta =
208 (adjustments.amount_variance_multiplier - prev.amount_variance_multiplier).abs();
209 let prev_mean = prev.amount_mean_multiplier;
210 let current_mean = adjustments.amount_mean_multiplier;
211 let min_threshold = self.config.min_magnitude_threshold;
212
213 if mean_delta >= min_threshold {
214 let detection_difficulty = if mean_delta > 0.20 {
215 DetectionDifficulty::Easy
216 } else if mean_delta > 0.10 {
217 DetectionDifficulty::Medium
218 } else {
219 DetectionDifficulty::Hard
220 };
221
222 let event_id = self.next_event_id();
223 let event = LabeledDriftEvent::new(
224 event_id,
225 DriftEventType::Statistical(StatisticalDriftEvent {
226 shift_type: StatisticalShiftType::MeanShift,
227 affected_field: "amount".to_string(),
228 magnitude: mean_delta,
229 detection_difficulty,
230 metrics: {
231 let mut m = HashMap::new();
232 m.insert("previous_multiplier".to_string(), prev_mean);
233 m.insert("current_multiplier".to_string(), current_mean);
234 m
235 },
236 }),
237 date,
238 period,
239 mean_delta,
240 );
241
242 self.events.push(event);
243 }
244
245 if var_delta >= min_threshold {
247 let event_id = self.next_event_id();
248 let event = LabeledDriftEvent::new(
249 event_id,
250 DriftEventType::Statistical(StatisticalDriftEvent {
251 shift_type: StatisticalShiftType::VarianceChange,
252 affected_field: "amount".to_string(),
253 magnitude: var_delta,
254 detection_difficulty: DetectionDifficulty::Medium,
255 metrics: HashMap::new(),
256 }),
257 date,
258 period,
259 var_delta,
260 );
261
262 self.events.push(event);
263 }
264 }
265
266 if adjustments.sudden_drift_occurred {
268 let event = LabeledDriftEvent::new(
269 self.next_event_id(),
270 DriftEventType::Statistical(StatisticalDriftEvent {
271 shift_type: StatisticalShiftType::DistributionChange,
272 affected_field: "amount".to_string(),
273 magnitude: 0.5, detection_difficulty: DetectionDifficulty::Easy,
275 metrics: HashMap::new(),
276 }),
277 date,
278 period,
279 0.5,
280 );
281
282 self.events.push(event);
283 }
284
285 self.previous_drift = Some(adjustments.clone());
286 }
287
288 pub fn record_market_drift(
290 &mut self,
291 market_type: MarketEventType,
292 period: u32,
293 magnitude: f64,
294 is_recession: bool,
295 ) {
296 if !self.config.enabled || !self.config.market {
297 return;
298 }
299
300 if magnitude < self.config.min_magnitude_threshold
301 && market_type != MarketEventType::RecessionStart
302 && market_type != MarketEventType::RecessionEnd
303 {
304 return;
305 }
306
307 let actual_type = if is_recession && !self.was_in_recession {
309 self.was_in_recession = true;
310 MarketEventType::RecessionStart
311 } else if !is_recession && self.was_in_recession {
312 self.was_in_recession = false;
313 MarketEventType::RecessionEnd
314 } else {
315 market_type
316 };
317
318 let detection_difficulty = match actual_type {
319 MarketEventType::RecessionStart | MarketEventType::RecessionEnd => {
320 DetectionDifficulty::Easy
321 }
322 MarketEventType::PriceShock => DetectionDifficulty::Easy,
323 MarketEventType::EconomicCycle => DetectionDifficulty::Medium,
324 MarketEventType::CommodityChange => DetectionDifficulty::Medium,
325 };
326
327 let event = LabeledDriftEvent::new(
328 self.next_event_id(),
329 DriftEventType::Market(MarketDriftEvent {
330 market_type: actual_type,
331 detection_difficulty,
332 magnitude,
333 is_recession,
334 affected_sectors: Vec::new(),
335 }),
336 self.period_to_date(period),
337 period,
338 magnitude,
339 );
340
341 self.events.push(event);
342 }
343
344 pub fn record_process_drift(
346 &mut self,
347 process_type: &str,
348 related_event_id: &str,
349 period: u32,
350 magnitude: f64,
351 affected_processes: Vec<String>,
352 ) {
353 if !self.config.enabled || !self.config.process_events {
354 return;
355 }
356
357 if magnitude < self.config.min_magnitude_threshold {
358 return;
359 }
360
361 let mut event = LabeledDriftEvent::new(
362 self.next_event_id(),
363 DriftEventType::Process(ProcessDriftEvent {
364 process_type: process_type.to_string(),
365 related_event_id: related_event_id.to_string(),
366 detection_difficulty: DetectionDifficulty::Medium,
367 affected_processes,
368 }),
369 self.period_to_date(period),
370 period,
371 magnitude,
372 );
373
374 event.related_org_event = Some(related_event_id.to_string());
375 self.events.push(event);
376 }
377
378 pub fn record_technology_drift(
380 &mut self,
381 transition_type: &str,
382 related_event_id: &str,
383 period: u32,
384 magnitude: f64,
385 systems: Vec<String>,
386 current_phase: Option<&str>,
387 ) {
388 if !self.config.enabled || !self.config.technology_events {
389 return;
390 }
391
392 if magnitude < self.config.min_magnitude_threshold {
393 return;
394 }
395
396 let mut event = LabeledDriftEvent::new(
397 self.next_event_id(),
398 DriftEventType::Technology(TechnologyDriftEvent {
399 transition_type: transition_type.to_string(),
400 related_event_id: related_event_id.to_string(),
401 detection_difficulty: DetectionDifficulty::Easy, systems,
403 current_phase: current_phase.map(String::from),
404 }),
405 self.period_to_date(period),
406 period,
407 magnitude,
408 );
409
410 event.related_org_event = Some(related_event_id.to_string());
411 self.events.push(event);
412 }
413
414 pub fn record_temporal_drift(
416 &mut self,
417 shift_type: TemporalShiftType,
418 period: u32,
419 magnitude: f64,
420 affected_field: Option<&str>,
421 description: Option<&str>,
422 ) {
423 if !self.config.enabled || !self.config.temporal {
424 return;
425 }
426
427 if magnitude < self.config.min_magnitude_threshold {
428 return;
429 }
430
431 let event = LabeledDriftEvent::new(
432 self.next_event_id(),
433 DriftEventType::Temporal(TemporalDriftEvent {
434 shift_type,
435 affected_field: affected_field.map(String::from),
436 detection_difficulty: DetectionDifficulty::Hard, magnitude,
438 description: description.map(String::from),
439 }),
440 self.period_to_date(period),
441 period,
442 magnitude,
443 );
444
445 self.events.push(event);
446 }
447
448 pub fn record_categorical_drift(
450 &mut self,
451 shift_type: CategoricalShiftType,
452 affected_field: &str,
453 period: u32,
454 proportions_before: HashMap<String, f64>,
455 proportions_after: HashMap<String, f64>,
456 ) {
457 if !self.config.enabled || !self.config.categorical {
458 return;
459 }
460
461 let magnitude = proportions_before
463 .keys()
464 .chain(proportions_after.keys())
465 .map(|k| {
466 let before = proportions_before.get(k).copied().unwrap_or(0.0);
467 let after = proportions_after.get(k).copied().unwrap_or(0.0);
468 (after - before).abs()
469 })
470 .fold(0.0f64, f64::max);
471
472 if magnitude < self.config.min_magnitude_threshold {
473 return;
474 }
475
476 let new_categories: Vec<String> = proportions_after
477 .keys()
478 .filter(|k| !proportions_before.contains_key(*k))
479 .cloned()
480 .collect();
481
482 let removed_categories: Vec<String> = proportions_before
483 .keys()
484 .filter(|k| !proportions_after.contains_key(*k))
485 .cloned()
486 .collect();
487
488 let event = LabeledDriftEvent::new(
489 self.next_event_id(),
490 DriftEventType::Categorical(CategoricalDriftEvent {
491 shift_type,
492 affected_field: affected_field.to_string(),
493 detection_difficulty: DetectionDifficulty::Medium,
494 proportions_before,
495 proportions_after,
496 new_categories,
497 removed_categories,
498 }),
499 self.period_to_date(period),
500 period,
501 magnitude,
502 );
503
504 self.events.push(event);
505 }
506
507 pub fn events(&self) -> &[LabeledDriftEvent] {
509 &self.events
510 }
511
512 pub fn events_in_range(&self, start_period: u32, end_period: u32) -> Vec<&LabeledDriftEvent> {
514 self.events
515 .iter()
516 .filter(|e| e.start_period >= start_period && e.start_period <= end_period)
517 .collect()
518 }
519
520 pub fn events_by_category(&self, category: &str) -> Vec<&LabeledDriftEvent> {
522 self.events
523 .iter()
524 .filter(|e| e.event_type.category_name() == category)
525 .collect()
526 }
527
528 pub fn event_count(&self) -> usize {
530 self.events.len()
531 }
532
533 pub fn export_to_csv(&self, path: &Path) -> std::io::Result<usize> {
535 let mut file = std::fs::File::create(path)?;
536
537 writeln!(
539 file,
540 "event_id,category,type,start_date,end_date,start_period,end_period,magnitude,detection_difficulty,affected_fields,tags"
541 )?;
542
543 for event in &self.events {
545 let end_date = event.end_date.map(|d| d.to_string()).unwrap_or_default();
546 let end_period = event.end_period.map(|p| p.to_string()).unwrap_or_default();
547 let affected_fields = event.affected_fields.join(";");
548 let tags = event.tags.join(";");
549
550 writeln!(
551 file,
552 "{},{},{},{},{},{},{},{:.4},{:?},{},{}",
553 event.event_id,
554 event.event_type.category_name(),
555 event.event_type.type_name(),
556 event.start_date,
557 end_date,
558 event.start_period,
559 end_period,
560 event.magnitude,
561 event.detection_difficulty,
562 affected_fields,
563 tags
564 )?;
565 }
566
567 Ok(self.events.len())
568 }
569
570 pub fn export_to_json(&self, path: &Path) -> std::io::Result<usize> {
572 let json = serde_json::to_string_pretty(&self.events).map_err(std::io::Error::other)?;
573 std::fs::write(path, json)?;
574 Ok(self.events.len())
575 }
576
577 pub fn summary(&self) -> DriftRecorderSummary {
579 let mut by_category: HashMap<String, usize> = HashMap::new();
580 let mut by_difficulty: HashMap<String, usize> = HashMap::new();
581 let mut total_magnitude = 0.0;
582
583 for event in &self.events {
584 *by_category
585 .entry(event.event_type.category_name().to_string())
586 .or_insert(0) += 1;
587 *by_difficulty
588 .entry(format!("{:?}", event.detection_difficulty))
589 .or_insert(0) += 1;
590 total_magnitude += event.magnitude;
591 }
592
593 DriftRecorderSummary {
594 total_events: self.events.len(),
595 by_category,
596 by_difficulty,
597 avg_magnitude: if self.events.is_empty() {
598 0.0
599 } else {
600 total_magnitude / self.events.len() as f64
601 },
602 }
603 }
604}
605
606#[derive(Debug, Clone, Serialize, Deserialize)]
608pub struct DriftRecorderSummary {
609 pub total_events: usize,
611 pub by_category: HashMap<String, usize>,
613 pub by_difficulty: HashMap<String, usize>,
615 pub avg_magnitude: f64,
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[test]
624 fn test_drift_recorder_creation() {
625 let config = DriftRecorderConfig {
626 enabled: true,
627 ..Default::default()
628 };
629 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
630 let recorder = DriftLabelRecorder::new(config, start);
631
632 assert!(recorder.is_enabled());
633 assert_eq!(recorder.event_count(), 0);
634 }
635
636 #[test]
637 fn test_record_regime_change() {
638 let config = DriftRecorderConfig {
639 enabled: true,
640 min_magnitude_threshold: 0.0,
641 ..Default::default()
642 };
643 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
644 let mut recorder = DriftLabelRecorder::new(config, start);
645
646 let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
647 recorder.record_regime_change(®ime, 6, start);
648
649 assert_eq!(recorder.event_count(), 1);
650 let event = &recorder.events()[0];
651 assert_eq!(event.event_type.category_name(), "organizational");
652 }
653
654 #[test]
655 fn test_record_statistical_drift() {
656 let config = DriftRecorderConfig {
657 enabled: true,
658 min_magnitude_threshold: 0.01, ..Default::default()
660 };
661 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
662 let mut recorder = DriftLabelRecorder::new(config, start);
663
664 let adj1 = DriftAdjustments {
666 amount_mean_multiplier: 1.0,
667 ..DriftAdjustments::none()
668 };
669 recorder.record_statistical_drift(&adj1, 0);
670
671 let adj2 = DriftAdjustments {
673 amount_mean_multiplier: 1.25,
674 ..DriftAdjustments::none()
675 };
676 recorder.record_statistical_drift(&adj2, 1);
677
678 assert_eq!(recorder.event_count(), 1);
680 }
681
682 #[test]
683 fn test_summary() {
684 let config = DriftRecorderConfig {
685 enabled: true,
686 min_magnitude_threshold: 0.0,
687 ..Default::default()
688 };
689 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
690 let mut recorder = DriftLabelRecorder::new(config, start);
691
692 let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
693 recorder.record_regime_change(®ime, 6, start);
694
695 let summary = recorder.summary();
696 assert_eq!(summary.total_events, 1);
697 assert!(summary.by_category.contains_key("organizational"));
698 }
699}