use crate::error::EvalResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SchemeRecord {
pub scheme_id: String,
pub difficulty: String,
pub detection_score: f64,
}
#[derive(Debug, Clone)]
pub struct SchemeDetectabilityThresholds {
pub min_detectability_score: f64,
}
impl Default for SchemeDetectabilityThresholds {
fn default() -> Self {
Self {
min_detectability_score: 0.60,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemeDetectabilityAnalysis {
pub difficulty_ordering_valid: bool,
pub detectability_score: f64,
pub per_difficulty_rates: Vec<(String, f64)>,
pub total_schemes: usize,
pub passes: bool,
pub issues: Vec<String>,
}
pub struct SchemeDetectabilityAnalyzer {
thresholds: SchemeDetectabilityThresholds,
}
impl SchemeDetectabilityAnalyzer {
const DIFFICULTY_ORDER: &'static [&'static str] =
&["trivial", "easy", "moderate", "hard", "expert"];
pub fn new() -> Self {
Self {
thresholds: SchemeDetectabilityThresholds::default(),
}
}
pub fn with_thresholds(thresholds: SchemeDetectabilityThresholds) -> Self {
Self { thresholds }
}
pub fn analyze(&self, records: &[SchemeRecord]) -> EvalResult<SchemeDetectabilityAnalysis> {
let mut issues = Vec::new();
let total_schemes = records.len();
if records.is_empty() {
return Ok(SchemeDetectabilityAnalysis {
difficulty_ordering_valid: true,
detectability_score: 0.0,
per_difficulty_rates: Vec::new(),
total_schemes: 0,
passes: true,
issues: vec!["No scheme records provided".to_string()],
});
}
let mut groups: HashMap<String, Vec<f64>> = HashMap::new();
for record in records {
groups
.entry(record.difficulty.clone())
.or_default()
.push(record.detection_score);
}
let per_difficulty_rates: Vec<(String, f64)> = Self::DIFFICULTY_ORDER
.iter()
.filter_map(|&d| {
groups.get(d).map(|scores| {
let mean = scores.iter().sum::<f64>() / scores.len() as f64;
(d.to_string(), mean)
})
})
.collect();
let difficulty_ordering_valid = self.check_monotonic(&per_difficulty_rates);
if !difficulty_ordering_valid {
issues.push("Difficulty ordering is not monotonically decreasing".to_string());
}
let detectability_score = self.compute_spearman(records);
if detectability_score < self.thresholds.min_detectability_score {
issues.push(format!(
"Detectability score {:.4} < {:.4} (threshold)",
detectability_score, self.thresholds.min_detectability_score
));
}
let passes = issues.is_empty();
Ok(SchemeDetectabilityAnalysis {
difficulty_ordering_valid,
detectability_score,
per_difficulty_rates,
total_schemes,
passes,
issues,
})
}
fn check_monotonic(&self, rates: &[(String, f64)]) -> bool {
if rates.len() < 2 {
return true;
}
for i in 1..rates.len() {
if rates[i].1 > rates[i - 1].1 {
return false;
}
}
true
}
fn compute_spearman(&self, records: &[SchemeRecord]) -> f64 {
let ordinal_map: HashMap<&str, f64> = Self::DIFFICULTY_ORDER
.iter()
.enumerate()
.map(|(i, &d)| (d, (i + 1) as f64))
.collect();
let pairs: Vec<(f64, f64)> = records
.iter()
.filter_map(|r| {
ordinal_map
.get(r.difficulty.as_str())
.map(|&ordinal| (ordinal, r.detection_score))
})
.collect();
if pairs.len() < 3 {
return 0.0;
}
let ordinals: Vec<f64> = pairs.iter().map(|(o, _)| *o).collect();
let scores: Vec<f64> = pairs.iter().map(|(_, s)| *s).collect();
let ranked_ord = compute_ranks(&ordinals);
let ranked_scores = compute_ranks(&scores);
let n = pairs.len() as f64;
let mean_o = ranked_ord.iter().sum::<f64>() / n;
let mean_s = ranked_scores.iter().sum::<f64>() / n;
let mut cov = 0.0;
let mut var_o = 0.0;
let mut var_s = 0.0;
for i in 0..pairs.len() {
let do_ = ranked_ord[i] - mean_o;
let ds = ranked_scores[i] - mean_s;
cov += do_ * ds;
var_o += do_ * do_;
var_s += ds * ds;
}
let denom = (var_o * var_s).sqrt();
if denom < 1e-12 {
return 0.0;
}
let rho = cov / denom;
(-rho).clamp(0.0, 1.0)
}
}
fn compute_ranks(values: &[f64]) -> Vec<f64> {
let n = values.len();
let mut indexed: Vec<(usize, f64)> = values.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranks = vec![0.0; n];
let mut i = 0;
while i < n {
let mut j = i;
while j < n && (indexed[j].1 - indexed[i].1).abs() < 1e-12 {
j += 1;
}
let avg_rank = (i + j + 1) as f64 / 2.0; for k in i..j {
ranks[indexed[k].0] = avg_rank;
}
i = j;
}
ranks
}
impl Default for SchemeDetectabilityAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_valid_ordering() {
let records = vec![
SchemeRecord {
scheme_id: "s1".into(),
difficulty: "trivial".into(),
detection_score: 0.95,
},
SchemeRecord {
scheme_id: "s2".into(),
difficulty: "easy".into(),
detection_score: 0.80,
},
SchemeRecord {
scheme_id: "s3".into(),
difficulty: "moderate".into(),
detection_score: 0.60,
},
SchemeRecord {
scheme_id: "s4".into(),
difficulty: "hard".into(),
detection_score: 0.35,
},
SchemeRecord {
scheme_id: "s5".into(),
difficulty: "expert".into(),
detection_score: 0.10,
},
];
let analyzer = SchemeDetectabilityAnalyzer::new();
let result = analyzer.analyze(&records).unwrap();
assert!(result.difficulty_ordering_valid);
assert!(result.detectability_score > 0.6);
assert!(result.passes);
}
#[test]
fn test_invalid_ordering() {
let records = vec![
SchemeRecord {
scheme_id: "s1".into(),
difficulty: "trivial".into(),
detection_score: 0.10,
},
SchemeRecord {
scheme_id: "s2".into(),
difficulty: "easy".into(),
detection_score: 0.30,
},
SchemeRecord {
scheme_id: "s3".into(),
difficulty: "moderate".into(),
detection_score: 0.50,
},
SchemeRecord {
scheme_id: "s4".into(),
difficulty: "hard".into(),
detection_score: 0.70,
},
SchemeRecord {
scheme_id: "s5".into(),
difficulty: "expert".into(),
detection_score: 0.90,
},
];
let analyzer = SchemeDetectabilityAnalyzer::new();
let result = analyzer.analyze(&records).unwrap();
assert!(!result.difficulty_ordering_valid);
assert!(!result.passes);
}
#[test]
fn test_empty_schemes() {
let analyzer = SchemeDetectabilityAnalyzer::new();
let result = analyzer.analyze(&[]).unwrap();
assert_eq!(result.total_schemes, 0);
assert!(result.difficulty_ordering_valid);
}
}