datasynth_eval/ml/
anomaly_scoring.rs1use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone)]
11pub struct ScoredRecord {
12 pub record_id: String,
14 pub score: f64,
16 pub is_anomaly: bool,
18}
19
20#[derive(Debug, Clone)]
22pub struct AnomalyScoringThresholds {
23 pub min_anomaly_separability: f64,
25}
26
27impl Default for AnomalyScoringThresholds {
28 fn default() -> Self {
29 Self {
30 min_anomaly_separability: 0.70,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct AnomalyScoringAnalysis {
38 pub anomaly_separability: f64,
40 pub avg_anomaly_score: f64,
42 pub avg_normal_score: f64,
44 pub per_type_separability: Vec<(String, f64)>,
46 pub total_records: usize,
48 pub passes: bool,
50 pub issues: Vec<String>,
52}
53
54pub struct AnomalyScoringAnalyzer {
56 thresholds: AnomalyScoringThresholds,
57}
58
59impl AnomalyScoringAnalyzer {
60 pub fn new() -> Self {
62 Self {
63 thresholds: AnomalyScoringThresholds::default(),
64 }
65 }
66
67 pub fn with_thresholds(thresholds: AnomalyScoringThresholds) -> Self {
69 Self { thresholds }
70 }
71
72 pub fn analyze(&self, records: &[ScoredRecord]) -> EvalResult<AnomalyScoringAnalysis> {
74 let mut issues = Vec::new();
75 let total_records = records.len();
76
77 if records.is_empty() {
78 return Ok(AnomalyScoringAnalysis {
79 anomaly_separability: 0.0,
80 avg_anomaly_score: 0.0,
81 avg_normal_score: 0.0,
82 per_type_separability: Vec::new(),
83 total_records: 0,
84 passes: true,
85 issues: vec!["No records provided".to_string()],
86 });
87 }
88
89 let anomaly_scores: Vec<f64> = records
91 .iter()
92 .filter(|r| r.is_anomaly)
93 .map(|r| r.score)
94 .collect();
95 let normal_scores: Vec<f64> = records
96 .iter()
97 .filter(|r| !r.is_anomaly)
98 .map(|r| r.score)
99 .collect();
100
101 let avg_anomaly_score = if anomaly_scores.is_empty() {
102 0.0
103 } else {
104 anomaly_scores.iter().sum::<f64>() / anomaly_scores.len() as f64
105 };
106
107 let avg_normal_score = if normal_scores.is_empty() {
108 0.0
109 } else {
110 normal_scores.iter().sum::<f64>() / normal_scores.len() as f64
111 };
112
113 let anomaly_separability = if anomaly_scores.is_empty() || normal_scores.is_empty() {
115 issues.push("Need both anomaly and normal records for AUC-ROC".to_string());
116 0.5
117 } else {
118 self.compute_auc_roc(records)
119 };
120
121 if anomaly_separability < self.thresholds.min_anomaly_separability {
123 issues.push(format!(
124 "Anomaly separability {:.4} < {:.4} (threshold)",
125 anomaly_separability, self.thresholds.min_anomaly_separability
126 ));
127 }
128
129 let passes = issues.is_empty();
130
131 Ok(AnomalyScoringAnalysis {
132 anomaly_separability,
133 avg_anomaly_score,
134 avg_normal_score,
135 per_type_separability: Vec::new(),
136 total_records,
137 passes,
138 issues,
139 })
140 }
141
142 fn compute_auc_roc(&self, records: &[ScoredRecord]) -> f64 {
147 let total_positives = records.iter().filter(|r| r.is_anomaly).count();
148 let total_negatives = records.len() - total_positives;
149
150 if total_positives == 0 || total_negatives == 0 {
151 return 0.5;
152 }
153
154 let mut sorted: Vec<&ScoredRecord> = records.iter().collect();
156 sorted.sort_by(|a, b| {
157 b.score
158 .partial_cmp(&a.score)
159 .unwrap_or(std::cmp::Ordering::Equal)
160 });
161
162 let mut tp = 0usize;
163 let mut fp = 0usize;
164 let mut auc = 0.0;
165 let mut prev_fpr = 0.0;
166 let mut prev_tpr = 0.0;
167
168 for record in &sorted {
169 if record.is_anomaly {
170 tp += 1;
171 } else {
172 fp += 1;
173 }
174
175 let tpr = tp as f64 / total_positives as f64;
176 let fpr = fp as f64 / total_negatives as f64;
177
178 auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
180
181 prev_fpr = fpr;
182 prev_tpr = tpr;
183 }
184
185 auc
186 }
187}
188
189impl Default for AnomalyScoringAnalyzer {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[cfg(test)]
196#[allow(clippy::unwrap_used)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_valid_anomaly_scoring() {
202 let records = vec![
203 ScoredRecord {
204 record_id: "r1".to_string(),
205 score: 0.9,
206 is_anomaly: true,
207 },
208 ScoredRecord {
209 record_id: "r2".to_string(),
210 score: 0.85,
211 is_anomaly: true,
212 },
213 ScoredRecord {
214 record_id: "r3".to_string(),
215 score: 0.1,
216 is_anomaly: false,
217 },
218 ScoredRecord {
219 record_id: "r4".to_string(),
220 score: 0.15,
221 is_anomaly: false,
222 },
223 ScoredRecord {
224 record_id: "r5".to_string(),
225 score: 0.05,
226 is_anomaly: false,
227 },
228 ];
229
230 let analyzer = AnomalyScoringAnalyzer::new();
231 let result = analyzer.analyze(&records).unwrap();
232
233 assert_eq!(result.total_records, 5);
234 assert!(result.anomaly_separability > 0.7);
235 assert!(result.avg_anomaly_score > result.avg_normal_score);
236 assert!(result.passes);
237 }
238
239 #[test]
240 fn test_invalid_anomaly_scoring() {
241 let records = vec![
243 ScoredRecord {
244 record_id: "r1".to_string(),
245 score: 0.1,
246 is_anomaly: true,
247 },
248 ScoredRecord {
249 record_id: "r2".to_string(),
250 score: 0.05,
251 is_anomaly: true,
252 },
253 ScoredRecord {
254 record_id: "r3".to_string(),
255 score: 0.9,
256 is_anomaly: false,
257 },
258 ScoredRecord {
259 record_id: "r4".to_string(),
260 score: 0.85,
261 is_anomaly: false,
262 },
263 ];
264
265 let analyzer = AnomalyScoringAnalyzer::new();
266 let result = analyzer.analyze(&records).unwrap();
267
268 assert!(result.anomaly_separability < 0.7);
269 assert!(!result.passes);
270 }
271
272 #[test]
273 fn test_empty_records() {
274 let analyzer = AnomalyScoringAnalyzer::new();
275 let result = analyzer.analyze(&[]).unwrap();
276
277 assert_eq!(result.total_records, 0);
278 assert_eq!(result.anomaly_separability, 0.0);
279 }
280}