1use crate::distributions::drift::{DriftAdjustments, RegimeChange, RegimeChangeType};
7use crate::models::drift_events::{
8 CategoricalDriftEvent, CategoricalShiftType, DetectionDifficulty, DriftEventType,
9 LabeledDriftEvent, MarketDriftEvent, MarketEventType, OrganizationalDriftEvent,
10 ProcessDriftEvent, StatisticalDriftEvent, StatisticalShiftType, TechnologyDriftEvent,
11 TemporalDriftEvent, TemporalShiftType,
12};
13use chrono::NaiveDate;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::io::Write;
17use std::path::Path;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct DriftRecorderConfig {
22 #[serde(default)]
24 pub enabled: bool,
25 #[serde(default = "default_true")]
27 pub statistical: bool,
28 #[serde(default = "default_true")]
30 pub categorical: bool,
31 #[serde(default = "default_true")]
33 pub temporal: bool,
34 #[serde(default = "default_true")]
36 pub organizational: bool,
37 #[serde(default = "default_true")]
39 pub process_events: bool,
40 #[serde(default = "default_true")]
42 pub technology_events: bool,
43 #[serde(default = "default_true")]
45 pub regulatory: bool,
46 #[serde(default = "default_true")]
48 pub market: bool,
49 #[serde(default = "default_true")]
51 pub behavioral: bool,
52 #[serde(default = "default_min_magnitude")]
54 pub min_magnitude_threshold: f64,
55}
56
57fn default_true() -> bool {
58 true
59}
60
61fn default_min_magnitude() -> f64 {
62 0.05
63}
64
65impl Default for DriftRecorderConfig {
66 fn default() -> Self {
67 Self {
68 enabled: false,
69 statistical: true,
70 categorical: true,
71 temporal: true,
72 organizational: true,
73 process_events: true,
74 technology_events: true,
75 regulatory: true,
76 market: true,
77 behavioral: true,
78 min_magnitude_threshold: 0.05,
79 }
80 }
81}
82
83pub struct DriftLabelRecorder {
85 events: Vec<LabeledDriftEvent>,
87 config: DriftRecorderConfig,
89 start_date: NaiveDate,
91 event_counter: u64,
93 previous_drift: Option<DriftAdjustments>,
95 was_in_recession: bool,
97}
98
99impl DriftLabelRecorder {
100 pub fn new(config: DriftRecorderConfig, start_date: NaiveDate) -> Self {
102 Self {
103 events: Vec::new(),
104 config,
105 start_date,
106 event_counter: 0,
107 previous_drift: None,
108 was_in_recession: false,
109 }
110 }
111
112 pub fn is_enabled(&self) -> bool {
114 self.config.enabled
115 }
116
117 fn next_event_id(&mut self) -> String {
119 self.event_counter += 1;
120 format!("DRIFT-{:06}", self.event_counter)
121 }
122
123 fn period_to_date(&self, period: u32) -> NaiveDate {
125 self.start_date + chrono::Duration::days(period as i64 * 30)
126 }
127
128 pub fn record_regime_change(&mut self, regime: &RegimeChange, period: u32, _date: NaiveDate) {
130 if !self.config.enabled || !self.config.organizational {
131 return;
132 }
133
134 let event_type = match regime.change_type {
135 RegimeChangeType::Acquisition => "acquisition",
136 RegimeChangeType::Divestiture => "divestiture",
137 RegimeChangeType::PriceIncrease => "price_increase",
138 RegimeChangeType::PriceDecrease => "price_decrease",
139 RegimeChangeType::ProductLaunch => "product_launch",
140 RegimeChangeType::ProductDiscontinuation => "product_discontinuation",
141 RegimeChangeType::PolicyChange => "policy_change",
142 RegimeChangeType::CompetitorEntry => "competitor_entry",
143 RegimeChangeType::Custom => "custom",
144 };
145
146 let magnitude = (regime.volume_multiplier() - 1.0)
147 .abs()
148 .max((regime.amount_mean_multiplier() - 1.0).abs());
149
150 if magnitude < self.config.min_magnitude_threshold {
151 return;
152 }
153
154 let detection_difficulty = if magnitude > 0.20 {
155 DetectionDifficulty::Easy
156 } else if magnitude > 0.10 {
157 DetectionDifficulty::Medium
158 } else {
159 DetectionDifficulty::Hard
160 };
161
162 let mut event = LabeledDriftEvent::new(
163 self.next_event_id(),
164 DriftEventType::Organizational(OrganizationalDriftEvent {
165 event_type: event_type.to_string(),
166 related_event_id: regime.description.clone().unwrap_or_default(),
167 detection_difficulty,
168 affected_entities: Vec::new(),
169 impact_metrics: {
170 let mut m = HashMap::new();
171 m.insert("volume_multiplier".to_string(), regime.volume_multiplier());
172 m.insert(
173 "amount_multiplier".to_string(),
174 regime.amount_mean_multiplier(),
175 );
176 m
177 },
178 }),
179 self.period_to_date(period),
180 period,
181 magnitude,
182 );
183
184 event.end_period = Some(period + regime.transition_periods);
185 event.tags.push("regime_change".to_string());
186 event.tags.push(event_type.to_string());
187
188 self.events.push(event);
189 }
190
191 pub fn record_statistical_drift(&mut self, adjustments: &DriftAdjustments, period: u32) {
193 if !self.config.enabled || !self.config.statistical {
194 return;
195 }
196
197 let date = self.period_to_date(period);
198
199 if let Some(ref prev) = self.previous_drift {
201 let mean_delta =
202 (adjustments.amount_mean_multiplier - prev.amount_mean_multiplier).abs();
203 let var_delta =
204 (adjustments.amount_variance_multiplier - prev.amount_variance_multiplier).abs();
205 let prev_mean = prev.amount_mean_multiplier;
206 let current_mean = adjustments.amount_mean_multiplier;
207 let min_threshold = self.config.min_magnitude_threshold;
208
209 if mean_delta >= min_threshold {
210 let detection_difficulty = if mean_delta > 0.20 {
211 DetectionDifficulty::Easy
212 } else if mean_delta > 0.10 {
213 DetectionDifficulty::Medium
214 } else {
215 DetectionDifficulty::Hard
216 };
217
218 let event_id = self.next_event_id();
219 let event = LabeledDriftEvent::new(
220 event_id,
221 DriftEventType::Statistical(StatisticalDriftEvent {
222 shift_type: StatisticalShiftType::MeanShift,
223 affected_field: "amount".to_string(),
224 magnitude: mean_delta,
225 detection_difficulty,
226 metrics: {
227 let mut m = HashMap::new();
228 m.insert("previous_multiplier".to_string(), prev_mean);
229 m.insert("current_multiplier".to_string(), current_mean);
230 m
231 },
232 }),
233 date,
234 period,
235 mean_delta,
236 );
237
238 self.events.push(event);
239 }
240
241 if var_delta >= min_threshold {
243 let event_id = self.next_event_id();
244 let event = LabeledDriftEvent::new(
245 event_id,
246 DriftEventType::Statistical(StatisticalDriftEvent {
247 shift_type: StatisticalShiftType::VarianceChange,
248 affected_field: "amount".to_string(),
249 magnitude: var_delta,
250 detection_difficulty: DetectionDifficulty::Medium,
251 metrics: HashMap::new(),
252 }),
253 date,
254 period,
255 var_delta,
256 );
257
258 self.events.push(event);
259 }
260 }
261
262 if adjustments.sudden_drift_occurred {
264 let event = LabeledDriftEvent::new(
265 self.next_event_id(),
266 DriftEventType::Statistical(StatisticalDriftEvent {
267 shift_type: StatisticalShiftType::DistributionChange,
268 affected_field: "amount".to_string(),
269 magnitude: 0.5, detection_difficulty: DetectionDifficulty::Easy,
271 metrics: HashMap::new(),
272 }),
273 date,
274 period,
275 0.5,
276 );
277
278 self.events.push(event);
279 }
280
281 self.previous_drift = Some(adjustments.clone());
282 }
283
284 pub fn record_market_drift(
286 &mut self,
287 market_type: MarketEventType,
288 period: u32,
289 magnitude: f64,
290 is_recession: bool,
291 ) {
292 if !self.config.enabled || !self.config.market {
293 return;
294 }
295
296 if magnitude < self.config.min_magnitude_threshold
297 && market_type != MarketEventType::RecessionStart
298 && market_type != MarketEventType::RecessionEnd
299 {
300 return;
301 }
302
303 let actual_type = if is_recession && !self.was_in_recession {
305 self.was_in_recession = true;
306 MarketEventType::RecessionStart
307 } else if !is_recession && self.was_in_recession {
308 self.was_in_recession = false;
309 MarketEventType::RecessionEnd
310 } else {
311 market_type
312 };
313
314 let detection_difficulty = match actual_type {
315 MarketEventType::RecessionStart | MarketEventType::RecessionEnd => {
316 DetectionDifficulty::Easy
317 }
318 MarketEventType::PriceShock => DetectionDifficulty::Easy,
319 MarketEventType::EconomicCycle => DetectionDifficulty::Medium,
320 MarketEventType::CommodityChange => DetectionDifficulty::Medium,
321 };
322
323 let event = LabeledDriftEvent::new(
324 self.next_event_id(),
325 DriftEventType::Market(MarketDriftEvent {
326 market_type: actual_type,
327 detection_difficulty,
328 magnitude,
329 is_recession,
330 affected_sectors: Vec::new(),
331 }),
332 self.period_to_date(period),
333 period,
334 magnitude,
335 );
336
337 self.events.push(event);
338 }
339
340 pub fn record_process_drift(
342 &mut self,
343 process_type: &str,
344 related_event_id: &str,
345 period: u32,
346 magnitude: f64,
347 affected_processes: Vec<String>,
348 ) {
349 if !self.config.enabled || !self.config.process_events {
350 return;
351 }
352
353 if magnitude < self.config.min_magnitude_threshold {
354 return;
355 }
356
357 let mut event = LabeledDriftEvent::new(
358 self.next_event_id(),
359 DriftEventType::Process(ProcessDriftEvent {
360 process_type: process_type.to_string(),
361 related_event_id: related_event_id.to_string(),
362 detection_difficulty: DetectionDifficulty::Medium,
363 affected_processes,
364 }),
365 self.period_to_date(period),
366 period,
367 magnitude,
368 );
369
370 event.related_org_event = Some(related_event_id.to_string());
371 self.events.push(event);
372 }
373
374 pub fn record_technology_drift(
376 &mut self,
377 transition_type: &str,
378 related_event_id: &str,
379 period: u32,
380 magnitude: f64,
381 systems: Vec<String>,
382 current_phase: Option<&str>,
383 ) {
384 if !self.config.enabled || !self.config.technology_events {
385 return;
386 }
387
388 if magnitude < self.config.min_magnitude_threshold {
389 return;
390 }
391
392 let mut event = LabeledDriftEvent::new(
393 self.next_event_id(),
394 DriftEventType::Technology(TechnologyDriftEvent {
395 transition_type: transition_type.to_string(),
396 related_event_id: related_event_id.to_string(),
397 detection_difficulty: DetectionDifficulty::Easy, systems,
399 current_phase: current_phase.map(String::from),
400 }),
401 self.period_to_date(period),
402 period,
403 magnitude,
404 );
405
406 event.related_org_event = Some(related_event_id.to_string());
407 self.events.push(event);
408 }
409
410 pub fn record_temporal_drift(
412 &mut self,
413 shift_type: TemporalShiftType,
414 period: u32,
415 magnitude: f64,
416 affected_field: Option<&str>,
417 description: Option<&str>,
418 ) {
419 if !self.config.enabled || !self.config.temporal {
420 return;
421 }
422
423 if magnitude < self.config.min_magnitude_threshold {
424 return;
425 }
426
427 let event = LabeledDriftEvent::new(
428 self.next_event_id(),
429 DriftEventType::Temporal(TemporalDriftEvent {
430 shift_type,
431 affected_field: affected_field.map(String::from),
432 detection_difficulty: DetectionDifficulty::Hard, magnitude,
434 description: description.map(String::from),
435 }),
436 self.period_to_date(period),
437 period,
438 magnitude,
439 );
440
441 self.events.push(event);
442 }
443
444 pub fn record_categorical_drift(
446 &mut self,
447 shift_type: CategoricalShiftType,
448 affected_field: &str,
449 period: u32,
450 proportions_before: HashMap<String, f64>,
451 proportions_after: HashMap<String, f64>,
452 ) {
453 if !self.config.enabled || !self.config.categorical {
454 return;
455 }
456
457 let magnitude = proportions_before
459 .keys()
460 .chain(proportions_after.keys())
461 .map(|k| {
462 let before = proportions_before.get(k).copied().unwrap_or(0.0);
463 let after = proportions_after.get(k).copied().unwrap_or(0.0);
464 (after - before).abs()
465 })
466 .fold(0.0f64, f64::max);
467
468 if magnitude < self.config.min_magnitude_threshold {
469 return;
470 }
471
472 let new_categories: Vec<String> = proportions_after
473 .keys()
474 .filter(|k| !proportions_before.contains_key(*k))
475 .cloned()
476 .collect();
477
478 let removed_categories: Vec<String> = proportions_before
479 .keys()
480 .filter(|k| !proportions_after.contains_key(*k))
481 .cloned()
482 .collect();
483
484 let event = LabeledDriftEvent::new(
485 self.next_event_id(),
486 DriftEventType::Categorical(CategoricalDriftEvent {
487 shift_type,
488 affected_field: affected_field.to_string(),
489 detection_difficulty: DetectionDifficulty::Medium,
490 proportions_before,
491 proportions_after,
492 new_categories,
493 removed_categories,
494 }),
495 self.period_to_date(period),
496 period,
497 magnitude,
498 );
499
500 self.events.push(event);
501 }
502
503 pub fn events(&self) -> &[LabeledDriftEvent] {
505 &self.events
506 }
507
508 pub fn events_in_range(&self, start_period: u32, end_period: u32) -> Vec<&LabeledDriftEvent> {
510 self.events
511 .iter()
512 .filter(|e| e.start_period >= start_period && e.start_period <= end_period)
513 .collect()
514 }
515
516 pub fn events_by_category(&self, category: &str) -> Vec<&LabeledDriftEvent> {
518 self.events
519 .iter()
520 .filter(|e| e.event_type.category_name() == category)
521 .collect()
522 }
523
524 pub fn event_count(&self) -> usize {
526 self.events.len()
527 }
528
529 pub fn export_to_csv(&self, path: &Path) -> std::io::Result<usize> {
531 let mut file = std::fs::File::create(path)?;
532
533 writeln!(
535 file,
536 "event_id,category,type,start_date,end_date,start_period,end_period,magnitude,detection_difficulty,affected_fields,tags"
537 )?;
538
539 for event in &self.events {
541 let end_date = event.end_date.map(|d| d.to_string()).unwrap_or_default();
542 let end_period = event.end_period.map(|p| p.to_string()).unwrap_or_default();
543 let affected_fields = event.affected_fields.join(";");
544 let tags = event.tags.join(";");
545
546 writeln!(
547 file,
548 "{},{},{},{},{},{},{},{:.4},{:?},{},{}",
549 event.event_id,
550 event.event_type.category_name(),
551 event.event_type.type_name(),
552 event.start_date,
553 end_date,
554 event.start_period,
555 end_period,
556 event.magnitude,
557 event.detection_difficulty,
558 affected_fields,
559 tags
560 )?;
561 }
562
563 Ok(self.events.len())
564 }
565
566 pub fn export_to_json(&self, path: &Path) -> std::io::Result<usize> {
568 let json = serde_json::to_string_pretty(&self.events).map_err(std::io::Error::other)?;
569 std::fs::write(path, json)?;
570 Ok(self.events.len())
571 }
572
573 pub fn summary(&self) -> DriftRecorderSummary {
575 let mut by_category: HashMap<String, usize> = HashMap::new();
576 let mut by_difficulty: HashMap<String, usize> = HashMap::new();
577 let mut total_magnitude = 0.0;
578
579 for event in &self.events {
580 *by_category
581 .entry(event.event_type.category_name().to_string())
582 .or_insert(0) += 1;
583 *by_difficulty
584 .entry(format!("{:?}", event.detection_difficulty))
585 .or_insert(0) += 1;
586 total_magnitude += event.magnitude;
587 }
588
589 DriftRecorderSummary {
590 total_events: self.events.len(),
591 by_category,
592 by_difficulty,
593 avg_magnitude: if self.events.is_empty() {
594 0.0
595 } else {
596 total_magnitude / self.events.len() as f64
597 },
598 }
599 }
600}
601
602#[derive(Debug, Clone, Serialize, Deserialize)]
604pub struct DriftRecorderSummary {
605 pub total_events: usize,
607 pub by_category: HashMap<String, usize>,
609 pub by_difficulty: HashMap<String, usize>,
611 pub avg_magnitude: f64,
613}
614
615#[cfg(test)]
616#[allow(clippy::unwrap_used)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_drift_recorder_creation() {
622 let config = DriftRecorderConfig {
623 enabled: true,
624 ..Default::default()
625 };
626 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
627 let recorder = DriftLabelRecorder::new(config, start);
628
629 assert!(recorder.is_enabled());
630 assert_eq!(recorder.event_count(), 0);
631 }
632
633 #[test]
634 fn test_record_regime_change() {
635 let config = DriftRecorderConfig {
636 enabled: true,
637 min_magnitude_threshold: 0.0,
638 ..Default::default()
639 };
640 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
641 let mut recorder = DriftLabelRecorder::new(config, start);
642
643 let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
644 recorder.record_regime_change(®ime, 6, start);
645
646 assert_eq!(recorder.event_count(), 1);
647 let event = &recorder.events()[0];
648 assert_eq!(event.event_type.category_name(), "organizational");
649 }
650
651 #[test]
652 fn test_record_statistical_drift() {
653 let config = DriftRecorderConfig {
654 enabled: true,
655 min_magnitude_threshold: 0.01, ..Default::default()
657 };
658 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
659 let mut recorder = DriftLabelRecorder::new(config, start);
660
661 let adj1 = DriftAdjustments {
663 amount_mean_multiplier: 1.0,
664 ..DriftAdjustments::none()
665 };
666 recorder.record_statistical_drift(&adj1, 0);
667
668 let adj2 = DriftAdjustments {
670 amount_mean_multiplier: 1.25,
671 ..DriftAdjustments::none()
672 };
673 recorder.record_statistical_drift(&adj2, 1);
674
675 assert_eq!(recorder.event_count(), 1);
677 }
678
679 #[test]
680 fn test_summary() {
681 let config = DriftRecorderConfig {
682 enabled: true,
683 min_magnitude_threshold: 0.0,
684 ..Default::default()
685 };
686 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
687 let mut recorder = DriftLabelRecorder::new(config, start);
688
689 let regime = RegimeChange::new(6, RegimeChangeType::Acquisition);
690 recorder.record_regime_change(®ime, 6, start);
691
692 let summary = recorder.summary();
693 assert_eq!(summary.total_events, 1);
694 assert!(summary.by_category.contains_key("organizational"));
695 }
696}