datasynth_eval/ml/
anomaly_scoring.rs1use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone)]
11pub struct ScoredRecord {
12 pub record_id: String,
14 pub score: f64,
16 pub is_anomaly: bool,
18}
19
20#[derive(Debug, Clone)]
22pub struct AnomalyScoringThresholds {
23 pub min_anomaly_separability: f64,
25}
26
27impl Default for AnomalyScoringThresholds {
28 fn default() -> Self {
29 Self {
30 min_anomaly_separability: 0.70,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct AnomalyScoringAnalysis {
38 pub anomaly_separability: f64,
40 pub avg_anomaly_score: f64,
42 pub avg_normal_score: f64,
44 pub per_type_separability: Vec<(String, f64)>,
46 pub total_records: usize,
48 pub passes: bool,
50 pub issues: Vec<String>,
52}
53
54pub struct AnomalyScoringAnalyzer {
56 thresholds: AnomalyScoringThresholds,
57}
58
59impl AnomalyScoringAnalyzer {
60 pub fn new() -> Self {
62 Self {
63 thresholds: AnomalyScoringThresholds::default(),
64 }
65 }
66
67 pub fn with_thresholds(thresholds: AnomalyScoringThresholds) -> Self {
69 Self { thresholds }
70 }
71
72 pub fn analyze(&self, records: &[ScoredRecord]) -> EvalResult<AnomalyScoringAnalysis> {
74 let mut issues = Vec::new();
75 let total_records = records.len();
76
77 if records.is_empty() {
78 return Ok(AnomalyScoringAnalysis {
79 anomaly_separability: 0.0,
80 avg_anomaly_score: 0.0,
81 avg_normal_score: 0.0,
82 per_type_separability: Vec::new(),
83 total_records: 0,
84 passes: true,
85 issues: vec!["No records provided".to_string()],
86 });
87 }
88
89 let anomaly_scores: Vec<f64> = records
91 .iter()
92 .filter(|r| r.is_anomaly)
93 .map(|r| r.score)
94 .collect();
95 let normal_scores: Vec<f64> = records
96 .iter()
97 .filter(|r| !r.is_anomaly)
98 .map(|r| r.score)
99 .collect();
100
101 let avg_anomaly_score = if anomaly_scores.is_empty() {
102 0.0
103 } else {
104 anomaly_scores.iter().sum::<f64>() / anomaly_scores.len() as f64
105 };
106
107 let avg_normal_score = if normal_scores.is_empty() {
108 0.0
109 } else {
110 normal_scores.iter().sum::<f64>() / normal_scores.len() as f64
111 };
112
113 let anomaly_separability = if anomaly_scores.is_empty() || normal_scores.is_empty() {
115 issues.push("Need both anomaly and normal records for AUC-ROC".to_string());
116 0.5
117 } else {
118 self.compute_auc_roc(records)
119 };
120
121 if anomaly_separability < self.thresholds.min_anomaly_separability {
123 issues.push(format!(
124 "Anomaly separability {:.4} < {:.4} (threshold)",
125 anomaly_separability, self.thresholds.min_anomaly_separability
126 ));
127 }
128
129 let passes = issues.is_empty();
130
131 Ok(AnomalyScoringAnalysis {
132 anomaly_separability,
133 avg_anomaly_score,
134 avg_normal_score,
135 per_type_separability: Vec::new(),
136 total_records,
137 passes,
138 issues,
139 })
140 }
141
142 fn compute_auc_roc(&self, records: &[ScoredRecord]) -> f64 {
147 let total_positives = records.iter().filter(|r| r.is_anomaly).count();
148 let total_negatives = records.len() - total_positives;
149
150 if total_positives == 0 || total_negatives == 0 {
151 return 0.5;
152 }
153
154 let mut sorted: Vec<&ScoredRecord> = records.iter().collect();
156 sorted.sort_by(|a, b| {
157 b.score
158 .partial_cmp(&a.score)
159 .unwrap_or(std::cmp::Ordering::Equal)
160 });
161
162 let mut tp = 0usize;
163 let mut fp = 0usize;
164 let mut auc = 0.0;
165 let mut prev_fpr = 0.0;
166 let mut prev_tpr = 0.0;
167
168 for record in &sorted {
169 if record.is_anomaly {
170 tp += 1;
171 } else {
172 fp += 1;
173 }
174
175 let tpr = tp as f64 / total_positives as f64;
176 let fpr = fp as f64 / total_negatives as f64;
177
178 auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
180
181 prev_fpr = fpr;
182 prev_tpr = tpr;
183 }
184
185 auc
186 }
187}
188
189impl Default for AnomalyScoringAnalyzer {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_valid_anomaly_scoring() {
201 let records = vec![
202 ScoredRecord {
203 record_id: "r1".to_string(),
204 score: 0.9,
205 is_anomaly: true,
206 },
207 ScoredRecord {
208 record_id: "r2".to_string(),
209 score: 0.85,
210 is_anomaly: true,
211 },
212 ScoredRecord {
213 record_id: "r3".to_string(),
214 score: 0.1,
215 is_anomaly: false,
216 },
217 ScoredRecord {
218 record_id: "r4".to_string(),
219 score: 0.15,
220 is_anomaly: false,
221 },
222 ScoredRecord {
223 record_id: "r5".to_string(),
224 score: 0.05,
225 is_anomaly: false,
226 },
227 ];
228
229 let analyzer = AnomalyScoringAnalyzer::new();
230 let result = analyzer.analyze(&records).unwrap();
231
232 assert_eq!(result.total_records, 5);
233 assert!(result.anomaly_separability > 0.7);
234 assert!(result.avg_anomaly_score > result.avg_normal_score);
235 assert!(result.passes);
236 }
237
238 #[test]
239 fn test_invalid_anomaly_scoring() {
240 let records = vec![
242 ScoredRecord {
243 record_id: "r1".to_string(),
244 score: 0.1,
245 is_anomaly: true,
246 },
247 ScoredRecord {
248 record_id: "r2".to_string(),
249 score: 0.05,
250 is_anomaly: true,
251 },
252 ScoredRecord {
253 record_id: "r3".to_string(),
254 score: 0.9,
255 is_anomaly: false,
256 },
257 ScoredRecord {
258 record_id: "r4".to_string(),
259 score: 0.85,
260 is_anomaly: false,
261 },
262 ];
263
264 let analyzer = AnomalyScoringAnalyzer::new();
265 let result = analyzer.analyze(&records).unwrap();
266
267 assert!(result.anomaly_separability < 0.7);
268 assert!(!result.passes);
269 }
270
271 #[test]
272 fn test_empty_records() {
273 let analyzer = AnomalyScoringAnalyzer::new();
274 let result = analyzer.analyze(&[]).unwrap();
275
276 assert_eq!(result.total_records, 0);
277 assert_eq!(result.anomaly_separability, 0.0);
278 }
279}