datasynth_eval/ml/
detectability.rs1use crate::error::EvalResult;
15use datasynth_core::models::{LabeledAnomaly, ObservabilityClass};
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeMap;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct DetectabilityReport {
22 pub total: usize,
24 pub by_class: BTreeMap<String, usize>,
26 pub fraction_by_class: BTreeMap<String, f64>,
28 pub memory_only_fraction: f64,
31 pub per_je_density_fraction: f64,
33 pub relational_graph_fraction: f64,
35 pub temporal_fraction: f64,
37 pub by_class_and_category: BTreeMap<String, BTreeMap<String, usize>>,
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct DetectabilityAnalyzer;
44
45impl DetectabilityAnalyzer {
46 pub fn new() -> Self {
48 Self
49 }
50
51 pub fn analyze(&self, labels: &[LabeledAnomaly]) -> EvalResult<DetectabilityReport> {
53 let total = labels.len();
54 let mut by_class: BTreeMap<String, usize> = BTreeMap::new();
55 let mut by_class_and_category: BTreeMap<String, BTreeMap<String, usize>> = BTreeMap::new();
56
57 for c in [
59 ObservabilityClass::PerJeDensity,
60 ObservabilityClass::RelationalGraph,
61 ObservabilityClass::Temporal,
62 ObservabilityClass::MemoryOnly,
63 ] {
64 by_class.insert(c.as_str().to_string(), 0);
65 }
66
67 for l in labels {
68 let cls = l.observability.as_str().to_string();
69 *by_class.entry(cls.clone()).or_default() += 1;
70 *by_class_and_category
71 .entry(cls)
72 .or_default()
73 .entry(l.anomaly_type.category().to_string())
74 .or_default() += 1;
75 }
76
77 let denom = total.max(1) as f64;
78 let fraction_by_class: BTreeMap<String, f64> = by_class
79 .iter()
80 .map(|(k, v)| (k.clone(), *v as f64 / denom))
81 .collect();
82 let frac = |c: ObservabilityClass| *by_class.get(c.as_str()).unwrap_or(&0) as f64 / denom;
83
84 Ok(DetectabilityReport {
85 total,
86 memory_only_fraction: frac(ObservabilityClass::MemoryOnly),
87 per_je_density_fraction: frac(ObservabilityClass::PerJeDensity),
88 relational_graph_fraction: frac(ObservabilityClass::RelationalGraph),
89 temporal_fraction: frac(ObservabilityClass::Temporal),
90 by_class,
91 fraction_by_class,
92 by_class_and_category,
93 })
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use chrono::NaiveDate;
101 use datasynth_core::models::{AnomalyType, ErrorType, FraudType, RelationalAnomalyType};
102
103 fn label(at: AnomalyType) -> LabeledAnomaly {
104 LabeledAnomaly::new(
105 "A".to_string(),
106 at,
107 "JE".to_string(),
108 "JE".to_string(),
109 "1000".to_string(),
110 NaiveDate::from_ymd_opt(2026, 1, 1).unwrap(),
111 )
112 }
113
114 #[test]
115 fn profiles_population_across_observability_arms() {
116 let labels = vec![
117 label(AnomalyType::Fraud(FraudType::RoundDollarManipulation)), label(AnomalyType::Fraud(FraudType::DuplicatePayment)), label(AnomalyType::Relational(
120 RelationalAnomalyType::CircularTransaction,
121 )), label(AnomalyType::Error(ErrorType::WrongPeriod)), ];
124 let report = DetectabilityAnalyzer::new().analyze(&labels).unwrap();
125
126 assert_eq!(report.total, 4);
127 assert_eq!(report.by_class["per_je_density"], 1);
128 assert_eq!(report.by_class["memory_only"], 1);
129 assert_eq!(report.by_class["relational_graph"], 1);
130 assert_eq!(report.by_class["temporal"], 1);
131 assert!((report.memory_only_fraction - 0.25).abs() < 1e-9);
132 assert!((report.per_je_density_fraction - 0.25).abs() < 1e-9);
133 assert_eq!(report.by_class_and_category["memory_only"]["Fraud"], 1);
135 }
136
137 #[test]
138 fn empty_population_has_stable_zeroed_schema() {
139 let report = DetectabilityAnalyzer::new().analyze(&[]).unwrap();
140 assert_eq!(report.total, 0);
141 assert_eq!(report.by_class.len(), 4);
143 assert_eq!(report.memory_only_fraction, 0.0);
144 assert_eq!(report.by_class["relational_graph"], 0);
145 }
146}