Skip to main content

datasynth_eval/coherence/
je_risk_scoring.rs

1//! Journal entry risk scoring evaluator.
2//!
3//! Scores each journal entry for fraud/error risk attributes and computes
4//! aggregate statistics including anomaly separability.
5
6use datasynth_core::models::JournalEntry;
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// ─── Result types ─────────────────────────────────────────────────────────────
12
13/// Aggregate result of JE risk scoring.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct JeRiskScoringResult {
16    /// Total number of journal entries supplied.
17    pub total_entries: usize,
18    /// Number of entries that were actually scored.
19    pub scored_entries: usize,
20    /// Distribution of entries across risk bands.
21    pub risk_distribution: RiskDistribution,
22    /// Per-attribute statistics.
23    pub risk_attributes: Vec<RiskAttributeStats>,
24    /// Average anomaly score minus average clean score.
25    /// Pass threshold: > 0.10.
26    pub anomaly_separability: f64,
27    /// True when anomaly_separability > 0.10 (or no anomaly labels present).
28    pub passes: bool,
29}
30
31/// Count of entries in each risk band.
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct RiskDistribution {
34    /// Score < 0.30.
35    pub low_risk: usize,
36    /// 0.30 ≤ score < 0.60.
37    pub medium_risk: usize,
38    /// Score ≥ 0.60.
39    pub high_risk: usize,
40}
41
42/// Statistics for one risk attribute across all scored entries.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RiskAttributeStats {
45    /// Attribute name (e.g. "RoundNumber").
46    pub attribute: String,
47    /// Number of entries where this attribute was triggered.
48    pub count: usize,
49    /// Percentage of total scored entries (0–100).
50    pub percentage: f64,
51}
52
53// ─── Per-entry score ──────────────────────────────────────────────────────────
54
55/// All risk attributes that can be detected.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57enum RiskAttribute {
58    RoundNumber,
59    UnusualHour,
60    WeekendHoliday,
61    NonStandardUser,
62    BelowApprovalThreshold,
63    ManualToAutomatedAccount,
64    LargeRoundTrip,
65}
66
67impl RiskAttribute {
68    fn name(self) -> &'static str {
69        match self {
70            Self::RoundNumber => "RoundNumber",
71            Self::UnusualHour => "UnusualHour",
72            Self::WeekendHoliday => "WeekendHoliday",
73            Self::NonStandardUser => "NonStandardUser",
74            Self::BelowApprovalThreshold => "BelowApprovalThreshold",
75            Self::ManualToAutomatedAccount => "ManualToAutomatedAccount",
76            Self::LargeRoundTrip => "LargeRoundTrip",
77        }
78    }
79
80    fn weight(self) -> f64 {
81        match self {
82            Self::RoundNumber => 0.10,
83            Self::UnusualHour => 0.15,
84            Self::WeekendHoliday => 0.15,
85            Self::NonStandardUser => 0.15,
86            Self::BelowApprovalThreshold => 0.15,
87            Self::ManualToAutomatedAccount => 0.15,
88            Self::LargeRoundTrip => 0.15,
89        }
90    }
91
92    fn all() -> &'static [RiskAttribute] {
93        &[
94            Self::RoundNumber,
95            Self::UnusualHour,
96            Self::WeekendHoliday,
97            Self::NonStandardUser,
98            Self::BelowApprovalThreshold,
99            Self::ManualToAutomatedAccount,
100            Self::LargeRoundTrip,
101        ]
102    }
103}
104
105// ─── Detection helpers ────────────────────────────────────────────────────────
106
107/// Common thresholds for "split-payment" detection (amounts just below these).
108const APPROVAL_THRESHOLDS: &[u64] = &[1000, 2500, 5000, 10000, 25000, 50000, 100000];
109
110/// GL accounts that are normally auto-posted (bank, AP clearing, AR clearing).
111/// Prefixes: "10" (bank/cash), "20" (AP clearing), "11" (AR clearing).
112const AUTOMATED_ACCOUNT_PREFIXES: &[&str] = &["100", "101", "102", "200", "201", "110", "111"];
113
114fn is_round_number(amount: Decimal) -> bool {
115    let thousand = Decimal::from(1000u32);
116    amount > Decimal::ZERO && (amount % thousand).is_zero()
117}
118
119fn is_unusual_hour(hour: u32) -> bool {
120    !(7..=21).contains(&hour)
121}
122
123fn is_weekend(weekday: chrono::Weekday) -> bool {
124    weekday == chrono::Weekday::Sat || weekday == chrono::Weekday::Sun
125}
126
127fn is_below_approval_threshold(amount: Decimal) -> bool {
128    for &threshold in APPROVAL_THRESHOLDS {
129        let low = Decimal::from(threshold - 100);
130        let high = Decimal::from(threshold - 1);
131        if amount >= low && amount <= high {
132            return true;
133        }
134    }
135    false
136}
137
138fn is_manual_to_automated_account(entry: &JournalEntry) -> bool {
139    use datasynth_core::models::TransactionSource;
140    if entry.header.source != TransactionSource::Manual {
141        return false;
142    }
143    entry.lines.iter().any(|line| {
144        AUTOMATED_ACCOUNT_PREFIXES
145            .iter()
146            .any(|prefix| line.gl_account.starts_with(prefix))
147    })
148}
149
150fn has_round_trip(entry: &JournalEntry) -> bool {
151    // Same account appears on both debit and credit sides within one entry.
152    let debited: std::collections::HashSet<_> = entry
153        .lines
154        .iter()
155        .filter(|l| l.debit_amount > Decimal::ZERO)
156        .map(|l| l.gl_account.as_str())
157        .collect();
158    let credited: std::collections::HashSet<_> = entry
159        .lines
160        .iter()
161        .filter(|l| l.credit_amount > Decimal::ZERO)
162        .map(|l| l.gl_account.as_str())
163        .collect();
164    debited.intersection(&credited).next().is_some()
165}
166
167// ─── Pre-computation pass ─────────────────────────────────────────────────────
168
169/// Count postings per user across all entries.
170fn build_user_posting_counts(entries: &[JournalEntry]) -> HashMap<String, usize> {
171    let mut counts: HashMap<String, usize> = HashMap::new();
172    for entry in entries {
173        *counts.entry(entry.header.created_by.clone()).or_default() += 1;
174    }
175    counts
176}
177
178// ─── Scoring ──────────────────────────────────────────────────────────────────
179
180/// Score a single journal entry; returns (score, triggered_attributes).
181fn score_entry(
182    entry: &JournalEntry,
183    user_counts: &HashMap<String, usize>,
184) -> (f64, Vec<RiskAttribute>) {
185    use chrono::Datelike as _;
186    use chrono::Timelike as _;
187
188    let mut triggered = Vec::new();
189
190    // Derive a representative "amount" from the entry (sum of debit amounts).
191    let total_debit: Decimal = entry.lines.iter().map(|l| l.debit_amount).sum();
192
193    // RoundNumber
194    if is_round_number(total_debit) {
195        triggered.push(RiskAttribute::RoundNumber);
196    }
197
198    // UnusualHour
199    let hour = entry.header.created_at.hour();
200    if is_unusual_hour(hour) {
201        triggered.push(RiskAttribute::UnusualHour);
202    }
203
204    // WeekendHoliday
205    if is_weekend(entry.header.posting_date.weekday()) {
206        triggered.push(RiskAttribute::WeekendHoliday);
207    }
208
209    // NonStandardUser (fewer than 5 postings)
210    let user_count = user_counts
211        .get(&entry.header.created_by)
212        .copied()
213        .unwrap_or(0);
214    if user_count < 5 {
215        triggered.push(RiskAttribute::NonStandardUser);
216    }
217
218    // BelowApprovalThreshold
219    if is_below_approval_threshold(total_debit) {
220        triggered.push(RiskAttribute::BelowApprovalThreshold);
221    }
222
223    // ManualToAutomatedAccount
224    if is_manual_to_automated_account(entry) {
225        triggered.push(RiskAttribute::ManualToAutomatedAccount);
226    }
227
228    // LargeRoundTrip
229    if has_round_trip(entry) {
230        triggered.push(RiskAttribute::LargeRoundTrip);
231    }
232
233    let raw_score: f64 = triggered.iter().map(|a| a.weight()).sum();
234    let score = raw_score.min(1.0_f64);
235
236    (score, triggered)
237}
238
239// ─── Public API ───────────────────────────────────────────────────────────────
240
241/// Score all journal entries and return aggregate statistics.
242pub fn score_entries(entries: &[JournalEntry]) -> JeRiskScoringResult {
243    let user_counts = build_user_posting_counts(entries);
244
245    let mut distribution = RiskDistribution::default();
246    let mut attribute_counts: HashMap<RiskAttribute, usize> = HashMap::new();
247    let mut anomaly_scores: Vec<f64> = Vec::new();
248    let mut clean_scores: Vec<f64> = Vec::new();
249
250    for entry in entries {
251        let (score, triggered) = score_entry(entry, &user_counts);
252
253        // Risk band
254        if score < 0.30 {
255            distribution.low_risk += 1;
256        } else if score < 0.60 {
257            distribution.medium_risk += 1;
258        } else {
259            distribution.high_risk += 1;
260        }
261
262        // Attribute counts
263        for attr in &triggered {
264            *attribute_counts.entry(*attr).or_default() += 1;
265        }
266
267        // Separability tracking
268        if entry.header.is_anomaly || entry.header.is_fraud {
269            anomaly_scores.push(score);
270        } else {
271            clean_scores.push(score);
272        }
273    }
274
275    let total = entries.len();
276    let risk_attributes: Vec<RiskAttributeStats> = RiskAttribute::all()
277        .iter()
278        .map(|&attr| {
279            let count = attribute_counts.get(&attr).copied().unwrap_or(0);
280            let percentage = if total > 0 {
281                count as f64 / total as f64 * 100.0
282            } else {
283                0.0
284            };
285            RiskAttributeStats {
286                attribute: attr.name().to_string(),
287                count,
288                percentage,
289            }
290        })
291        .collect();
292
293    let avg = |v: &[f64]| -> f64 {
294        if v.is_empty() {
295            0.0
296        } else {
297            v.iter().sum::<f64>() / v.len() as f64
298        }
299    };
300
301    let anomaly_separability = if anomaly_scores.is_empty() {
302        // No anomaly labels → vacuously pass
303        1.0
304    } else {
305        avg(&anomaly_scores) - avg(&clean_scores)
306    };
307
308    let passes = anomaly_separability > 0.10;
309
310    JeRiskScoringResult {
311        total_entries: total,
312        scored_entries: total,
313        risk_distribution: distribution,
314        risk_attributes,
315        anomaly_separability,
316        passes,
317    }
318}
319
320// ─── Tests ────────────────────────────────────────────────────────────────────
321
322#[cfg(test)]
323#[allow(clippy::unwrap_used)]
324mod tests {
325    use super::*;
326    use datasynth_core::models::{
327        JournalEntry, JournalEntryHeader, JournalEntryLine, TransactionSource,
328    };
329    use rust_decimal_macros::dec;
330
331    fn make_date(year: i32, month: u32, day: u32) -> chrono::NaiveDate {
332        chrono::NaiveDate::from_ymd_opt(year, month, day).unwrap()
333    }
334
335    fn weekday_date() -> chrono::NaiveDate {
336        // 2024-01-03 is a Wednesday
337        make_date(2024, 1, 3)
338    }
339
340    fn weekend_date() -> chrono::NaiveDate {
341        // 2024-01-06 is a Saturday
342        make_date(2024, 1, 6)
343    }
344
345    fn make_je(
346        company: &str,
347        posting_date: chrono::NaiveDate,
348        debit_account: &str,
349        credit_account: &str,
350        amount: Decimal,
351        user: &str,
352        source: TransactionSource,
353    ) -> JournalEntry {
354        let mut header = JournalEntryHeader::new(company.to_string(), posting_date);
355        header.created_by = user.to_string();
356        header.source = source;
357        let doc_id = header.document_id;
358        let mut entry = JournalEntry::new(header);
359        entry.add_line(JournalEntryLine::debit(
360            doc_id,
361            1,
362            debit_account.to_string(),
363            amount,
364        ));
365        entry.add_line(JournalEntryLine::credit(
366            doc_id,
367            2,
368            credit_account.to_string(),
369            amount,
370        ));
371        entry
372    }
373
374    fn simple_je(amount: Decimal) -> JournalEntry {
375        make_je(
376            "C001",
377            weekday_date(),
378            "6000",
379            "2000",
380            amount,
381            "alice",
382            TransactionSource::Automated,
383        )
384    }
385
386    // ── Round-number detection ────────────────────────────────────────────────
387
388    #[test]
389    fn test_round_number_detected() {
390        assert!(is_round_number(dec!(1000)));
391        assert!(is_round_number(dec!(5000)));
392        assert!(is_round_number(dec!(100000)));
393    }
394
395    #[test]
396    fn test_non_round_number() {
397        assert!(!is_round_number(dec!(1234.56)));
398        assert!(!is_round_number(dec!(999)));
399        assert!(!is_round_number(dec!(0)));
400    }
401
402    // ── Weekend detection ─────────────────────────────────────────────────────
403
404    #[test]
405    fn test_weekend_detected() {
406        let entry = make_je(
407            "C001",
408            weekend_date(),
409            "6000",
410            "2000",
411            dec!(500),
412            "alice",
413            TransactionSource::Automated,
414        );
415        let counts = build_user_posting_counts(std::slice::from_ref(&entry));
416        let (_score, triggered) = score_entry(&entry, &counts);
417        assert!(
418            triggered.contains(&RiskAttribute::WeekendHoliday),
419            "Saturday should trigger WeekendHoliday"
420        );
421    }
422
423    #[test]
424    fn test_weekday_not_flagged() {
425        let entry = make_je(
426            "C001",
427            weekday_date(),
428            "6000",
429            "2000",
430            dec!(500),
431            "alice",
432            TransactionSource::Automated,
433        );
434        // post alice 10 times so she's not a NonStandardUser
435        let mut entries: Vec<JournalEntry> = (0..10)
436            .map(|_| {
437                make_je(
438                    "C001",
439                    weekday_date(),
440                    "6000",
441                    "2000",
442                    dec!(500),
443                    "alice",
444                    TransactionSource::Automated,
445                )
446            })
447            .collect();
448        entries.push(entry.clone());
449        let counts = build_user_posting_counts(&entries);
450        let (_score, triggered) = score_entry(&entry, &counts);
451        assert!(
452            !triggered.contains(&RiskAttribute::WeekendHoliday),
453            "Wednesday should not trigger WeekendHoliday"
454        );
455    }
456
457    // ── Score range ───────────────────────────────────────────────────────────
458
459    #[test]
460    fn test_score_within_range() {
461        let entries: Vec<JournalEntry> = vec![simple_je(dec!(500)), simple_je(dec!(1000))];
462        let counts = build_user_posting_counts(&entries);
463        for entry in &entries {
464            let (score, _) = score_entry(entry, &counts);
465            assert!((0.0..=1.0).contains(&score), "Score {score} out of [0,1]");
466        }
467    }
468
469    #[test]
470    fn test_multi_attribute_higher_score() {
471        // Entry with round number + weekend
472        let risky = make_je(
473            "C001",
474            weekend_date(),
475            "6000",
476            "2000",
477            dec!(5000), // round
478            "alice",
479            TransactionSource::Automated,
480        );
481        let clean = make_je(
482            "C001",
483            weekday_date(),
484            "6000",
485            "2000",
486            dec!(1234),
487            "alice",
488            TransactionSource::Automated,
489        );
490        let mut entries = vec![risky.clone()];
491        // 10 alice postings so she's not NonStandardUser in clean entry
492        for _ in 0..10 {
493            entries.push(make_je(
494                "C001",
495                weekday_date(),
496                "6000",
497                "2000",
498                dec!(100),
499                "alice",
500                TransactionSource::Automated,
501            ));
502        }
503        entries.push(clean.clone());
504        let counts = build_user_posting_counts(&entries);
505        let (risky_score, _) = score_entry(&risky, &counts);
506        let (clean_score, _) = score_entry(&clean, &counts);
507        assert!(
508            risky_score >= clean_score,
509            "Risky entry ({risky_score}) should score >= clean ({clean_score})"
510        );
511    }
512
513    // ── Below-threshold detection ─────────────────────────────────────────────
514
515    #[test]
516    fn test_below_approval_threshold() {
517        assert!(is_below_approval_threshold(dec!(4999)));
518        assert!(is_below_approval_threshold(dec!(4950)));
519        assert!(!is_below_approval_threshold(dec!(5000)));
520        assert!(!is_below_approval_threshold(dec!(6000)));
521    }
522
523    // ── Round-trip detection ──────────────────────────────────────────────────
524
525    #[test]
526    fn test_round_trip_detected() {
527        let header = JournalEntryHeader::new("C001".to_string(), weekday_date());
528        let doc_id = header.document_id;
529        let mut entry = JournalEntry::new(header);
530        // Same account on both sides
531        entry.add_line(JournalEntryLine::debit(
532            doc_id,
533            1,
534            "1000".to_string(),
535            dec!(100),
536        ));
537        entry.add_line(JournalEntryLine::credit(
538            doc_id,
539            2,
540            "1000".to_string(),
541            dec!(100),
542        ));
543        assert!(
544            has_round_trip(&entry),
545            "Same account debit+credit should be detected"
546        );
547    }
548
549    #[test]
550    fn test_no_round_trip() {
551        let entry = simple_je(dec!(100));
552        assert!(
553            !has_round_trip(&entry),
554            "Different accounts should not trigger round-trip"
555        );
556    }
557
558    // ── Aggregate scoring ─────────────────────────────────────────────────────
559
560    #[test]
561    fn test_score_entries_basic() {
562        let entries: Vec<JournalEntry> = (0..20)
563            .map(|i| {
564                make_je(
565                    "C001",
566                    weekday_date(),
567                    "6000",
568                    "2000",
569                    Decimal::from(i * 100 + 50),
570                    "alice",
571                    TransactionSource::Automated,
572                )
573            })
574            .collect();
575        let result = score_entries(&entries);
576        assert_eq!(result.total_entries, 20);
577        assert_eq!(result.scored_entries, 20);
578        assert_eq!(
579            result.risk_distribution.low_risk
580                + result.risk_distribution.medium_risk
581                + result.risk_distribution.high_risk,
582            20
583        );
584        assert_eq!(result.risk_attributes.len(), RiskAttribute::all().len());
585    }
586
587    #[test]
588    fn test_anomaly_separability_passes_with_no_labels() {
589        let entries: Vec<JournalEntry> = (0..5).map(|_| simple_je(dec!(100))).collect();
590        let result = score_entries(&entries);
591        // No anomaly labels → vacuously passes
592        assert!(result.passes, "No anomaly labels → should pass");
593    }
594
595    #[test]
596    fn test_anomaly_separability_with_flagged_entries() {
597        let mut entries: Vec<JournalEntry> = Vec::new();
598
599        // 5 clean entries (low-risk amounts, no round numbers, weekday)
600        for _ in 0..5 {
601            let mut e = make_je(
602                "C001",
603                weekday_date(),
604                "6000",
605                "2000",
606                dec!(123),
607                "bob",
608                TransactionSource::Automated,
609            );
610            // post bob many times
611            e.header.is_anomaly = false;
612            entries.push(e);
613        }
614        // Force bob to have many postings
615        for _ in 0..10 {
616            entries.push(make_je(
617                "C001",
618                weekday_date(),
619                "6000",
620                "2000",
621                dec!(50),
622                "bob",
623                TransactionSource::Automated,
624            ));
625        }
626
627        // 5 anomaly entries: weekend + round number
628        for _ in 0..5 {
629            let mut e = make_je(
630                "C001",
631                weekend_date(),
632                "6000",
633                "2000",
634                dec!(5000),
635                "zz_rare_user",
636                TransactionSource::Automated,
637            );
638            e.header.is_anomaly = true;
639            entries.push(e);
640        }
641
642        let result = score_entries(&entries);
643        assert!(
644            result.anomaly_separability > 0.0,
645            "Anomaly entries should have higher average score than clean entries"
646        );
647    }
648}