Skip to main content

datasynth_eval/ml/
scheme_detectability.rs

1//! Scheme detectability evaluation.
2//!
3//! Validates that injected fraud schemes follow an expected difficulty ordering:
4//! trivial > easy > moderate > hard > expert in terms of detection score,
5//! and computes Spearman rank correlation.
6
7use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// A single scheme record with difficulty and detection score.
12#[derive(Debug, Clone)]
13pub struct SchemeRecord {
14    /// Unique identifier for this scheme instance.
15    pub scheme_id: String,
16    /// Difficulty level (e.g. "trivial", "easy", "moderate", "hard", "expert").
17    pub difficulty: String,
18    /// Detection score: probability that the scheme is detected (0.0-1.0).
19    pub detection_score: f64,
20}
21
22/// Thresholds for scheme detectability analysis.
23#[derive(Debug, Clone)]
24pub struct SchemeDetectabilityThresholds {
25    /// Minimum Spearman correlation for detectability ordering.
26    pub min_detectability_score: f64,
27}
28
29impl Default for SchemeDetectabilityThresholds {
30    fn default() -> Self {
31        Self {
32            min_detectability_score: 0.60,
33        }
34    }
35}
36
37/// Results of scheme detectability analysis.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SchemeDetectabilityAnalysis {
40    /// Whether difficulty ordering is monotonically valid.
41    pub difficulty_ordering_valid: bool,
42    /// Spearman rank correlation between difficulty ordinal and detection rate.
43    pub detectability_score: f64,
44    /// Mean detection rate per difficulty level.
45    pub per_difficulty_rates: Vec<(String, f64)>,
46    /// Total number of scheme records analyzed.
47    pub total_schemes: usize,
48    /// Whether the analysis passes all thresholds.
49    pub passes: bool,
50    /// Issues found during analysis.
51    pub issues: Vec<String>,
52}
53
54/// Analyzer for scheme detectability.
55pub struct SchemeDetectabilityAnalyzer {
56    thresholds: SchemeDetectabilityThresholds,
57}
58
59impl SchemeDetectabilityAnalyzer {
60    /// Canonical difficulty ordering (from most detectable to least).
61    const DIFFICULTY_ORDER: &'static [&'static str] =
62        &["trivial", "easy", "moderate", "hard", "expert"];
63
64    /// Create a new analyzer with default thresholds.
65    pub fn new() -> Self {
66        Self {
67            thresholds: SchemeDetectabilityThresholds::default(),
68        }
69    }
70
71    /// Create an analyzer with custom thresholds.
72    pub fn with_thresholds(thresholds: SchemeDetectabilityThresholds) -> Self {
73        Self { thresholds }
74    }
75
76    /// Analyze scheme detectability.
77    pub fn analyze(&self, records: &[SchemeRecord]) -> EvalResult<SchemeDetectabilityAnalysis> {
78        let mut issues = Vec::new();
79        let total_schemes = records.len();
80
81        if records.is_empty() {
82            return Ok(SchemeDetectabilityAnalysis {
83                difficulty_ordering_valid: true,
84                detectability_score: 0.0,
85                per_difficulty_rates: Vec::new(),
86                total_schemes: 0,
87                passes: true,
88                issues: vec!["No scheme records provided".to_string()],
89            });
90        }
91
92        // Group by difficulty and compute mean detection score
93        let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
94        for record in records {
95            groups
96                .entry(record.difficulty.clone())
97                .or_default()
98                .push(record.detection_score);
99        }
100
101        let per_difficulty_rates: Vec<(String, f64)> = Self::DIFFICULTY_ORDER
102            .iter()
103            .filter_map(|&d| {
104                groups.get(d).map(|scores| {
105                    let mean = scores.iter().sum::<f64>() / scores.len() as f64;
106                    (d.to_string(), mean)
107                })
108            })
109            .collect();
110
111        // Check monotonic ordering: trivial should have highest detection rate
112        let difficulty_ordering_valid = self.check_monotonic(&per_difficulty_rates);
113        if !difficulty_ordering_valid {
114            issues.push("Difficulty ordering is not monotonically decreasing".to_string());
115        }
116
117        // Compute Spearman rank correlation
118        let detectability_score = self.compute_spearman(records);
119
120        if detectability_score < self.thresholds.min_detectability_score {
121            issues.push(format!(
122                "Detectability score {:.4} < {:.4} (threshold)",
123                detectability_score, self.thresholds.min_detectability_score
124            ));
125        }
126
127        let passes = issues.is_empty();
128
129        Ok(SchemeDetectabilityAnalysis {
130            difficulty_ordering_valid,
131            detectability_score,
132            per_difficulty_rates,
133            total_schemes,
134            passes,
135            issues,
136        })
137    }
138
139    /// Check if per-difficulty rates are monotonically decreasing.
140    fn check_monotonic(&self, rates: &[(String, f64)]) -> bool {
141        if rates.len() < 2 {
142            return true;
143        }
144
145        for i in 1..rates.len() {
146            if rates[i].1 > rates[i - 1].1 {
147                return false;
148            }
149        }
150
151        true
152    }
153
154    /// Compute Spearman rank correlation between difficulty ordinal and detection rate.
155    ///
156    /// Assigns ordinal ranks to difficulties (trivial=1, easy=2, etc.) and
157    /// correlates with detection scores. A positive correlation means
158    /// easier schemes have higher detection rates (expected).
159    fn compute_spearman(&self, records: &[SchemeRecord]) -> f64 {
160        let ordinal_map: HashMap<&str, f64> = Self::DIFFICULTY_ORDER
161            .iter()
162            .enumerate()
163            .map(|(i, &d)| (d, (i + 1) as f64))
164            .collect();
165
166        // Filter to records with known difficulty levels
167        let pairs: Vec<(f64, f64)> = records
168            .iter()
169            .filter_map(|r| {
170                ordinal_map
171                    .get(r.difficulty.as_str())
172                    .map(|&ordinal| (ordinal, r.detection_score))
173            })
174            .collect();
175
176        if pairs.len() < 3 {
177            return 0.0;
178        }
179
180        // Rank both columns
181        let ordinals: Vec<f64> = pairs.iter().map(|(o, _)| *o).collect();
182        let scores: Vec<f64> = pairs.iter().map(|(_, s)| *s).collect();
183
184        let ranked_ord = compute_ranks(&ordinals);
185        let ranked_scores = compute_ranks(&scores);
186
187        // Pearson correlation on ranks
188        let n = pairs.len() as f64;
189        let mean_o = ranked_ord.iter().sum::<f64>() / n;
190        let mean_s = ranked_scores.iter().sum::<f64>() / n;
191
192        let mut cov = 0.0;
193        let mut var_o = 0.0;
194        let mut var_s = 0.0;
195
196        for i in 0..pairs.len() {
197            let do_ = ranked_ord[i] - mean_o;
198            let ds = ranked_scores[i] - mean_s;
199            cov += do_ * ds;
200            var_o += do_ * do_;
201            var_s += ds * ds;
202        }
203
204        let denom = (var_o * var_s).sqrt();
205        if denom < 1e-12 {
206            return 0.0;
207        }
208
209        // We expect negative correlation (higher ordinal = harder = lower detection),
210        // so we negate to get a positive "detectability score"
211        let rho = cov / denom;
212        (-rho).clamp(0.0, 1.0)
213    }
214}
215
216/// Compute ranks for a vector of values (average ranks for ties).
217fn compute_ranks(values: &[f64]) -> Vec<f64> {
218    let n = values.len();
219    let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
220    indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
221
222    let mut ranks = vec![0.0; n];
223    let mut i = 0;
224    while i < n {
225        let mut j = i;
226        // Find all tied values
227        while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-12 {
228            j += 1;
229        }
230        // Average rank for the tie group
231        let avg_rank = (i + j + 1) as f64 / 2.0; // 1-based
232        for k in i..j {
233            ranks[indexed[k].0] = avg_rank;
234        }
235        i = j;
236    }
237
238    ranks
239}
240
241impl Default for SchemeDetectabilityAnalyzer {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_valid_ordering() {
254        let records = vec![
255            SchemeRecord {
256                scheme_id: "s1".into(),
257                difficulty: "trivial".into(),
258                detection_score: 0.95,
259            },
260            SchemeRecord {
261                scheme_id: "s2".into(),
262                difficulty: "easy".into(),
263                detection_score: 0.80,
264            },
265            SchemeRecord {
266                scheme_id: "s3".into(),
267                difficulty: "moderate".into(),
268                detection_score: 0.60,
269            },
270            SchemeRecord {
271                scheme_id: "s4".into(),
272                difficulty: "hard".into(),
273                detection_score: 0.35,
274            },
275            SchemeRecord {
276                scheme_id: "s5".into(),
277                difficulty: "expert".into(),
278                detection_score: 0.10,
279            },
280        ];
281
282        let analyzer = SchemeDetectabilityAnalyzer::new();
283        let result = analyzer.analyze(&records).unwrap();
284
285        assert!(result.difficulty_ordering_valid);
286        assert!(result.detectability_score > 0.6);
287        assert!(result.passes);
288    }
289
290    #[test]
291    fn test_invalid_ordering() {
292        // Inverted: trivial has lowest detection, expert has highest
293        let records = vec![
294            SchemeRecord {
295                scheme_id: "s1".into(),
296                difficulty: "trivial".into(),
297                detection_score: 0.10,
298            },
299            SchemeRecord {
300                scheme_id: "s2".into(),
301                difficulty: "easy".into(),
302                detection_score: 0.30,
303            },
304            SchemeRecord {
305                scheme_id: "s3".into(),
306                difficulty: "moderate".into(),
307                detection_score: 0.50,
308            },
309            SchemeRecord {
310                scheme_id: "s4".into(),
311                difficulty: "hard".into(),
312                detection_score: 0.70,
313            },
314            SchemeRecord {
315                scheme_id: "s5".into(),
316                difficulty: "expert".into(),
317                detection_score: 0.90,
318            },
319        ];
320
321        let analyzer = SchemeDetectabilityAnalyzer::new();
322        let result = analyzer.analyze(&records).unwrap();
323
324        assert!(!result.difficulty_ordering_valid);
325        assert!(!result.passes);
326    }
327
328    #[test]
329    fn test_empty_schemes() {
330        let analyzer = SchemeDetectabilityAnalyzer::new();
331        let result = analyzer.analyze(&[]).unwrap();
332
333        assert_eq!(result.total_schemes, 0);
334        assert!(result.difficulty_ordering_valid);
335    }
336}