Skip to main content

datasynth_eval/behavioral_fidelity/
velocity_rules.rs

1//! P4 — Velocity-rule trigger rate gap.
2
3use std::collections::{HashMap, HashSet};
4
5use chrono::{Datelike, NaiveDate};
6use serde::{Deserialize, Serialize};
7
8use super::math::{is_weekend, percentile};
9use super::types::{Record, RuleSet, VelocityRuleKind, VelocityRuleSpec};
10
11/// Per-rule trigger rate result.
12#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
13pub struct RuleResult {
14    pub id: String,
15    pub trigger_rate_reference: f64,
16    pub trigger_rate_syn: f64,
17    pub abs_gap: f64,
18}
19
20/// Computes per-rule TR(real) and TR(syn) and the mean absolute gap across rules.
21pub fn evaluate_rule_set<F>(
22    rules: &RuleSet,
23    real: &[Record],
24    syn: &[Record],
25    entity_of: F,
26) -> (Vec<RuleResult>, f64)
27where
28    F: Fn(&Record) -> Option<String> + Copy,
29{
30    let mut results = Vec::with_capacity(rules.rules.len());
31    let mut gaps_sum = 0.0;
32    for rule in &rules.rules {
33        let tr_real = trigger_rate(rule, real, entity_of);
34        let tr_syn = trigger_rate(rule, syn, entity_of);
35        let gap = (tr_real - tr_syn).abs();
36        gaps_sum += gap;
37        results.push(RuleResult {
38            id: rule.id.clone(),
39            trigger_rate_reference: tr_real,
40            trigger_rate_syn: tr_syn,
41            abs_gap: gap,
42        });
43    }
44    let mean_gap = if rules.rules.is_empty() {
45        0.0
46    } else {
47        gaps_sum / rules.rules.len() as f64
48    };
49    (results, mean_gap)
50}
51
52fn trigger_rate<F>(rule: &VelocityRuleSpec, records: &[Record], entity_of: F) -> f64
53where
54    F: Fn(&Record) -> Option<String> + Copy,
55{
56    let by_entity: HashMap<String, Vec<&Record>> = group_by_entity(records, entity_of);
57    if by_entity.is_empty() {
58        return 0.0;
59    }
60    let triggered: usize = by_entity
61        .values()
62        .filter(|rows| entity_triggers(rule, rows))
63        .count();
64    triggered as f64 / by_entity.len() as f64
65}
66
67fn group_by_entity<F>(records: &[Record], entity_of: F) -> HashMap<String, Vec<&Record>>
68where
69    F: Fn(&Record) -> Option<String> + Copy,
70{
71    let mut by: HashMap<String, Vec<&Record>> = HashMap::new();
72    for r in records {
73        if let Some(e) = entity_of(r) {
74            by.entry(e).or_default().push(r);
75        }
76    }
77    by
78}
79
80fn entity_triggers(rule: &VelocityRuleSpec, rows: &[&Record]) -> bool {
81    match &rule.kind {
82        VelocityRuleKind::CountPerEntityPerDay { threshold } => {
83            let mut by_day: HashMap<NaiveDate, u32> = HashMap::new();
84            for r in rows {
85                if !is_weekend(r.entry_date) {
86                    *by_day.entry(r.entry_date).or_insert(0) += 1;
87                }
88            }
89            by_day.values().any(|&c| c > *threshold)
90        }
91        VelocityRuleKind::DistinctAccountsPerEntityPerDay { threshold } => {
92            let mut by_day: HashMap<NaiveDate, HashSet<&str>> = HashMap::new();
93            for r in rows {
94                by_day
95                    .entry(r.entry_date)
96                    .or_default()
97                    .insert(r.gl_account.as_str());
98            }
99            by_day.values().any(|s| s.len() > *threshold as usize)
100        }
101        VelocityRuleKind::SumAmountPerEntityPerDayAbovePercentile { pct } => {
102            let mut by_day: HashMap<NaiveDate, f64> = HashMap::new();
103            for r in rows {
104                *by_day.entry(r.entry_date).or_insert(0.0) += r.functional_amount.abs();
105            }
106            let sums: Vec<f64> = by_day.values().copied().collect();
107            if sums.is_empty() {
108                return false;
109            }
110            let threshold = percentile(&sums, *pct);
111            sums.iter().any(|&s| s > threshold)
112        }
113        VelocityRuleKind::DormantAccountActivity { inactivity_days } => {
114            let mut last_seen: HashMap<&str, NaiveDate> = HashMap::new();
115            let mut sorted = rows.to_vec();
116            sorted.sort_by_key(|r| r.entry_date);
117            for r in sorted {
118                if let Some(prev) = last_seen.get(r.gl_account.as_str()) {
119                    if (r.entry_date - *prev).num_days() >= *inactivity_days {
120                        return true;
121                    }
122                }
123                last_seen.insert(r.gl_account.as_str(), r.entry_date);
124            }
125            false
126        }
127        VelocityRuleKind::DistinctTradingPartnersPerEntityPerDay { threshold } => {
128            let mut by_day: HashMap<NaiveDate, HashSet<&str>> = HashMap::new();
129            for r in rows {
130                if let Some(tp) = r.trading_partner.as_deref() {
131                    by_day.entry(r.entry_date).or_default().insert(tp);
132                }
133            }
134            by_day.values().any(|s| s.len() > *threshold as usize)
135        }
136        VelocityRuleKind::AmountSpikeRatio { window_days, ratio } => {
137            let mut sorted = rows.to_vec();
138            sorted.sort_by_key(|r| r.entry_date);
139            for end_idx in 0..sorted.len() {
140                let end_date = sorted[end_idx].entry_date;
141                let start_date = end_date - chrono::Duration::days(*window_days);
142                let window: Vec<f64> = sorted
143                    .iter()
144                    .filter(|r| r.entry_date >= start_date && r.entry_date <= end_date)
145                    .map(|r| r.functional_amount.abs())
146                    .filter(|x| *x > 0.0)
147                    .collect();
148                if window.len() < 3 {
149                    continue;
150                }
151                let max = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
152                let med = percentile(&window, 0.5);
153                if med > 0.0 && (max / med) > *ratio {
154                    return true;
155                }
156            }
157            false
158        }
159        VelocityRuleKind::OffHoursPosting => rows.iter().any(|r| is_weekend(r.entry_date)),
160        VelocityRuleKind::PostClosePosting {
161            tolerance_business_days,
162        } => rows.iter().any(|r| {
163            let period_end = last_day_of_month(r.effective_date);
164            let tol_days = *tolerance_business_days * 7 / 5;
165            (r.entry_date - period_end).num_days() > tol_days
166        }),
167        VelocityRuleKind::RoundDollarConcentration { share_threshold } => {
168            let n = rows.len();
169            if n == 0 {
170                return false;
171            }
172            let rounds = rows
173                .iter()
174                .filter(|r| {
175                    let amt = r.functional_amount.abs().round() as i64;
176                    amt > 0 && amt % 1000 == 0
177                })
178                .count();
179            (rounds as f64 / n as f64) > *share_threshold
180        }
181        VelocityRuleKind::BackdatingDays { gap_days } => rows
182            .iter()
183            .any(|r| (r.effective_date - r.entry_date).num_days() > *gap_days),
184    }
185}
186
187fn last_day_of_month(d: NaiveDate) -> NaiveDate {
188    let (y, m) = (d.year(), d.month());
189    let (ny, nm) = if m == 12 { (y + 1, 1) } else { (y, m + 1) };
190    NaiveDate::from_ymd_opt(ny, nm, 1)
191        .expect("valid month+1 date")
192        .pred_opt()
193        .expect("date before first of month always exists")
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    fn r(src: &str, gl: &str, day: u32, amt: f64, tp: Option<&str>) -> Record {
201        let d = NaiveDate::from_ymd_opt(2022, 1, day).unwrap();
202        Record {
203            source: src.into(),
204            gl_account: gl.into(),
205            cost_center: None,
206            profit_center: None,
207            trading_partner: tp.map(String::from),
208            je_number: format!("J{src}{day}{gl}"),
209            je_line_number: "001".into(),
210            effective_date: d,
211            entry_date: d,
212            created_at: None,
213            functional_amount: amt,
214            header_text: String::new(),
215            line_text: String::new(),
216        }
217    }
218
219    #[test]
220    fn r1_count_per_day_triggers() {
221        let recs: Vec<Record> = (0..6).map(|_| r("A", "100", 3, 1.0, None)).collect();
222        let kind = VelocityRuleSpec {
223            id: "R1".into(),
224            description: "".into(),
225            kind: VelocityRuleKind::CountPerEntityPerDay { threshold: 5 },
226        };
227        let refs: Vec<&Record> = recs.iter().collect();
228        assert!(entity_triggers(&kind, &refs));
229    }
230
231    #[test]
232    fn r7_off_hours_triggers_on_weekend_record() {
233        // 2022-01-01 was a Saturday.
234        let recs = [r("A", "1", 1, 1.0, None)];
235        let kind = VelocityRuleSpec {
236            id: "R7".into(),
237            description: "".into(),
238            kind: VelocityRuleKind::OffHoursPosting,
239        };
240        let refs: Vec<&Record> = recs.iter().collect();
241        assert!(entity_triggers(&kind, &refs));
242    }
243
244    #[test]
245    fn r10_backdating_triggers_when_eff_minus_entry_over_30() {
246        let mut rec = r("A", "1", 1, 1.0, None);
247        rec.effective_date = NaiveDate::from_ymd_opt(2022, 3, 31).unwrap();
248        rec.entry_date = NaiveDate::from_ymd_opt(2022, 1, 15).unwrap();
249        let kind = VelocityRuleSpec {
250            id: "R10".into(),
251            description: "".into(),
252            kind: VelocityRuleKind::BackdatingDays { gap_days: 30 },
253        };
254        assert!(entity_triggers(&kind, &[&rec]));
255    }
256
257    #[test]
258    fn canonical_rules_has_ten_entries() {
259        let rs = RuleSet::canonical_gl_rules();
260        assert_eq!(rs.rules.len(), 10);
261        let ids: Vec<&str> = rs.rules.iter().map(|r| r.id.as_str()).collect();
262        assert_eq!(
263            ids,
264            vec!["R1", "R2", "R3", "R4", "R5", "R6", "R7", "R8", "R9", "R10"]
265        );
266    }
267}