1use chrono::{NaiveDate, NaiveTime, Timelike};
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use datasynth_core::models::{AnomalyType, SeverityLevel, StatisticalAnomalyType};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BehavioralBaselineConfig {
16 pub enabled: bool,
18 pub baseline_period_days: u32,
20 pub min_observations: u32,
22 pub amount_deviation_threshold: f64,
24 pub frequency_deviation_threshold: f64,
26 pub recency_decay_factor: f64,
28}
29
30impl Default for BehavioralBaselineConfig {
31 fn default() -> Self {
32 Self {
33 enabled: true,
34 baseline_period_days: 90,
35 min_observations: 10,
36 amount_deviation_threshold: 3.0,
37 frequency_deviation_threshold: 2.0,
38 recency_decay_factor: 0.95,
39 }
40 }
41}
42
43pub struct BehavioralBaseline {
45 config: BehavioralBaselineConfig,
46 entity_baselines: HashMap<String, EntityBaseline>,
48}
49
50impl Default for BehavioralBaseline {
51 fn default() -> Self {
52 Self::new(BehavioralBaselineConfig::default())
53 }
54}
55
56impl BehavioralBaseline {
57 pub fn new(config: BehavioralBaselineConfig) -> Self {
59 Self {
60 config,
61 entity_baselines: HashMap::new(),
62 }
63 }
64
65 pub fn record_observation(&mut self, entity_id: impl Into<String>, observation: Observation) {
67 let id = entity_id.into();
68 let baseline = self.entity_baselines.entry(id).or_default();
69 baseline.add_observation(observation);
70 }
71
72 pub fn get_baseline(&self, entity_id: &str) -> Option<&EntityBaseline> {
74 self.entity_baselines.get(entity_id)
75 }
76
77 pub fn check_deviation(
79 &self,
80 entity_id: &str,
81 observation: &Observation,
82 ) -> Vec<BehavioralDeviation> {
83 if !self.config.enabled {
84 return Vec::new();
85 }
86
87 let baseline = match self.get_baseline(entity_id) {
88 Some(b) if b.observation_count >= self.config.min_observations => b,
89 _ => return Vec::new(),
90 };
91
92 let mut deviations = Vec::new();
93
94 if let Some(amount) = observation.amount {
96 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
97 if baseline.amount_std_dev > 0.0 {
98 let z_score =
99 (amount_f64 - baseline.avg_transaction_amount) / baseline.amount_std_dev;
100 if z_score.abs() > self.config.amount_deviation_threshold {
101 deviations.push(BehavioralDeviation {
102 deviation_type: DeviationType::AmountAnomaly,
103 std_deviations: z_score.abs(),
104 expected_value: baseline.avg_transaction_amount,
105 actual_value: amount_f64,
106 label: AnomalyType::Statistical(
107 StatisticalAnomalyType::UnusuallyHighAmount,
108 ),
109 severity: Self::severity_from_std_dev(z_score.abs()),
110 description: format!(
111 "Amount ${:.2} is {:.1} std devs from mean ${:.2}",
112 amount_f64,
113 z_score.abs(),
114 baseline.avg_transaction_amount
115 ),
116 });
117 }
118 }
119 }
120
121 if let Some(time) = observation.time {
123 if !baseline.is_within_typical_hours(time) {
124 deviations.push(BehavioralDeviation {
125 deviation_type: DeviationType::TimingAnomaly,
126 std_deviations: 0.0,
127 expected_value: 0.0,
128 actual_value: 0.0,
129 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualTiming),
130 severity: SeverityLevel::Low,
131 description: format!(
132 "Transaction at {} outside typical hours {:02}:00-{:02}:00",
133 time, baseline.typical_posting_hours.0, baseline.typical_posting_hours.1
134 ),
135 });
136 }
137 }
138
139 if let Some(ref counterparty) = observation.counterparty {
141 if !baseline.common_counterparties.contains(counterparty)
142 && baseline.common_counterparties.len() >= 5
143 {
144 deviations.push(BehavioralDeviation {
145 deviation_type: DeviationType::NewCounterparty,
146 std_deviations: 0.0,
147 expected_value: 0.0,
148 actual_value: 0.0,
149 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
150 severity: SeverityLevel::Low,
151 description: format!(
152 "New counterparty '{}' not in typical partners",
153 counterparty
154 ),
155 });
156 }
157 }
158
159 if let Some(ref account) = observation.account_code {
161 if !baseline.usual_account_codes.contains(account)
162 && baseline.usual_account_codes.len() >= 3
163 {
164 deviations.push(BehavioralDeviation {
165 deviation_type: DeviationType::UnusualAccount,
166 std_deviations: 0.0,
167 expected_value: 0.0,
168 actual_value: 0.0,
169 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
170 severity: SeverityLevel::Low,
171 description: format!("Account '{}' not typically used by this entity", account),
172 });
173 }
174 }
175
176 deviations
177 }
178
179 fn severity_from_std_dev(std_devs: f64) -> SeverityLevel {
181 if std_devs > 5.0 {
182 SeverityLevel::Critical
183 } else if std_devs > 4.0 {
184 SeverityLevel::High
185 } else if std_devs > 3.5 {
186 SeverityLevel::Medium
187 } else {
188 SeverityLevel::Low
189 }
190 }
191
192 pub fn check_frequency_deviation(
194 &self,
195 entity_id: &str,
196 current_frequency: f64,
197 ) -> Option<BehavioralDeviation> {
198 if !self.config.enabled {
199 return None;
200 }
201
202 let baseline = self.get_baseline(entity_id)?;
203
204 if baseline.observation_count < self.config.min_observations {
205 return None;
206 }
207
208 if baseline.frequency_std_dev <= 0.0 {
209 return None;
210 }
211
212 let z_score =
213 (current_frequency - baseline.transaction_frequency) / baseline.frequency_std_dev;
214
215 if z_score.abs() > self.config.frequency_deviation_threshold {
216 Some(BehavioralDeviation {
217 deviation_type: DeviationType::FrequencyAnomaly,
218 std_deviations: z_score.abs(),
219 expected_value: baseline.transaction_frequency,
220 actual_value: current_frequency,
221 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualFrequency),
222 severity: Self::severity_from_std_dev(z_score.abs()),
223 description: format!(
224 "Frequency {:.2}/day is {:.1} std devs from normal {:.2}/day",
225 current_frequency,
226 z_score.abs(),
227 baseline.transaction_frequency
228 ),
229 })
230 } else {
231 None
232 }
233 }
234
235 pub fn entity_count(&self) -> usize {
237 self.entity_baselines.len()
238 }
239
240 pub fn config(&self) -> &BehavioralBaselineConfig {
242 &self.config
243 }
244
245 pub fn clear(&mut self) {
247 self.entity_baselines.clear();
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct Observation {
254 pub date: NaiveDate,
256 pub time: Option<NaiveTime>,
258 pub amount: Option<Decimal>,
260 pub counterparty: Option<String>,
262 pub account_code: Option<String>,
264}
265
266impl Observation {
267 pub fn new(date: NaiveDate) -> Self {
269 Self {
270 date,
271 time: None,
272 amount: None,
273 counterparty: None,
274 account_code: None,
275 }
276 }
277
278 pub fn with_time(mut self, time: NaiveTime) -> Self {
280 self.time = Some(time);
281 self
282 }
283
284 pub fn with_amount(mut self, amount: Decimal) -> Self {
286 self.amount = Some(amount);
287 self
288 }
289
290 pub fn with_counterparty(mut self, counterparty: impl Into<String>) -> Self {
292 self.counterparty = Some(counterparty.into());
293 self
294 }
295
296 pub fn with_account(mut self, account: impl Into<String>) -> Self {
298 self.account_code = Some(account.into());
299 self
300 }
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct EntityBaseline {
306 pub avg_transaction_amount: f64,
308 pub amount_std_dev: f64,
310 pub transaction_frequency: f64,
312 pub frequency_std_dev: f64,
314 pub typical_posting_hours: (u8, u8),
316 pub common_counterparties: Vec<String>,
318 pub usual_account_codes: Vec<String>,
320 pub observation_count: u32,
322 #[serde(skip)]
324 amount_sum: f64,
325 #[serde(skip)]
327 amount_sum_sq: f64,
328 #[serde(skip)]
330 daily_counts: HashMap<NaiveDate, u32>,
331 #[serde(skip)]
333 hour_counts: [u32; 24],
334 #[serde(skip)]
336 counterparty_freq: HashMap<String, u32>,
337 #[serde(skip)]
339 account_freq: HashMap<String, u32>,
340}
341
342impl Default for EntityBaseline {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl EntityBaseline {
349 pub fn new() -> Self {
351 Self {
352 avg_transaction_amount: 0.0,
353 amount_std_dev: 0.0,
354 transaction_frequency: 0.0,
355 frequency_std_dev: 0.0,
356 typical_posting_hours: (8, 18),
357 common_counterparties: Vec::new(),
358 usual_account_codes: Vec::new(),
359 observation_count: 0,
360 amount_sum: 0.0,
361 amount_sum_sq: 0.0,
362 daily_counts: HashMap::new(),
363 hour_counts: [0; 24],
364 counterparty_freq: HashMap::new(),
365 account_freq: HashMap::new(),
366 }
367 }
368
369 pub fn add_observation(&mut self, observation: Observation) {
371 self.observation_count += 1;
372
373 if let Some(amount) = observation.amount {
375 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
376 self.amount_sum += amount_f64;
377 self.amount_sum_sq += amount_f64 * amount_f64;
378 self.avg_transaction_amount = self.amount_sum / self.observation_count as f64;
379
380 if self.observation_count > 1 {
381 let variance = (self.amount_sum_sq
382 - (self.amount_sum * self.amount_sum) / self.observation_count as f64)
383 / (self.observation_count - 1) as f64;
384 self.amount_std_dev = variance.max(0.0).sqrt();
385 }
386 }
387
388 *self.daily_counts.entry(observation.date).or_insert(0) += 1;
390 self.update_frequency_stats();
391
392 if let Some(time) = observation.time {
394 self.hour_counts[time.hour() as usize] += 1;
395 self.update_typical_hours();
396 }
397
398 if let Some(ref counterparty) = observation.counterparty {
400 *self
401 .counterparty_freq
402 .entry(counterparty.clone())
403 .or_insert(0) += 1;
404 self.update_common_counterparties();
405 }
406
407 if let Some(ref account) = observation.account_code {
409 *self.account_freq.entry(account.clone()).or_insert(0) += 1;
410 self.update_usual_accounts();
411 }
412 }
413
414 fn update_frequency_stats(&mut self) {
416 if self.daily_counts.is_empty() {
417 return;
418 }
419
420 let counts: Vec<f64> = self.daily_counts.values().map(|&c| c as f64).collect();
421 let n = counts.len() as f64;
422
423 self.transaction_frequency = counts.iter().sum::<f64>() / n;
424
425 if counts.len() > 1 {
426 let variance: f64 = counts
427 .iter()
428 .map(|c| (c - self.transaction_frequency).powi(2))
429 .sum::<f64>()
430 / (n - 1.0);
431 self.frequency_std_dev = variance.sqrt();
432 }
433 }
434
435 fn update_typical_hours(&mut self) {
437 let total: u32 = self.hour_counts.iter().sum();
438 if total == 0 {
439 return;
440 }
441
442 let threshold = (total as f64 * 0.1) as u32; let mut cumsum = 0u32;
446 let mut start_hour = 0u8;
447 for (hour, &count) in self.hour_counts.iter().enumerate() {
448 cumsum += count;
449 if cumsum > threshold {
450 start_hour = hour as u8;
451 break;
452 }
453 }
454
455 cumsum = 0;
456 let mut end_hour = 23u8;
457 for (hour, &count) in self.hour_counts.iter().enumerate().rev() {
458 cumsum += count;
459 if cumsum > threshold {
460 end_hour = hour as u8;
461 break;
462 }
463 }
464
465 self.typical_posting_hours = (start_hour, end_hour.max(start_hour + 1));
466 }
467
468 fn update_common_counterparties(&mut self) {
470 let mut sorted: Vec<_> = self.counterparty_freq.iter().collect();
471 sorted.sort_by(|a, b| b.1.cmp(a.1));
472 self.common_counterparties = sorted
473 .into_iter()
474 .take(10)
475 .map(|(k, _)| k.clone())
476 .collect();
477 }
478
479 fn update_usual_accounts(&mut self) {
481 let mut sorted: Vec<_> = self.account_freq.iter().collect();
482 sorted.sort_by(|a, b| b.1.cmp(a.1));
483 self.usual_account_codes = sorted.into_iter().take(5).map(|(k, _)| k.clone()).collect();
484 }
485
486 pub fn is_within_typical_hours(&self, time: NaiveTime) -> bool {
488 let hour = time.hour() as u8;
489 hour >= self.typical_posting_hours.0 && hour <= self.typical_posting_hours.1
490 }
491
492 pub fn is_established(&self, min_observations: u32) -> bool {
494 self.observation_count >= min_observations
495 }
496}
497
498#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
500pub enum DeviationType {
501 AmountAnomaly,
503 FrequencyAnomaly,
505 NewCounterparty,
507 TimingAnomaly,
509 UnusualAccount,
511}
512
513#[derive(Debug, Clone)]
515pub struct BehavioralDeviation {
516 pub deviation_type: DeviationType,
518 pub std_deviations: f64,
520 pub expected_value: f64,
522 pub actual_value: f64,
524 pub label: AnomalyType,
526 pub severity: SeverityLevel,
528 pub description: String,
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use rust_decimal_macros::dec;
536
537 #[test]
538 fn test_entity_baseline_creation() {
539 let baseline = EntityBaseline::new();
540 assert_eq!(baseline.observation_count, 0);
541 assert!((baseline.avg_transaction_amount - 0.0).abs() < 0.01);
542 }
543
544 #[test]
545 fn test_observation_builder() {
546 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
547 .with_amount(dec!(1000))
548 .with_counterparty("VENDOR001")
549 .with_account("5000");
550
551 assert_eq!(obs.amount, Some(dec!(1000)));
552 assert_eq!(obs.counterparty, Some("VENDOR001".to_string()));
553 assert_eq!(obs.account_code, Some("5000".to_string()));
554 }
555
556 #[test]
557 fn test_baseline_amount_tracking() {
558 let mut baseline = EntityBaseline::new();
559
560 for amount in [1000.0, 1100.0, 900.0, 1050.0, 950.0] {
561 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
562 .with_amount(Decimal::try_from(amount).unwrap());
563 baseline.add_observation(obs);
564 }
565
566 assert_eq!(baseline.observation_count, 5);
567 assert!((baseline.avg_transaction_amount - 1000.0).abs() < 1.0);
568 assert!(baseline.amount_std_dev > 0.0);
569 }
570
571 #[test]
572 fn test_behavioral_baseline_deviation_detection() {
573 let mut baseline_mgr = BehavioralBaseline::default();
574
575 let amounts = [
578 900, 950, 1000, 1050, 1100, 920, 980, 1020, 1080, 950, 960, 1000, 1040, 990, 970, 1010,
579 1030, 1000, 980, 1020,
580 ];
581 for (i, &amount) in amounts.iter().enumerate() {
582 let obs = Observation::new(
583 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
584 + chrono::Duration::days(i as i64 % 10),
585 )
586 .with_amount(Decimal::from(amount))
587 .with_counterparty("VENDOR001")
588 .with_time(NaiveTime::from_hms_opt(10, 0, 0).unwrap());
589 baseline_mgr.record_observation("ENTITY1", obs);
590 }
591
592 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
594 .with_amount(dec!(50000))
595 .with_counterparty("VENDOR001");
596
597 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
598
599 assert!(deviations
601 .iter()
602 .any(|d| d.deviation_type == DeviationType::AmountAnomaly));
603 }
604
605 #[test]
606 fn test_new_counterparty_detection() {
607 let mut baseline_mgr = BehavioralBaseline::default();
608
609 for i in 0..15 {
611 let cp = format!("VENDOR{:03}", i % 5);
612 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
613 .with_amount(dec!(1000))
614 .with_counterparty(&cp);
615 baseline_mgr.record_observation("ENTITY1", obs);
616 }
617
618 let new_cp_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
620 .with_amount(dec!(1000))
621 .with_counterparty("NEW_VENDOR");
622
623 let deviations = baseline_mgr.check_deviation("ENTITY1", &new_cp_obs);
624
625 assert!(deviations
627 .iter()
628 .any(|d| d.deviation_type == DeviationType::NewCounterparty));
629 }
630
631 #[test]
632 fn test_timing_anomaly_detection() {
633 let mut baseline_mgr = BehavioralBaseline::default();
634
635 for i in 0..15 {
637 let hour = 9 + (i % 8);
638 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
639 .with_amount(dec!(1000))
640 .with_time(NaiveTime::from_hms_opt(hour, 0, 0).unwrap());
641 baseline_mgr.record_observation("ENTITY1", obs);
642 }
643
644 let unusual_time_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
646 .with_amount(dec!(1000))
647 .with_time(NaiveTime::from_hms_opt(3, 0, 0).unwrap());
648
649 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_time_obs);
650
651 assert!(deviations
653 .iter()
654 .any(|d| d.deviation_type == DeviationType::TimingAnomaly));
655 }
656
657 #[test]
658 fn test_frequency_deviation() {
659 let mut baseline_mgr = BehavioralBaseline::default();
660
661 let daily_counts = [
663 2, 1, 3, 2, 2, 1, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3,
664 2,
665 ];
666 for (day, &count) in daily_counts.iter().enumerate() {
667 for _ in 0..count {
668 let obs = Observation::new(
669 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
670 + chrono::Duration::days(day as i64),
671 )
672 .with_amount(dec!(1000));
673 baseline_mgr.record_observation("ENTITY1", obs);
674 }
675 }
676
677 let deviation = baseline_mgr.check_frequency_deviation("ENTITY1", 10.0);
679
680 assert!(deviation.is_some());
682 assert_eq!(
683 deviation.unwrap().deviation_type,
684 DeviationType::FrequencyAnomaly
685 );
686 }
687
688 #[test]
689 fn test_insufficient_baseline() {
690 let mut baseline_mgr = BehavioralBaseline::default();
691
692 for i in 0..5 {
694 let obs = Observation::new(
695 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap() + chrono::Duration::days(i),
696 )
697 .with_amount(dec!(1000));
698 baseline_mgr.record_observation("ENTITY1", obs);
699 }
700
701 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
703 .with_amount(dec!(50000));
704
705 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
706
707 assert!(deviations.is_empty());
709 }
710
711 #[test]
712 fn test_typical_hours_calculation() {
713 let mut baseline = EntityBaseline::new();
714
715 for _ in 0..10 {
717 for hour in 9..17 {
718 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
719 .with_time(NaiveTime::from_hms_opt(hour, 30, 0).unwrap());
720 baseline.add_observation(obs);
721 }
722 }
723
724 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(10, 0, 0).unwrap()));
725 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(14, 0, 0).unwrap()));
726 }
727}