1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct SchemeRecord {
14 pub scheme_id: String,
16 pub difficulty: String,
18 pub detection_score: f64,
20}
21
22#[derive(Debug, Clone)]
24pub struct SchemeDetectabilityThresholds {
25 pub min_detectability_score: f64,
27}
28
29impl Default for SchemeDetectabilityThresholds {
30 fn default() -> Self {
31 Self {
32 min_detectability_score: 0.60,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SchemeDetectabilityAnalysis {
40 pub difficulty_ordering_valid: bool,
42 pub detectability_score: f64,
44 pub per_difficulty_rates: Vec<(String, f64)>,
46 pub total_schemes: usize,
48 pub passes: bool,
50 pub issues: Vec<String>,
52}
53
54pub struct SchemeDetectabilityAnalyzer {
56 thresholds: SchemeDetectabilityThresholds,
57}
58
59impl SchemeDetectabilityAnalyzer {
60 const DIFFICULTY_ORDER: &'static [&'static str] =
62 &["trivial", "easy", "moderate", "hard", "expert"];
63
64 pub fn new() -> Self {
66 Self {
67 thresholds: SchemeDetectabilityThresholds::default(),
68 }
69 }
70
71 pub fn with_thresholds(thresholds: SchemeDetectabilityThresholds) -> Self {
73 Self { thresholds }
74 }
75
76 pub fn analyze(&self, records: &[SchemeRecord]) -> EvalResult<SchemeDetectabilityAnalysis> {
78 let mut issues = Vec::new();
79 let total_schemes = records.len();
80
81 if records.is_empty() {
82 return Ok(SchemeDetectabilityAnalysis {
83 difficulty_ordering_valid: true,
84 detectability_score: 0.0,
85 per_difficulty_rates: Vec::new(),
86 total_schemes: 0,
87 passes: true,
88 issues: vec!["No scheme records provided".to_string()],
89 });
90 }
91
92 let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
94 for record in records {
95 groups
96 .entry(record.difficulty.clone())
97 .or_default()
98 .push(record.detection_score);
99 }
100
101 let per_difficulty_rates: Vec<(String, f64)> = Self::DIFFICULTY_ORDER
102 .iter()
103 .filter_map(|&d| {
104 groups.get(d).map(|scores| {
105 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
106 (d.to_string(), mean)
107 })
108 })
109 .collect();
110
111 let difficulty_ordering_valid = self.check_monotonic(&per_difficulty_rates);
113 if !difficulty_ordering_valid {
114 issues.push("Difficulty ordering is not monotonically decreasing".to_string());
115 }
116
117 let detectability_score = self.compute_spearman(records);
119
120 if detectability_score < self.thresholds.min_detectability_score {
121 issues.push(format!(
122 "Detectability score {:.4} < {:.4} (threshold)",
123 detectability_score, self.thresholds.min_detectability_score
124 ));
125 }
126
127 let passes = issues.is_empty();
128
129 Ok(SchemeDetectabilityAnalysis {
130 difficulty_ordering_valid,
131 detectability_score,
132 per_difficulty_rates,
133 total_schemes,
134 passes,
135 issues,
136 })
137 }
138
139 fn check_monotonic(&self, rates: &[(String, f64)]) -> bool {
141 if rates.len() < 2 {
142 return true;
143 }
144
145 for i in 1..rates.len() {
146 if rates[i].1 > rates[i - 1].1 {
147 return false;
148 }
149 }
150
151 true
152 }
153
154 fn compute_spearman(&self, records: &[SchemeRecord]) -> f64 {
160 let ordinal_map: HashMap<&str, f64> = Self::DIFFICULTY_ORDER
161 .iter()
162 .enumerate()
163 .map(|(i, &d)| (d, (i + 1) as f64))
164 .collect();
165
166 let pairs: Vec<(f64, f64)> = records
168 .iter()
169 .filter_map(|r| {
170 ordinal_map
171 .get(r.difficulty.as_str())
172 .map(|&ordinal| (ordinal, r.detection_score))
173 })
174 .collect();
175
176 if pairs.len() < 3 {
177 return 0.0;
178 }
179
180 let ordinals: Vec<f64> = pairs.iter().map(|(o, _)| *o).collect();
182 let scores: Vec<f64> = pairs.iter().map(|(_, s)| *s).collect();
183
184 let ranked_ord = compute_ranks(&ordinals);
185 let ranked_scores = compute_ranks(&scores);
186
187 let n = pairs.len() as f64;
189 let mean_o = ranked_ord.iter().sum::<f64>() / n;
190 let mean_s = ranked_scores.iter().sum::<f64>() / n;
191
192 let mut cov = 0.0;
193 let mut var_o = 0.0;
194 let mut var_s = 0.0;
195
196 for i in 0..pairs.len() {
197 let do_ = ranked_ord[i] - mean_o;
198 let ds = ranked_scores[i] - mean_s;
199 cov += do_ * ds;
200 var_o += do_ * do_;
201 var_s += ds * ds;
202 }
203
204 let denom = (var_o * var_s).sqrt();
205 if denom < 1e-12 {
206 return 0.0;
207 }
208
209 let rho = cov / denom;
212 (-rho).clamp(0.0, 1.0)
213 }
214}
215
216fn compute_ranks(values: &[f64]) -> Vec<f64> {
218 let n = values.len();
219 let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
220 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
221
222 let mut ranks = vec![0.0; n];
223 let mut i = 0;
224 while i < n {
225 let mut j = i;
226 while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-12 {
228 j += 1;
229 }
230 let avg_rank = (i + j + 1) as f64 / 2.0; for k in i..j {
233 ranks[indexed[k].0] = avg_rank;
234 }
235 i = j;
236 }
237
238 ranks
239}
240
241impl Default for SchemeDetectabilityAnalyzer {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_valid_ordering() {
253 let records = vec![
254 SchemeRecord {
255 scheme_id: "s1".into(),
256 difficulty: "trivial".into(),
257 detection_score: 0.95,
258 },
259 SchemeRecord {
260 scheme_id: "s2".into(),
261 difficulty: "easy".into(),
262 detection_score: 0.80,
263 },
264 SchemeRecord {
265 scheme_id: "s3".into(),
266 difficulty: "moderate".into(),
267 detection_score: 0.60,
268 },
269 SchemeRecord {
270 scheme_id: "s4".into(),
271 difficulty: "hard".into(),
272 detection_score: 0.35,
273 },
274 SchemeRecord {
275 scheme_id: "s5".into(),
276 difficulty: "expert".into(),
277 detection_score: 0.10,
278 },
279 ];
280
281 let analyzer = SchemeDetectabilityAnalyzer::new();
282 let result = analyzer.analyze(&records).unwrap();
283
284 assert!(result.difficulty_ordering_valid);
285 assert!(result.detectability_score > 0.6);
286 assert!(result.passes);
287 }
288
289 #[test]
290 fn test_invalid_ordering() {
291 let records = vec![
293 SchemeRecord {
294 scheme_id: "s1".into(),
295 difficulty: "trivial".into(),
296 detection_score: 0.10,
297 },
298 SchemeRecord {
299 scheme_id: "s2".into(),
300 difficulty: "easy".into(),
301 detection_score: 0.30,
302 },
303 SchemeRecord {
304 scheme_id: "s3".into(),
305 difficulty: "moderate".into(),
306 detection_score: 0.50,
307 },
308 SchemeRecord {
309 scheme_id: "s4".into(),
310 difficulty: "hard".into(),
311 detection_score: 0.70,
312 },
313 SchemeRecord {
314 scheme_id: "s5".into(),
315 difficulty: "expert".into(),
316 detection_score: 0.90,
317 },
318 ];
319
320 let analyzer = SchemeDetectabilityAnalyzer::new();
321 let result = analyzer.analyze(&records).unwrap();
322
323 assert!(!result.difficulty_ordering_valid);
324 assert!(!result.passes);
325 }
326
327 #[test]
328 fn test_empty_schemes() {
329 let analyzer = SchemeDetectabilityAnalyzer::new();
330 let result = analyzer.analyze(&[]).unwrap();
331
332 assert_eq!(result.total_schemes, 0);
333 assert!(result.difficulty_ordering_valid);
334 }
335}