touchstone_rs/metrics/
thresholding.rs1pub trait Threshold: Send + Sync {
3 fn threshold(&self, scores: &[f32]) -> f32;
5 #[allow(dead_code)]
7 fn name(&self) -> &str;
8}
9
10#[allow(dead_code)]
12pub struct FixedValueThreshold(pub f32);
13
14pub struct PercentileThreshold(pub f64);
16
17#[allow(dead_code)]
19pub struct SigmaThreshold(pub f64);
20
21impl Threshold for FixedValueThreshold {
22 fn threshold(&self, _scores: &[f32]) -> f32 {
23 self.0
24 }
25 fn name(&self) -> &str {
26 "fixed"
27 }
28}
29
30impl Threshold for PercentileThreshold {
31 fn threshold(&self, scores: &[f32]) -> f32 {
32 let mut sorted = scores.to_vec();
33 sorted.sort_by(|a, b| a.total_cmp(b));
34 let idx = ((self.0 / 100.0) * (sorted.len() - 1) as f64).round() as usize;
35 sorted[idx.min(sorted.len() - 1)]
36 }
37 fn name(&self) -> &str {
38 "percentile"
39 }
40}
41
42impl Threshold for SigmaThreshold {
43 fn threshold(&self, scores: &[f32]) -> f32 {
44 let n = scores.len() as f64;
45 let mean = scores.iter().map(|&s| s as f64).sum::<f64>() / n;
46 let var = scores
47 .iter()
48 .map(|&s| (s as f64 - mean).powi(2))
49 .sum::<f64>()
50 / n;
51 (mean + self.0 * var.sqrt()) as f32
52 }
53 fn name(&self) -> &str {
54 "sigma"
55 }
56}
57
58pub(crate) fn apply_threshold(scores: &[f32], thresh: f32) -> Vec<u8> {
60 scores
61 .iter()
62 .map(|&s| if s >= thresh { 1 } else { 0 })
63 .collect()
64}