1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct SchemeRecord {
14 pub scheme_id: String,
16 pub difficulty: String,
18 pub detection_score: f64,
20}
21
22#[derive(Debug, Clone)]
24pub struct SchemeDetectabilityThresholds {
25 pub min_detectability_score: f64,
27}
28
29impl Default for SchemeDetectabilityThresholds {
30 fn default() -> Self {
31 Self {
32 min_detectability_score: 0.60,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SchemeDetectabilityAnalysis {
40 pub difficulty_ordering_valid: bool,
42 pub detectability_score: f64,
44 pub per_difficulty_rates: Vec<(String, f64)>,
46 pub total_schemes: usize,
48 pub passes: bool,
50 pub issues: Vec<String>,
52}
53
54pub struct SchemeDetectabilityAnalyzer {
56 thresholds: SchemeDetectabilityThresholds,
57}
58
59impl SchemeDetectabilityAnalyzer {
60 const DIFFICULTY_ORDER: &'static [&'static str] =
62 &["trivial", "easy", "moderate", "hard", "expert"];
63
64 pub fn new() -> Self {
66 Self {
67 thresholds: SchemeDetectabilityThresholds::default(),
68 }
69 }
70
71 pub fn with_thresholds(thresholds: SchemeDetectabilityThresholds) -> Self {
73 Self { thresholds }
74 }
75
76 pub fn analyze(&self, records: &[SchemeRecord]) -> EvalResult<SchemeDetectabilityAnalysis> {
78 let mut issues = Vec::new();
79 let total_schemes = records.len();
80
81 if records.is_empty() {
82 return Ok(SchemeDetectabilityAnalysis {
83 difficulty_ordering_valid: true,
84 detectability_score: 0.0,
85 per_difficulty_rates: Vec::new(),
86 total_schemes: 0,
87 passes: true,
88 issues: vec!["No scheme records provided".to_string()],
89 });
90 }
91
92 let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
94 for record in records {
95 groups
96 .entry(record.difficulty.clone())
97 .or_default()
98 .push(record.detection_score);
99 }
100
101 let per_difficulty_rates: Vec<(String, f64)> = Self::DIFFICULTY_ORDER
102 .iter()
103 .filter_map(|&d| {
104 groups.get(d).map(|scores| {
105 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
106 (d.to_string(), mean)
107 })
108 })
109 .collect();
110
111 let difficulty_ordering_valid = self.check_monotonic(&per_difficulty_rates);
113 if !difficulty_ordering_valid {
114 issues.push("Difficulty ordering is not monotonically decreasing".to_string());
115 }
116
117 let detectability_score = self.compute_spearman(records);
119
120 if detectability_score < self.thresholds.min_detectability_score {
121 issues.push(format!(
122 "Detectability score {:.4} < {:.4} (threshold)",
123 detectability_score, self.thresholds.min_detectability_score
124 ));
125 }
126
127 let passes = issues.is_empty();
128
129 Ok(SchemeDetectabilityAnalysis {
130 difficulty_ordering_valid,
131 detectability_score,
132 per_difficulty_rates,
133 total_schemes,
134 passes,
135 issues,
136 })
137 }
138
139 fn check_monotonic(&self, rates: &[(String, f64)]) -> bool {
141 if rates.len() < 2 {
142 return true;
143 }
144
145 for i in 1..rates.len() {
146 if rates[i].1 > rates[i - 1].1 {
147 return false;
148 }
149 }
150
151 true
152 }
153
154 fn compute_spearman(&self, records: &[SchemeRecord]) -> f64 {
160 let ordinal_map: HashMap<&str, f64> = Self::DIFFICULTY_ORDER
161 .iter()
162 .enumerate()
163 .map(|(i, &d)| (d, (i + 1) as f64))
164 .collect();
165
166 let pairs: Vec<(f64, f64)> = records
168 .iter()
169 .filter_map(|r| {
170 ordinal_map
171 .get(r.difficulty.as_str())
172 .map(|&ordinal| (ordinal, r.detection_score))
173 })
174 .collect();
175
176 if pairs.len() < 3 {
177 return 0.0;
178 }
179
180 let ordinals: Vec<f64> = pairs.iter().map(|(o, _)| *o).collect();
182 let scores: Vec<f64> = pairs.iter().map(|(_, s)| *s).collect();
183
184 let ranked_ord = compute_ranks(&ordinals);
185 let ranked_scores = compute_ranks(&scores);
186
187 let n = pairs.len() as f64;
189 let mean_o = ranked_ord.iter().sum::<f64>() / n;
190 let mean_s = ranked_scores.iter().sum::<f64>() / n;
191
192 let mut cov = 0.0;
193 let mut var_o = 0.0;
194 let mut var_s = 0.0;
195
196 for i in 0..pairs.len() {
197 let do_ = ranked_ord[i] - mean_o;
198 let ds = ranked_scores[i] - mean_s;
199 cov += do_ * ds;
200 var_o += do_ * do_;
201 var_s += ds * ds;
202 }
203
204 let denom = (var_o * var_s).sqrt();
205 if denom < 1e-12 {
206 return 0.0;
207 }
208
209 let rho = cov / denom;
212 (-rho).clamp(0.0, 1.0)
213 }
214}
215
216fn compute_ranks(values: &[f64]) -> Vec<f64> {
218 let n = values.len();
219 let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
220 indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
221
222 let mut ranks = vec![0.0; n];
223 let mut i = 0;
224 while i < n {
225 let mut j = i;
226 while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-12 {
228 j += 1;
229 }
230 let avg_rank = (i + j + 1) as f64 / 2.0; for k in i..j {
233 ranks[indexed[k].0] = avg_rank;
234 }
235 i = j;
236 }
237
238 ranks
239}
240
241impl Default for SchemeDetectabilityAnalyzer {
242 fn default() -> Self {
243 Self::new()
244 }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_valid_ordering() {
254 let records = vec![
255 SchemeRecord {
256 scheme_id: "s1".into(),
257 difficulty: "trivial".into(),
258 detection_score: 0.95,
259 },
260 SchemeRecord {
261 scheme_id: "s2".into(),
262 difficulty: "easy".into(),
263 detection_score: 0.80,
264 },
265 SchemeRecord {
266 scheme_id: "s3".into(),
267 difficulty: "moderate".into(),
268 detection_score: 0.60,
269 },
270 SchemeRecord {
271 scheme_id: "s4".into(),
272 difficulty: "hard".into(),
273 detection_score: 0.35,
274 },
275 SchemeRecord {
276 scheme_id: "s5".into(),
277 difficulty: "expert".into(),
278 detection_score: 0.10,
279 },
280 ];
281
282 let analyzer = SchemeDetectabilityAnalyzer::new();
283 let result = analyzer.analyze(&records).unwrap();
284
285 assert!(result.difficulty_ordering_valid);
286 assert!(result.detectability_score > 0.6);
287 assert!(result.passes);
288 }
289
290 #[test]
291 fn test_invalid_ordering() {
292 let records = vec![
294 SchemeRecord {
295 scheme_id: "s1".into(),
296 difficulty: "trivial".into(),
297 detection_score: 0.10,
298 },
299 SchemeRecord {
300 scheme_id: "s2".into(),
301 difficulty: "easy".into(),
302 detection_score: 0.30,
303 },
304 SchemeRecord {
305 scheme_id: "s3".into(),
306 difficulty: "moderate".into(),
307 detection_score: 0.50,
308 },
309 SchemeRecord {
310 scheme_id: "s4".into(),
311 difficulty: "hard".into(),
312 detection_score: 0.70,
313 },
314 SchemeRecord {
315 scheme_id: "s5".into(),
316 difficulty: "expert".into(),
317 detection_score: 0.90,
318 },
319 ];
320
321 let analyzer = SchemeDetectabilityAnalyzer::new();
322 let result = analyzer.analyze(&records).unwrap();
323
324 assert!(!result.difficulty_ordering_valid);
325 assert!(!result.passes);
326 }
327
328 #[test]
329 fn test_empty_schemes() {
330 let analyzer = SchemeDetectabilityAnalyzer::new();
331 let result = analyzer.analyze(&[]).unwrap();
332
333 assert_eq!(result.total_schemes, 0);
334 assert!(result.difficulty_ordering_valid);
335 }
336}