1use chrono::{NaiveDate, NaiveTime, Timelike};
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use datasynth_core::models::{AnomalyType, SeverityLevel, StatisticalAnomalyType};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct BehavioralBaselineConfig {
16 pub enabled: bool,
18 pub baseline_period_days: u32,
20 pub min_observations: u32,
22 pub amount_deviation_threshold: f64,
24 pub frequency_deviation_threshold: f64,
26 pub recency_decay_factor: f64,
28}
29
30impl Default for BehavioralBaselineConfig {
31 fn default() -> Self {
32 Self {
33 enabled: true,
34 baseline_period_days: 90,
35 min_observations: 10,
36 amount_deviation_threshold: 3.0,
37 frequency_deviation_threshold: 2.0,
38 recency_decay_factor: 0.95,
39 }
40 }
41}
42
43pub struct BehavioralBaseline {
45 config: BehavioralBaselineConfig,
46 entity_baselines: HashMap<String, EntityBaseline>,
48}
49
50impl Default for BehavioralBaseline {
51 fn default() -> Self {
52 Self::new(BehavioralBaselineConfig::default())
53 }
54}
55
56impl BehavioralBaseline {
57 pub fn new(config: BehavioralBaselineConfig) -> Self {
59 Self {
60 config,
61 entity_baselines: HashMap::new(),
62 }
63 }
64
65 pub fn record_observation(&mut self, entity_id: impl Into<String>, observation: Observation) {
67 let id = entity_id.into();
68 let baseline = self.entity_baselines.entry(id).or_default();
69 baseline.add_observation(observation);
70 }
71
72 pub fn get_baseline(&self, entity_id: &str) -> Option<&EntityBaseline> {
74 self.entity_baselines.get(entity_id)
75 }
76
77 pub fn check_deviation(
79 &self,
80 entity_id: &str,
81 observation: &Observation,
82 ) -> Vec<BehavioralDeviation> {
83 if !self.config.enabled {
84 return Vec::new();
85 }
86
87 let baseline = match self.get_baseline(entity_id) {
88 Some(b) if b.observation_count >= self.config.min_observations => b,
89 _ => return Vec::new(),
90 };
91
92 let mut deviations = Vec::new();
93
94 if let Some(amount) = observation.amount {
96 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
97 if baseline.amount_std_dev > 0.0 {
98 let z_score =
99 (amount_f64 - baseline.avg_transaction_amount) / baseline.amount_std_dev;
100 if z_score.abs() > self.config.amount_deviation_threshold {
101 deviations.push(BehavioralDeviation {
102 deviation_type: DeviationType::AmountAnomaly,
103 std_deviations: z_score.abs(),
104 expected_value: baseline.avg_transaction_amount,
105 actual_value: amount_f64,
106 label: AnomalyType::Statistical(
107 StatisticalAnomalyType::UnusuallyHighAmount,
108 ),
109 severity: Self::severity_from_std_dev(z_score.abs()),
110 description: format!(
111 "Amount ${:.2} is {:.1} std devs from mean ${:.2}",
112 amount_f64,
113 z_score.abs(),
114 baseline.avg_transaction_amount
115 ),
116 });
117 }
118 }
119 }
120
121 if let Some(time) = observation.time {
123 if !baseline.is_within_typical_hours(time) {
124 deviations.push(BehavioralDeviation {
125 deviation_type: DeviationType::TimingAnomaly,
126 std_deviations: 0.0,
127 expected_value: 0.0,
128 actual_value: 0.0,
129 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualTiming),
130 severity: SeverityLevel::Low,
131 description: format!(
132 "Transaction at {} outside typical hours {:02}:00-{:02}:00",
133 time, baseline.typical_posting_hours.0, baseline.typical_posting_hours.1
134 ),
135 });
136 }
137 }
138
139 if let Some(ref counterparty) = observation.counterparty {
141 if !baseline.common_counterparties.contains(counterparty)
142 && baseline.common_counterparties.len() >= 5
143 {
144 deviations.push(BehavioralDeviation {
145 deviation_type: DeviationType::NewCounterparty,
146 std_deviations: 0.0,
147 expected_value: 0.0,
148 actual_value: 0.0,
149 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
150 severity: SeverityLevel::Low,
151 description: format!(
152 "New counterparty '{counterparty}' not in typical partners"
153 ),
154 });
155 }
156 }
157
158 if let Some(ref account) = observation.account_code {
160 if !baseline.usual_account_codes.contains(account)
161 && baseline.usual_account_codes.len() >= 3
162 {
163 deviations.push(BehavioralDeviation {
164 deviation_type: DeviationType::UnusualAccount,
165 std_deviations: 0.0,
166 expected_value: 0.0,
167 actual_value: 0.0,
168 label: AnomalyType::Statistical(StatisticalAnomalyType::StatisticalOutlier),
169 severity: SeverityLevel::Low,
170 description: format!("Account '{account}' not typically used by this entity"),
171 });
172 }
173 }
174
175 deviations
176 }
177
178 fn severity_from_std_dev(std_devs: f64) -> SeverityLevel {
180 if std_devs > 5.0 {
181 SeverityLevel::Critical
182 } else if std_devs > 4.0 {
183 SeverityLevel::High
184 } else if std_devs > 3.5 {
185 SeverityLevel::Medium
186 } else {
187 SeverityLevel::Low
188 }
189 }
190
191 pub fn check_frequency_deviation(
193 &self,
194 entity_id: &str,
195 current_frequency: f64,
196 ) -> Option<BehavioralDeviation> {
197 if !self.config.enabled {
198 return None;
199 }
200
201 let baseline = self.get_baseline(entity_id)?;
202
203 if baseline.observation_count < self.config.min_observations {
204 return None;
205 }
206
207 if baseline.frequency_std_dev <= 0.0 {
208 return None;
209 }
210
211 let z_score =
212 (current_frequency - baseline.transaction_frequency) / baseline.frequency_std_dev;
213
214 if z_score.abs() > self.config.frequency_deviation_threshold {
215 Some(BehavioralDeviation {
216 deviation_type: DeviationType::FrequencyAnomaly,
217 std_deviations: z_score.abs(),
218 expected_value: baseline.transaction_frequency,
219 actual_value: current_frequency,
220 label: AnomalyType::Statistical(StatisticalAnomalyType::UnusualFrequency),
221 severity: Self::severity_from_std_dev(z_score.abs()),
222 description: format!(
223 "Frequency {:.2}/day is {:.1} std devs from normal {:.2}/day",
224 current_frequency,
225 z_score.abs(),
226 baseline.transaction_frequency
227 ),
228 })
229 } else {
230 None
231 }
232 }
233
234 pub fn entity_count(&self) -> usize {
236 self.entity_baselines.len()
237 }
238
239 pub fn config(&self) -> &BehavioralBaselineConfig {
241 &self.config
242 }
243
244 pub fn clear(&mut self) {
246 self.entity_baselines.clear();
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct Observation {
253 pub date: NaiveDate,
255 pub time: Option<NaiveTime>,
257 pub amount: Option<Decimal>,
259 pub counterparty: Option<String>,
261 pub account_code: Option<String>,
263}
264
265impl Observation {
266 pub fn new(date: NaiveDate) -> Self {
268 Self {
269 date,
270 time: None,
271 amount: None,
272 counterparty: None,
273 account_code: None,
274 }
275 }
276
277 pub fn with_time(mut self, time: NaiveTime) -> Self {
279 self.time = Some(time);
280 self
281 }
282
283 pub fn with_amount(mut self, amount: Decimal) -> Self {
285 self.amount = Some(amount);
286 self
287 }
288
289 pub fn with_counterparty(mut self, counterparty: impl Into<String>) -> Self {
291 self.counterparty = Some(counterparty.into());
292 self
293 }
294
295 pub fn with_account(mut self, account: impl Into<String>) -> Self {
297 self.account_code = Some(account.into());
298 self
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
304pub struct EntityBaseline {
305 pub avg_transaction_amount: f64,
307 pub amount_std_dev: f64,
309 pub transaction_frequency: f64,
311 pub frequency_std_dev: f64,
313 pub typical_posting_hours: (u8, u8),
315 pub common_counterparties: Vec<String>,
317 pub usual_account_codes: Vec<String>,
319 pub observation_count: u32,
321 #[serde(skip)]
323 amount_sum: f64,
324 #[serde(skip)]
326 amount_sum_sq: f64,
327 #[serde(skip)]
329 daily_counts: HashMap<NaiveDate, u32>,
330 #[serde(skip)]
332 hour_counts: [u32; 24],
333 #[serde(skip)]
335 counterparty_freq: HashMap<String, u32>,
336 #[serde(skip)]
338 account_freq: HashMap<String, u32>,
339}
340
341impl Default for EntityBaseline {
342 fn default() -> Self {
343 Self::new()
344 }
345}
346
347impl EntityBaseline {
348 pub fn new() -> Self {
350 Self {
351 avg_transaction_amount: 0.0,
352 amount_std_dev: 0.0,
353 transaction_frequency: 0.0,
354 frequency_std_dev: 0.0,
355 typical_posting_hours: (8, 18),
356 common_counterparties: Vec::new(),
357 usual_account_codes: Vec::new(),
358 observation_count: 0,
359 amount_sum: 0.0,
360 amount_sum_sq: 0.0,
361 daily_counts: HashMap::new(),
362 hour_counts: [0; 24],
363 counterparty_freq: HashMap::new(),
364 account_freq: HashMap::new(),
365 }
366 }
367
368 pub fn add_observation(&mut self, observation: Observation) {
370 self.observation_count += 1;
371
372 if let Some(amount) = observation.amount {
374 let amount_f64: f64 = amount.try_into().unwrap_or(0.0);
375 self.amount_sum += amount_f64;
376 self.amount_sum_sq += amount_f64 * amount_f64;
377 self.avg_transaction_amount = self.amount_sum / self.observation_count as f64;
378
379 if self.observation_count > 1 {
380 let variance = (self.amount_sum_sq
381 - (self.amount_sum * self.amount_sum) / self.observation_count as f64)
382 / (self.observation_count - 1) as f64;
383 self.amount_std_dev = variance.max(0.0).sqrt();
384 }
385 }
386
387 *self.daily_counts.entry(observation.date).or_insert(0) += 1;
389 self.update_frequency_stats();
390
391 if let Some(time) = observation.time {
393 self.hour_counts[time.hour() as usize] += 1;
394 self.update_typical_hours();
395 }
396
397 if let Some(ref counterparty) = observation.counterparty {
399 *self
400 .counterparty_freq
401 .entry(counterparty.clone())
402 .or_insert(0) += 1;
403 self.update_common_counterparties();
404 }
405
406 if let Some(ref account) = observation.account_code {
408 *self.account_freq.entry(account.clone()).or_insert(0) += 1;
409 self.update_usual_accounts();
410 }
411 }
412
413 fn update_frequency_stats(&mut self) {
415 if self.daily_counts.is_empty() {
416 return;
417 }
418
419 let counts: Vec<f64> = self.daily_counts.values().map(|&c| c as f64).collect();
420 let n = counts.len() as f64;
421
422 self.transaction_frequency = counts.iter().sum::<f64>() / n;
423
424 if counts.len() > 1 {
425 let variance: f64 = counts
426 .iter()
427 .map(|c| (c - self.transaction_frequency).powi(2))
428 .sum::<f64>()
429 / (n - 1.0);
430 self.frequency_std_dev = variance.sqrt();
431 }
432 }
433
434 fn update_typical_hours(&mut self) {
436 let total: u32 = self.hour_counts.iter().sum();
437 if total == 0 {
438 return;
439 }
440
441 let threshold = (total as f64 * 0.1) as u32; let mut cumsum = 0u32;
445 let mut start_hour = 0u8;
446 for (hour, &count) in self.hour_counts.iter().enumerate() {
447 cumsum += count;
448 if cumsum > threshold {
449 start_hour = hour as u8;
450 break;
451 }
452 }
453
454 cumsum = 0;
455 let mut end_hour = 23u8;
456 for (hour, &count) in self.hour_counts.iter().enumerate().rev() {
457 cumsum += count;
458 if cumsum > threshold {
459 end_hour = hour as u8;
460 break;
461 }
462 }
463
464 self.typical_posting_hours = (start_hour, end_hour.max(start_hour + 1));
465 }
466
467 fn update_common_counterparties(&mut self) {
469 let mut sorted: Vec<_> = self.counterparty_freq.iter().collect();
470 sorted.sort_by(|a, b| b.1.cmp(a.1));
471 self.common_counterparties = sorted
472 .into_iter()
473 .take(10)
474 .map(|(k, _)| k.clone())
475 .collect();
476 }
477
478 fn update_usual_accounts(&mut self) {
480 let mut sorted: Vec<_> = self.account_freq.iter().collect();
481 sorted.sort_by(|a, b| b.1.cmp(a.1));
482 self.usual_account_codes = sorted.into_iter().take(5).map(|(k, _)| k.clone()).collect();
483 }
484
485 pub fn is_within_typical_hours(&self, time: NaiveTime) -> bool {
487 let hour = time.hour() as u8;
488 hour >= self.typical_posting_hours.0 && hour <= self.typical_posting_hours.1
489 }
490
491 pub fn is_established(&self, min_observations: u32) -> bool {
493 self.observation_count >= min_observations
494 }
495}
496
497#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
499pub enum DeviationType {
500 AmountAnomaly,
502 FrequencyAnomaly,
504 NewCounterparty,
506 TimingAnomaly,
508 UnusualAccount,
510}
511
512#[derive(Debug, Clone)]
514pub struct BehavioralDeviation {
515 pub deviation_type: DeviationType,
517 pub std_deviations: f64,
519 pub expected_value: f64,
521 pub actual_value: f64,
523 pub label: AnomalyType,
525 pub severity: SeverityLevel,
527 pub description: String,
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use rust_decimal_macros::dec;
535
536 #[test]
537 fn test_entity_baseline_creation() {
538 let baseline = EntityBaseline::new();
539 assert_eq!(baseline.observation_count, 0);
540 assert!((baseline.avg_transaction_amount - 0.0).abs() < 0.01);
541 }
542
543 #[test]
544 fn test_observation_builder() {
545 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
546 .with_amount(dec!(1000))
547 .with_counterparty("VENDOR001")
548 .with_account("5000");
549
550 assert_eq!(obs.amount, Some(dec!(1000)));
551 assert_eq!(obs.counterparty, Some("VENDOR001".to_string()));
552 assert_eq!(obs.account_code, Some("5000".to_string()));
553 }
554
555 #[test]
556 fn test_baseline_amount_tracking() {
557 let mut baseline = EntityBaseline::new();
558
559 for amount in [1000.0, 1100.0, 900.0, 1050.0, 950.0] {
560 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
561 .with_amount(Decimal::try_from(amount).unwrap());
562 baseline.add_observation(obs);
563 }
564
565 assert_eq!(baseline.observation_count, 5);
566 assert!((baseline.avg_transaction_amount - 1000.0).abs() < 1.0);
567 assert!(baseline.amount_std_dev > 0.0);
568 }
569
570 #[test]
571 fn test_behavioral_baseline_deviation_detection() {
572 let mut baseline_mgr = BehavioralBaseline::default();
573
574 let amounts = [
577 900, 950, 1000, 1050, 1100, 920, 980, 1020, 1080, 950, 960, 1000, 1040, 990, 970, 1010,
578 1030, 1000, 980, 1020,
579 ];
580 for (i, &amount) in amounts.iter().enumerate() {
581 let obs = Observation::new(
582 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
583 + chrono::Duration::days(i as i64 % 10),
584 )
585 .with_amount(Decimal::from(amount))
586 .with_counterparty("VENDOR001")
587 .with_time(NaiveTime::from_hms_opt(10, 0, 0).unwrap());
588 baseline_mgr.record_observation("ENTITY1", obs);
589 }
590
591 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
593 .with_amount(dec!(50000))
594 .with_counterparty("VENDOR001");
595
596 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
597
598 assert!(deviations
600 .iter()
601 .any(|d| d.deviation_type == DeviationType::AmountAnomaly));
602 }
603
604 #[test]
605 fn test_new_counterparty_detection() {
606 let mut baseline_mgr = BehavioralBaseline::default();
607
608 for i in 0..15 {
610 let cp = format!("VENDOR{:03}", i % 5);
611 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
612 .with_amount(dec!(1000))
613 .with_counterparty(&cp);
614 baseline_mgr.record_observation("ENTITY1", obs);
615 }
616
617 let new_cp_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
619 .with_amount(dec!(1000))
620 .with_counterparty("NEW_VENDOR");
621
622 let deviations = baseline_mgr.check_deviation("ENTITY1", &new_cp_obs);
623
624 assert!(deviations
626 .iter()
627 .any(|d| d.deviation_type == DeviationType::NewCounterparty));
628 }
629
630 #[test]
631 fn test_timing_anomaly_detection() {
632 let mut baseline_mgr = BehavioralBaseline::default();
633
634 for i in 0..15 {
636 let hour = 9 + (i % 8);
637 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 1).unwrap())
638 .with_amount(dec!(1000))
639 .with_time(NaiveTime::from_hms_opt(hour, 0, 0).unwrap());
640 baseline_mgr.record_observation("ENTITY1", obs);
641 }
642
643 let unusual_time_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
645 .with_amount(dec!(1000))
646 .with_time(NaiveTime::from_hms_opt(3, 0, 0).unwrap());
647
648 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_time_obs);
649
650 assert!(deviations
652 .iter()
653 .any(|d| d.deviation_type == DeviationType::TimingAnomaly));
654 }
655
656 #[test]
657 fn test_frequency_deviation() {
658 let mut baseline_mgr = BehavioralBaseline::default();
659
660 let daily_counts = [
662 2, 1, 3, 2, 2, 1, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 1, 2, 3, 2, 1, 2, 2, 3,
663 2,
664 ];
665 for (day, &count) in daily_counts.iter().enumerate() {
666 for _ in 0..count {
667 let obs = Observation::new(
668 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap()
669 + chrono::Duration::days(day as i64),
670 )
671 .with_amount(dec!(1000));
672 baseline_mgr.record_observation("ENTITY1", obs);
673 }
674 }
675
676 let deviation = baseline_mgr.check_frequency_deviation("ENTITY1", 10.0);
678
679 assert!(deviation.is_some());
681 assert_eq!(
682 deviation.unwrap().deviation_type,
683 DeviationType::FrequencyAnomaly
684 );
685 }
686
687 #[test]
688 fn test_insufficient_baseline() {
689 let mut baseline_mgr = BehavioralBaseline::default();
690
691 for i in 0..5 {
693 let obs = Observation::new(
694 NaiveDate::from_ymd_opt(2024, 6, 1).unwrap() + chrono::Duration::days(i),
695 )
696 .with_amount(dec!(1000));
697 baseline_mgr.record_observation("ENTITY1", obs);
698 }
699
700 let unusual_obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 25).unwrap())
702 .with_amount(dec!(50000));
703
704 let deviations = baseline_mgr.check_deviation("ENTITY1", &unusual_obs);
705
706 assert!(deviations.is_empty());
708 }
709
710 #[test]
711 fn test_typical_hours_calculation() {
712 let mut baseline = EntityBaseline::new();
713
714 for _ in 0..10 {
716 for hour in 9..17 {
717 let obs = Observation::new(NaiveDate::from_ymd_opt(2024, 6, 15).unwrap())
718 .with_time(NaiveTime::from_hms_opt(hour, 30, 0).unwrap());
719 baseline.add_observation(obs);
720 }
721 }
722
723 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(10, 0, 0).unwrap()));
724 assert!(baseline.is_within_typical_hours(NaiveTime::from_hms_opt(14, 0, 0).unwrap()));
725 }
726}