1use chrono::{NaiveDate, NaiveTime, Timelike};
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use datasynth_core::models::{AnomalyType, SeverityLevel, StatisticalAnomalyType};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BehavioralBaselineConfig {
16 pub enabled: bool,
18 pub baseline_period_days: u32,
20 pub min_observations: u32,
22 pub amount_deviation_threshold: f64,
24 pub frequency_deviation_threshold: f64,
26 pub recency_decay_factor: f64,
28}
29
30impl Default for BehavioralBaselineConfig {
31 fn default() -> Self {
32 Self {
33 enabled: true,
34 baseline_period_days: 90,
35 min_observations: 10,
36 amount_deviation_threshold: 3.0,
37 frequency_deviation_threshold: 2.0,
38 recency_decay_factor: 0.95,
39 }
40 }
41}
42
43pub struct BehavioralBaseline {
45 config: BehavioralBaselineConfig,
46 entity_baselines: HashMap<String, EntityBaseline>,
48}
49
50impl Default for BehavioralBaseline {
51 fn default() -> Self {
52 Self::new(BehavioralBaselineConfig::default())
53 }
54}
55
56impl BehavioralBaseline {
57 pub fn new(config: BehavioralBaselineConfig) -> Self {
59 Self {
60 config,
61 entity_baselines: HashMap::new(),
62 }
63 }
64
65 pub fn record_observation(&mut self, entity_id: impl Into<String>, observation: Observation) {
67 let id = entity_id.into();
68 let baseline = self.entity_baselines.entry(id).or_default();
69 baseline.add_observation(observation);
70 }
71
72 pub fn get_baseline(&self, entity_id: &str) -> Option<&EntityBaseline> {
74 self.entity_baselines.get(entity_id)
75 }
76
77 pub fn check_deviation(
79 &self,
80 entity_id: &str,
81 observation: &Observation,
82 ) -> Vec<BehavioralDeviation> {
83 if !self.config.enabled {
84 return Vec::new();
85 }
86
87 let baseline = match self.get_baseline(entity_id) {
88 Some(b) if b.observation_count >= self.config.min_observations => b,
89 _ => return Vec::new(),
90 };
91
92 let mut deviations = Vec::new();
93
94 if let Some(amount) = observation.amount {
96 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
97 if baseline.amount_std_dev > 0.0 {
98 let z_score =
99 (amount_f64 - baseline.avg_transaction_amount) / baseline.amount_std_dev;
100 if z_score.abs() > self.config.amount_deviation_threshold {
101 deviations.push(BehavioralDeviation {
102 deviation_type: DeviationType::AmountAnomaly,
103 std_deviations: z_score.abs(),
104 expected_value: baseline.avg_transaction_amount,
105 actual_value: amount_f64,
106 label: AnomalyType::Statistical(
107 StatisticalAnomalyType::UnusuallyHighAmount,
108 ),
109 severity: Self::severity_from_std_dev(z_score.abs()),
110 description: format!(
111 "Amount ${:.2} is {:.1} std devs from mean ${:.2}",
112 amount_f64,
113 z_score.abs(),
114 baseline.avg_transaction_amount
115 ),
116 });
117 }
118 }
119 }
120
121 if let Some(time) = observation.time {
123 if !baseline.is_within_typical_hours(time) {
124 deviations.push(BehavioralDeviation {
125 deviation_type: DeviationType::TimingAnomaly,
126 std_deviations: 0.0,
127 expected_value: 0.0,
128 actual_value: 0.0,
129 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualTiming),
130 severity: SeverityLevel::Low,
131 description: format!(
132 "Transaction at {} outside typical hours {:02}:00-{:02}:00",
133 time, baseline.typical_posting_hours.0, baseline.typical_posting_hours.1
134 ),
135 });
136 }
137 }
138
139 if let Some(ref counterparty) = observation.counterparty {
141 if !baseline.common_counterparties.contains(counterparty)
142 && baseline.common_counterparties.len() >= 5
143 {
144 deviations.push(BehavioralDeviation {
145 deviation_type: DeviationType::NewCounterparty,
146 std_deviations: 0.0,
147 expected_value: 0.0,
148 actual_value: 0.0,
149 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
150 severity: SeverityLevel::Low,
151 description: format!(
152 "New counterparty '{}' not in typical partners",
153 counterparty
154 ),
155 });
156 }
157 }
158
159 if let Some(ref account) = observation.account_code {
161 if !baseline.usual_account_codes.contains(account)
162 && baseline.usual_account_codes.len() >= 3
163 {
164 deviations.push(BehavioralDeviation {
165 deviation_type: DeviationType::UnusualAccount,
166 std_deviations: 0.0,
167 expected_value: 0.0,
168 actual_value: 0.0,
169 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
170 severity: SeverityLevel::Low,
171 description: format!("Account '{}' not typically used by this entity", account),
172 });
173 }
174 }
175
176 deviations
177 }
178
179 fn severity_from_std_dev(std_devs: f64) -> SeverityLevel {
181 if std_devs > 5.0 {
182 SeverityLevel::Critical
183 } else if std_devs > 4.0 {
184 SeverityLevel::High
185 } else if std_devs > 3.5 {
186 SeverityLevel::Medium
187 } else {
188 SeverityLevel::Low
189 }
190 }
191
192 pub fn check_frequency_deviation(
194 &self,
195 entity_id: &str,
196 current_frequency: f64,
197 ) -> Option<BehavioralDeviation> {
198 if !self.config.enabled {
199 return None;
200 }
201
202 let baseline = self.get_baseline(entity_id)?;
203
204 if baseline.observation_count < self.config.min_observations {
205 return None;
206 }
207
208 if baseline.frequency_std_dev <= 0.0 {
209 return None;
210 }
211
212 let z_score =
213 (current_frequency - baseline.transaction_frequency) / baseline.frequency_std_dev;
214
215 if z_score.abs() > self.config.frequency_deviation_threshold {
216 Some(BehavioralDeviation {
217 deviation_type: DeviationType::FrequencyAnomaly,
218 std_deviations: z_score.abs(),
219 expected_value: baseline.transaction_frequency,
220 actual_value: current_frequency,
221 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualFrequency),
222 severity: Self::severity_from_std_dev(z_score.abs()),
223 description: format!(
224 "Frequency {:.2}/day is {:.1} std devs from normal {:.2}/day",
225 current_frequency,
226 z_score.abs(),
227 baseline.transaction_frequency
228 ),
229 })
230 } else {
231 None
232 }
233 }
234
235 pub fn entity_count(&self) -> usize {
237 self.entity_baselines.len()
238 }
239
240 pub fn config(&self) -> &BehavioralBaselineConfig {
242 &self.config
243 }
244
245 pub fn clear(&mut self) {
247 self.entity_baselines.clear();
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct Observation {
254 pub date: NaiveDate,
256 pub time: Option<NaiveTime>,
258 pub amount: Option<Decimal>,
260 pub counterparty: Option<String>,
262 pub account_code: Option<String>,
264}
265
266impl Observation {
267 pub fn new(date: NaiveDate) -> Self {
269 Self {
270 date,
271 time: None,
272 amount: None,
273 counterparty: None,
274 account_code: None,
275 }
276 }
277
278 pub fn with_time(mut self, time: NaiveTime) -> Self {
280 self.time = Some(time);
281 self
282 }
283
284 pub fn with_amount(mut self, amount: Decimal) -> Self {
286 self.amount = Some(amount);
287 self
288 }
289
290 pub fn with_counterparty(mut self, counterparty: impl Into<String>) -> Self {
292 self.counterparty = Some(counterparty.into());
293 self
294 }
295
296 pub fn with_account(mut self, account: impl Into<String>) -> Self {
298 self.account_code = Some(account.into());
299 self
300 }
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct EntityBaseline {
306 pub avg_transaction_amount: f64,
308 pub amount_std_dev: f64,
310 pub transaction_frequency: f64,
312 pub frequency_std_dev: f64,
314 pub typical_posting_hours: (u8, u8),
316 pub common_counterparties: Vec<String>,
318 pub usual_account_codes: Vec<String>,
320 pub observation_count: u32,
322 #[serde(skip)]
324 amount_sum: f64,
325 #[serde(skip)]
327 amount_sum_sq: f64,
328 #[serde(skip)]
330 daily_counts: HashMap<NaiveDate, u32>,
331 #[serde(skip)]
333 hour_counts: [u32; 24],
334 #[serde(skip)]
336 counterparty_freq: HashMap<String, u32>,
337 #[serde(skip)]
339 account_freq: HashMap<String, u32>,
340}
341
342impl Default for EntityBaseline {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl EntityBaseline {
349 pub fn new() -> Self {
351 Self {
352 avg_transaction_amount: 0.0,
353 amount_std_dev: 0.0,
354 transaction_frequency: 0.0,
355 frequency_std_dev: 0.0,
356 typical_posting_hours: (8, 18),
357 common_counterparties: Vec::new(),
358 usual_account_codes: Vec::new(),
359 observation_count: 0,
360 amount_sum: 0.0,
361 amount_sum_sq: 0.0,
362 daily_counts: HashMap::new(),
363 hour_counts: [0; 24],
364 counterparty_freq: HashMap::new(),
365 account_freq: HashMap::new(),
366 }
367 }
368
369 pub fn add_observation(&mut self, observation: Observation) {
371 self.observation_count += 1;
372
373 if let Some(amount) = observation.amount {
375 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
376 self.amount_sum += amount_f64;
377 self.amount_sum_sq += amount_f64 * amount_f64;
378 self.avg_transaction_amount = self.amount_sum / self.observation_count as f64;
379
380 if self.observation_count > 1 {
381 let variance = (self.amount_sum_sq
382 - (self.amount_sum * self.amount_sum) / self.observation_count as f64)
383 / (self.observation_count - 1) as f64;
384 self.amount_std_dev = variance.max(0.0).sqrt();
385 }
386 }
387
388 *self.daily_counts.entry(observation.date).or_insert(0) += 1;
390 self.update_frequency_stats();
391
392 if let Some(time) = observation.time {
394 self.hour_counts[time.hour() as usize] += 1;
395 self.update_typical_hours();
396 }
397
398 if let Some(ref counterparty) = observation.counterparty {
400 *self
401 .counterparty_freq
402 .entry(counterparty.clone())
403 .or_insert(0) += 1;
404 self.update_common_counterparties();
405 }
406
407 if let Some(ref account) = observation.account_code {
409 *self.account_freq.entry(account.clone()).or_insert(0) += 1;
410 self.update_usual_accounts();
411 }
412 }
413
414 fn update_frequency_stats(&mut self) {
416 if self.daily_counts.is_empty() {
417 return;
418 }
419
420 let counts: Vec<f64> = self.daily_counts.values().map(|&c| c as f64).collect();
421 let n = counts.len() as f64;
422
423 self.transaction_frequency = counts.iter().sum::<f64>() / n;
424
425 if counts.len() > 1 {
426 let variance: f64 = counts
427 .iter()
428 .map(|c| (c - self.transaction_frequency).powi(2))
429 .sum::<f64>()
430 / (n - 1.0);
431 self.frequency_std_dev = variance.sqrt();
432 }
433 }
434
435 fn update_typical_hours(&mut self) {
437 let total: u32 = self.hour_counts.iter().sum();
438 if total == 0 {
439 return;
440 }
441
442 let threshold = (total as f64 * 0.1) as u32; let mut cumsum = 0u32;
446 let mut start_hour = 0u8;
447 for (hour, &count) in self.hour_counts.iter().enumerate() {
448 cumsum += count;
449 if cumsum > threshold {
450 start_hour = hour as u8;
451 break;
452 }
453 }
454
455 cumsum = 0;
456 let mut end_hour = 23u8;
457 for (hour, &count) in self.hour_counts.iter().enumerate().rev() {
458 cumsum += count;
459 if cumsum > threshold {
460 end_hour = hour as u8;
461 break;
462 }
463 }
464
465 self.typical_posting_hours = (start_hour, end_hour.max(start_hour + 1));
466 }
467
468 fn update_common_counterparties(&mut self) {
470 let mut sorted: Vec<_> = self.counterparty_freq.iter().collect();
471 sorted.sort_by(|a, b| b.1.cmp(a.1));
472 self.common_counterparties = sorted
473 .into_iter()
474 .take(10)
475 .map(|(k, _)| k.clone())
476 .collect();
477 }
478
479 fn update_usual_accounts(&mut self) {
481 let mut sorted: Vec<_> = self.account_freq.iter().collect();
482 sorted.sort_by(|a, b| b.1.cmp(a.1));
483 self.usual_account_codes = sorted.into_iter().take(5).map(|(k, _)| k.clone()).collect();
484 }
485
486 pub fn is_within_typical_hours(&self, time: NaiveTime) -> bool {
488 let hour = time.hour() as u8;
489 hour >= self.typical_posting_hours.0 && hour <= self.typical_posting_hours.1
490 }
491
492 pub fn is_established(&self, min_observations: u32) -> bool {
494 self.observation_count >= min_observations
495 }
496}
497
498#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
500pub enum DeviationType {
501 AmountAnomaly,
503 FrequencyAnomaly,
505 NewCounterparty,
507 TimingAnomaly,
509 UnusualAccount,
511}
512
513#[derive(Debug, Clone)]
515pub struct BehavioralDeviation {
516 pub deviation_type: DeviationType,
518 pub std_deviations: f64,
520 pub expected_value: f64,
522 pub actual_value: f64,
524 pub label: AnomalyType,
526 pub severity: SeverityLevel,
528 pub description: String,
530}
531
532#[cfg(test)]
533#[allow(clippy::unwrap_used)]
534mod tests {
535 use super::*;
536 use rust_decimal_macros::dec;
537
538 #[test]
539 fn test_entity_baseline_creation() {
540 let baseline = EntityBaseline::new();
541 assert_eq!(baseline.observation_count, 0);
542 assert!((baseline.avg_transaction_amount - 0.0).abs() < 0.01);
543 }
544
545 #[test]
546 fn test_observation_builder() {
547 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
548 .with_amount(dec!(1000))
549 .with_counterparty("VENDOR001")
550 .with_account("5000");
551
552 assert_eq!(obs.amount, Some(dec!(1000)));
553 assert_eq!(obs.counterparty, Some("VENDOR001".to_string()));
554 assert_eq!(obs.account_code, Some("5000".to_string()));
555 }
556
557 #[test]
558 fn test_baseline_amount_tracking() {
559 let mut baseline = EntityBaseline::new();
560
561 for amount in [1000.0, 1100.0, 900.0, 1050.0, 950.0] {
562 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
563 .with_amount(Decimal::try_from(amount).unwrap());
564 baseline.add_observation(obs);
565 }
566
567 assert_eq!(baseline.observation_count, 5);
568 assert!((baseline.avg_transaction_amount - 1000.0).abs() < 1.0);
569 assert!(baseline.amount_std_dev > 0.0);
570 }
571
572 #[test]
573 fn test_behavioral_baseline_deviation_detection() {
574 let mut baseline_mgr = BehavioralBaseline::default();
575
576 let amounts = [
579 900, 950, 1000, 1050, 1100, 920, 980, 1020, 1080, 950, 960, 1000, 1040, 990, 970, 1010,
580 1030, 1000, 980, 1020,
581 ];
582 for (i, &amount) in amounts.iter().enumerate() {
583 let obs = Observation::new(
584 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
585 + chrono::Duration::days(i as i64 % 10),
586 )
587 .with_amount(Decimal::from(amount))
588 .with_counterparty("VENDOR001")
589 .with_time(NaiveTime::from_hms_opt(10, 0, 0).unwrap());
590 baseline_mgr.record_observation("ENTITY1", obs);
591 }
592
593 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
595 .with_amount(dec!(50000))
596 .with_counterparty("VENDOR001");
597
598 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
599
600 assert!(deviations
602 .iter()
603 .any(|d| d.deviation_type == DeviationType::AmountAnomaly));
604 }
605
606 #[test]
607 fn test_new_counterparty_detection() {
608 let mut baseline_mgr = BehavioralBaseline::default();
609
610 for i in 0..15 {
612 let cp = format!("VENDOR{:03}", i % 5);
613 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
614 .with_amount(dec!(1000))
615 .with_counterparty(&cp);
616 baseline_mgr.record_observation("ENTITY1", obs);
617 }
618
619 let new_cp_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
621 .with_amount(dec!(1000))
622 .with_counterparty("NEW_VENDOR");
623
624 let deviations = baseline_mgr.check_deviation("ENTITY1", &new_cp_obs);
625
626 assert!(deviations
628 .iter()
629 .any(|d| d.deviation_type == DeviationType::NewCounterparty));
630 }
631
632 #[test]
633 fn test_timing_anomaly_detection() {
634 let mut baseline_mgr = BehavioralBaseline::default();
635
636 for i in 0..15 {
638 let hour = 9 + (i % 8);
639 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
640 .with_amount(dec!(1000))
641 .with_time(NaiveTime::from_hms_opt(hour, 0, 0).unwrap());
642 baseline_mgr.record_observation("ENTITY1", obs);
643 }
644
645 let unusual_time_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
647 .with_amount(dec!(1000))
648 .with_time(NaiveTime::from_hms_opt(3, 0, 0).unwrap());
649
650 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_time_obs);
651
652 assert!(deviations
654 .iter()
655 .any(|d| d.deviation_type == DeviationType::TimingAnomaly));
656 }
657
658 #[test]
659 fn test_frequency_deviation() {
660 let mut baseline_mgr = BehavioralBaseline::default();
661
662 let daily_counts = [
664 2, 1, 3, 2, 2, 1, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3,
665 2,
666 ];
667 for (day, &count) in daily_counts.iter().enumerate() {
668 for _ in 0..count {
669 let obs = Observation::new(
670 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
671 + chrono::Duration::days(day as i64),
672 )
673 .with_amount(dec!(1000));
674 baseline_mgr.record_observation("ENTITY1", obs);
675 }
676 }
677
678 let deviation = baseline_mgr.check_frequency_deviation("ENTITY1", 10.0);
680
681 assert!(deviation.is_some());
683 assert_eq!(
684 deviation.unwrap().deviation_type,
685 DeviationType::FrequencyAnomaly
686 );
687 }
688
689 #[test]
690 fn test_insufficient_baseline() {
691 let mut baseline_mgr = BehavioralBaseline::default();
692
693 for i in 0..5 {
695 let obs = Observation::new(
696 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap() + chrono::Duration::days(i),
697 )
698 .with_amount(dec!(1000));
699 baseline_mgr.record_observation("ENTITY1", obs);
700 }
701
702 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
704 .with_amount(dec!(50000));
705
706 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
707
708 assert!(deviations.is_empty());
710 }
711
712 #[test]
713 fn test_typical_hours_calculation() {
714 let mut baseline = EntityBaseline::new();
715
716 for _ in 0..10 {
718 for hour in 9..17 {
719 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
720 .with_time(NaiveTime::from_hms_opt(hour, 30, 0).unwrap());
721 baseline.add_observation(obs);
722 }
723 }
724
725 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(10, 0, 0).unwrap()));
726 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(14, 0, 0).unwrap()));
727 }
728}