1use chrono::{NaiveDate, NaiveTime, Timelike};
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use datasynth_core::models::{AnomalyType, SeverityLevel, StatisticalAnomalyType};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BehavioralBaselineConfig {
16 pub enabled: bool,
18 pub baseline_period_days: u32,
20 pub min_observations: u32,
22 pub amount_deviation_threshold: f64,
24 pub frequency_deviation_threshold: f64,
26 pub recency_decay_factor: f64,
28}
29
30impl Default for BehavioralBaselineConfig {
31 fn default() -> Self {
32 Self {
33 enabled: true,
34 baseline_period_days: 90,
35 min_observations: 10,
36 amount_deviation_threshold: 3.0,
37 frequency_deviation_threshold: 2.0,
38 recency_decay_factor: 0.95,
39 }
40 }
41}
42
43pub struct BehavioralBaseline {
45 config: BehavioralBaselineConfig,
46 entity_baselines: HashMap<String, EntityBaseline>,
48}
49
50impl Default for BehavioralBaseline {
51 fn default() -> Self {
52 Self::new(BehavioralBaselineConfig::default())
53 }
54}
55
56impl BehavioralBaseline {
57 pub fn new(config: BehavioralBaselineConfig) -> Self {
59 Self {
60 config,
61 entity_baselines: HashMap::new(),
62 }
63 }
64
65 pub fn record_observation(&mut self, entity_id: impl Into<String>, observation: Observation) {
67 let id = entity_id.into();
68 let baseline = self.entity_baselines.entry(id).or_default();
69 baseline.add_observation(observation);
70 }
71
72 pub fn get_baseline(&self, entity_id: &str) -> Option<&EntityBaseline> {
74 self.entity_baselines.get(entity_id)
75 }
76
77 pub fn check_deviation(
79 &self,
80 entity_id: &str,
81 observation: &Observation,
82 ) -> Vec<BehavioralDeviation> {
83 if !self.config.enabled {
84 return Vec::new();
85 }
86
87 let baseline = match self.get_baseline(entity_id) {
88 Some(b) if b.observation_count >= self.config.min_observations => b,
89 _ => return Vec::new(),
90 };
91
92 let mut deviations = Vec::new();
93
94 if let Some(amount) = observation.amount {
96 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
97 if baseline.amount_std_dev > 0.0 {
98 let z_score =
99 (amount_f64 - baseline.avg_transaction_amount) / baseline.amount_std_dev;
100 if z_score.abs() > self.config.amount_deviation_threshold {
101 deviations.push(BehavioralDeviation {
102 deviation_type: DeviationType::AmountAnomaly,
103 std_deviations: z_score.abs(),
104 expected_value: baseline.avg_transaction_amount,
105 actual_value: amount_f64,
106 label: AnomalyType::Statistical(
107 StatisticalAnomalyType::UnusuallyHighAmount,
108 ),
109 severity: Self::severity_from_std_dev(z_score.abs()),
110 description: format!(
111 "Amount ${:.2} is {:.1} std devs from mean ${:.2}",
112 amount_f64,
113 z_score.abs(),
114 baseline.avg_transaction_amount
115 ),
116 });
117 }
118 }
119 }
120
121 if let Some(time) = observation.time {
123 if !baseline.is_within_typical_hours(time) {
124 deviations.push(BehavioralDeviation {
125 deviation_type: DeviationType::TimingAnomaly,
126 std_deviations: 0.0,
127 expected_value: 0.0,
128 actual_value: 0.0,
129 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualTiming),
130 severity: SeverityLevel::Low,
131 description: format!(
132 "Transaction at {} outside typical hours {:02}:00-{:02}:00",
133 time, baseline.typical_posting_hours.0, baseline.typical_posting_hours.1
134 ),
135 });
136 }
137 }
138
139 if let Some(ref counterparty) = observation.counterparty {
141 if !baseline.common_counterparties.contains(counterparty)
142 && baseline.common_counterparties.len() >= 5
143 {
144 deviations.push(BehavioralDeviation {
145 deviation_type: DeviationType::NewCounterparty,
146 std_deviations: 0.0,
147 expected_value: 0.0,
148 actual_value: 0.0,
149 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
150 severity: SeverityLevel::Low,
151 description: format!(
152 "New counterparty '{counterparty}' not in typical partners"
153 ),
154 });
155 }
156 }
157
158 if let Some(ref account) = observation.account_code {
160 if !baseline.usual_account_codes.contains(account)
161 && baseline.usual_account_codes.len() >= 3
162 {
163 deviations.push(BehavioralDeviation {
164 deviation_type: DeviationType::UnusualAccount,
165 std_deviations: 0.0,
166 expected_value: 0.0,
167 actual_value: 0.0,
168 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
169 severity: SeverityLevel::Low,
170 description: format!("Account '{account}' not typically used by this entity"),
171 });
172 }
173 }
174
175 deviations
176 }
177
178 fn severity_from_std_dev(std_devs: f64) -> SeverityLevel {
180 if std_devs > 5.0 {
181 SeverityLevel::Critical
182 } else if std_devs > 4.0 {
183 SeverityLevel::High
184 } else if std_devs > 3.5 {
185 SeverityLevel::Medium
186 } else {
187 SeverityLevel::Low
188 }
189 }
190
191 pub fn check_frequency_deviation(
193 &self,
194 entity_id: &str,
195 current_frequency: f64,
196 ) -> Option<BehavioralDeviation> {
197 if !self.config.enabled {
198 return None;
199 }
200
201 let baseline = self.get_baseline(entity_id)?;
202
203 if baseline.observation_count < self.config.min_observations {
204 return None;
205 }
206
207 if baseline.frequency_std_dev <= 0.0 {
208 return None;
209 }
210
211 let z_score =
212 (current_frequency - baseline.transaction_frequency) / baseline.frequency_std_dev;
213
214 if z_score.abs() > self.config.frequency_deviation_threshold {
215 Some(BehavioralDeviation {
216 deviation_type: DeviationType::FrequencyAnomaly,
217 std_deviations: z_score.abs(),
218 expected_value: baseline.transaction_frequency,
219 actual_value: current_frequency,
220 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualFrequency),
221 severity: Self::severity_from_std_dev(z_score.abs()),
222 description: format!(
223 "Frequency {:.2}/day is {:.1} std devs from normal {:.2}/day",
224 current_frequency,
225 z_score.abs(),
226 baseline.transaction_frequency
227 ),
228 })
229 } else {
230 None
231 }
232 }
233
234 pub fn entity_count(&self) -> usize {
236 self.entity_baselines.len()
237 }
238
239 pub fn config(&self) -> &BehavioralBaselineConfig {
241 &self.config
242 }
243
244 pub fn clear(&mut self) {
246 self.entity_baselines.clear();
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct Observation {
253 pub date: NaiveDate,
255 pub time: Option<NaiveTime>,
257 pub amount: Option<Decimal>,
259 pub counterparty: Option<String>,
261 pub account_code: Option<String>,
263}
264
265impl Observation {
266 pub fn new(date: NaiveDate) -> Self {
268 Self {
269 date,
270 time: None,
271 amount: None,
272 counterparty: None,
273 account_code: None,
274 }
275 }
276
277 pub fn with_time(mut self, time: NaiveTime) -> Self {
279 self.time = Some(time);
280 self
281 }
282
283 pub fn with_amount(mut self, amount: Decimal) -> Self {
285 self.amount = Some(amount);
286 self
287 }
288
289 pub fn with_counterparty(mut self, counterparty: impl Into<String>) -> Self {
291 self.counterparty = Some(counterparty.into());
292 self
293 }
294
295 pub fn with_account(mut self, account: impl Into<String>) -> Self {
297 self.account_code = Some(account.into());
298 self
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct EntityBaseline {
305 pub avg_transaction_amount: f64,
307 pub amount_std_dev: f64,
309 pub transaction_frequency: f64,
311 pub frequency_std_dev: f64,
313 pub typical_posting_hours: (u8, u8),
315 pub common_counterparties: Vec<String>,
317 pub usual_account_codes: Vec<String>,
319 pub observation_count: u32,
321 #[serde(skip)]
323 amount_sum: f64,
324 #[serde(skip)]
326 amount_sum_sq: f64,
327 #[serde(skip)]
329 daily_counts: HashMap<NaiveDate, u32>,
330 #[serde(skip)]
332 hour_counts: [u32; 24],
333 #[serde(skip)]
335 counterparty_freq: HashMap<String, u32>,
336 #[serde(skip)]
338 account_freq: HashMap<String, u32>,
339}
340
341impl Default for EntityBaseline {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347impl EntityBaseline {
348 pub fn new() -> Self {
350 Self {
351 avg_transaction_amount: 0.0,
352 amount_std_dev: 0.0,
353 transaction_frequency: 0.0,
354 frequency_std_dev: 0.0,
355 typical_posting_hours: (8, 18),
356 common_counterparties: Vec::new(),
357 usual_account_codes: Vec::new(),
358 observation_count: 0,
359 amount_sum: 0.0,
360 amount_sum_sq: 0.0,
361 daily_counts: HashMap::new(),
362 hour_counts: [0; 24],
363 counterparty_freq: HashMap::new(),
364 account_freq: HashMap::new(),
365 }
366 }
367
368 pub fn add_observation(&mut self, observation: Observation) {
370 self.observation_count += 1;
371
372 if let Some(amount) = observation.amount {
374 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
375 self.amount_sum += amount_f64;
376 self.amount_sum_sq += amount_f64 * amount_f64;
377 self.avg_transaction_amount = self.amount_sum / self.observation_count as f64;
378
379 if self.observation_count > 1 {
380 let variance = (self.amount_sum_sq
381 - (self.amount_sum * self.amount_sum) / self.observation_count as f64)
382 / (self.observation_count - 1) as f64;
383 self.amount_std_dev = variance.max(0.0).sqrt();
384 }
385 }
386
387 *self.daily_counts.entry(observation.date).or_insert(0) += 1;
389 self.update_frequency_stats();
390
391 if let Some(time) = observation.time {
393 self.hour_counts[time.hour() as usize] += 1;
394 self.update_typical_hours();
395 }
396
397 if let Some(ref counterparty) = observation.counterparty {
399 *self
400 .counterparty_freq
401 .entry(counterparty.clone())
402 .or_insert(0) += 1;
403 self.update_common_counterparties();
404 }
405
406 if let Some(ref account) = observation.account_code {
408 *self.account_freq.entry(account.clone()).or_insert(0) += 1;
409 self.update_usual_accounts();
410 }
411 }
412
413 fn update_frequency_stats(&mut self) {
415 if self.daily_counts.is_empty() {
416 return;
417 }
418
419 let counts: Vec<f64> = self.daily_counts.values().map(|&c| c as f64).collect();
420 let n = counts.len() as f64;
421
422 self.transaction_frequency = counts.iter().sum::<f64>() / n;
423
424 if counts.len() > 1 {
425 let variance: f64 = counts
426 .iter()
427 .map(|c| (c - self.transaction_frequency).powi(2))
428 .sum::<f64>()
429 / (n - 1.0);
430 self.frequency_std_dev = variance.sqrt();
431 }
432 }
433
434 fn update_typical_hours(&mut self) {
436 let total: u32 = self.hour_counts.iter().sum();
437 if total == 0 {
438 return;
439 }
440
441 let threshold = (total as f64 * 0.1) as u32; let mut cumsum = 0u32;
445 let mut start_hour = 0u8;
446 for (hour, &count) in self.hour_counts.iter().enumerate() {
447 cumsum += count;
448 if cumsum > threshold {
449 start_hour = hour as u8;
450 break;
451 }
452 }
453
454 cumsum = 0;
455 let mut end_hour = 23u8;
456 for (hour, &count) in self.hour_counts.iter().enumerate().rev() {
457 cumsum += count;
458 if cumsum > threshold {
459 end_hour = hour as u8;
460 break;
461 }
462 }
463
464 self.typical_posting_hours = (start_hour, end_hour.max(start_hour + 1));
465 }
466
467 fn update_common_counterparties(&mut self) {
469 let mut sorted: Vec<_> = self.counterparty_freq.iter().collect();
470 sorted.sort_by(|a, b| b.1.cmp(a.1));
471 self.common_counterparties = sorted
472 .into_iter()
473 .take(10)
474 .map(|(k, _)| k.clone())
475 .collect();
476 }
477
478 fn update_usual_accounts(&mut self) {
480 let mut sorted: Vec<_> = self.account_freq.iter().collect();
481 sorted.sort_by(|a, b| b.1.cmp(a.1));
482 self.usual_account_codes = sorted.into_iter().take(5).map(|(k, _)| k.clone()).collect();
483 }
484
485 pub fn is_within_typical_hours(&self, time: NaiveTime) -> bool {
487 let hour = time.hour() as u8;
488 hour >= self.typical_posting_hours.0 && hour <= self.typical_posting_hours.1
489 }
490
491 pub fn is_established(&self, min_observations: u32) -> bool {
493 self.observation_count >= min_observations
494 }
495}
496
497#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
499pub enum DeviationType {
500 AmountAnomaly,
502 FrequencyAnomaly,
504 NewCounterparty,
506 TimingAnomaly,
508 UnusualAccount,
510}
511
512#[derive(Debug, Clone)]
514pub struct BehavioralDeviation {
515 pub deviation_type: DeviationType,
517 pub std_deviations: f64,
519 pub expected_value: f64,
521 pub actual_value: f64,
523 pub label: AnomalyType,
525 pub severity: SeverityLevel,
527 pub description: String,
529}
530
531#[cfg(test)]
532#[allow(clippy::unwrap_used)]
533mod tests {
534 use super::*;
535 use rust_decimal_macros::dec;
536
537 #[test]
538 fn test_entity_baseline_creation() {
539 let baseline = EntityBaseline::new();
540 assert_eq!(baseline.observation_count, 0);
541 assert!((baseline.avg_transaction_amount - 0.0).abs() < 0.01);
542 }
543
544 #[test]
545 fn test_observation_builder() {
546 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
547 .with_amount(dec!(1000))
548 .with_counterparty("VENDOR001")
549 .with_account("5000");
550
551 assert_eq!(obs.amount, Some(dec!(1000)));
552 assert_eq!(obs.counterparty, Some("VENDOR001".to_string()));
553 assert_eq!(obs.account_code, Some("5000".to_string()));
554 }
555
556 #[test]
557 fn test_baseline_amount_tracking() {
558 let mut baseline = EntityBaseline::new();
559
560 for amount in [1000.0, 1100.0, 900.0, 1050.0, 950.0] {
561 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
562 .with_amount(Decimal::try_from(amount).unwrap());
563 baseline.add_observation(obs);
564 }
565
566 assert_eq!(baseline.observation_count, 5);
567 assert!((baseline.avg_transaction_amount - 1000.0).abs() < 1.0);
568 assert!(baseline.amount_std_dev > 0.0);
569 }
570
571 #[test]
572 fn test_behavioral_baseline_deviation_detection() {
573 let mut baseline_mgr = BehavioralBaseline::default();
574
575 let amounts = [
578 900, 950, 1000, 1050, 1100, 920, 980, 1020, 1080, 950, 960, 1000, 1040, 990, 970, 1010,
579 1030, 1000, 980, 1020,
580 ];
581 for (i, &amount) in amounts.iter().enumerate() {
582 let obs = Observation::new(
583 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
584 + chrono::Duration::days(i as i64 % 10),
585 )
586 .with_amount(Decimal::from(amount))
587 .with_counterparty("VENDOR001")
588 .with_time(NaiveTime::from_hms_opt(10, 0, 0).unwrap());
589 baseline_mgr.record_observation("ENTITY1", obs);
590 }
591
592 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
594 .with_amount(dec!(50000))
595 .with_counterparty("VENDOR001");
596
597 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
598
599 assert!(deviations
601 .iter()
602 .any(|d| d.deviation_type == DeviationType::AmountAnomaly));
603 }
604
605 #[test]
606 fn test_new_counterparty_detection() {
607 let mut baseline_mgr = BehavioralBaseline::default();
608
609 for i in 0..15 {
611 let cp = format!("VENDOR{:03}", i % 5);
612 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
613 .with_amount(dec!(1000))
614 .with_counterparty(&cp);
615 baseline_mgr.record_observation("ENTITY1", obs);
616 }
617
618 let new_cp_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
620 .with_amount(dec!(1000))
621 .with_counterparty("NEW_VENDOR");
622
623 let deviations = baseline_mgr.check_deviation("ENTITY1", &new_cp_obs);
624
625 assert!(deviations
627 .iter()
628 .any(|d| d.deviation_type == DeviationType::NewCounterparty));
629 }
630
631 #[test]
632 fn test_timing_anomaly_detection() {
633 let mut baseline_mgr = BehavioralBaseline::default();
634
635 for i in 0..15 {
637 let hour = 9 + (i % 8);
638 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
639 .with_amount(dec!(1000))
640 .with_time(NaiveTime::from_hms_opt(hour, 0, 0).unwrap());
641 baseline_mgr.record_observation("ENTITY1", obs);
642 }
643
644 let unusual_time_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
646 .with_amount(dec!(1000))
647 .with_time(NaiveTime::from_hms_opt(3, 0, 0).unwrap());
648
649 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_time_obs);
650
651 assert!(deviations
653 .iter()
654 .any(|d| d.deviation_type == DeviationType::TimingAnomaly));
655 }
656
657 #[test]
658 fn test_frequency_deviation() {
659 let mut baseline_mgr = BehavioralBaseline::default();
660
661 let daily_counts = [
663 2, 1, 3, 2, 2, 1, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3,
664 2,
665 ];
666 for (day, &count) in daily_counts.iter().enumerate() {
667 for _ in 0..count {
668 let obs = Observation::new(
669 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
670 + chrono::Duration::days(day as i64),
671 )
672 .with_amount(dec!(1000));
673 baseline_mgr.record_observation("ENTITY1", obs);
674 }
675 }
676
677 let deviation = baseline_mgr.check_frequency_deviation("ENTITY1", 10.0);
679
680 assert!(deviation.is_some());
682 assert_eq!(
683 deviation.unwrap().deviation_type,
684 DeviationType::FrequencyAnomaly
685 );
686 }
687
688 #[test]
689 fn test_insufficient_baseline() {
690 let mut baseline_mgr = BehavioralBaseline::default();
691
692 for i in 0..5 {
694 let obs = Observation::new(
695 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap() + chrono::Duration::days(i),
696 )
697 .with_amount(dec!(1000));
698 baseline_mgr.record_observation("ENTITY1", obs);
699 }
700
701 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
703 .with_amount(dec!(50000));
704
705 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
706
707 assert!(deviations.is_empty());
709 }
710
711 #[test]
712 fn test_typical_hours_calculation() {
713 let mut baseline = EntityBaseline::new();
714
715 for _ in 0..10 {
717 for hour in 9..17 {
718 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
719 .with_time(NaiveTime::from_hms_opt(hour, 30, 0).unwrap());
720 baseline.add_observation(obs);
721 }
722 }
723
724 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(10, 0, 0).unwrap()));
725 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(14, 0, 0).unwrap()));
726 }
727}