Skip to main content

scry_learn/metrics/
roc.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! ROC and Precision-Recall curve computation.
3
4/// A receiver operating characteristic curve.
5#[derive(Clone, Debug)]
6#[non_exhaustive]
7pub struct RocCurve {
8    /// False positive rates.
9    pub fpr: Vec<f64>,
10    /// True positive rates.
11    pub tpr: Vec<f64>,
12    /// Thresholds.
13    pub thresholds: Vec<f64>,
14    /// Area under the ROC curve.
15    pub auc: f64,
16}
17
18/// A precision-recall curve.
19#[derive(Clone, Debug)]
20#[non_exhaustive]
21pub struct PrCurve {
22    /// Precision values.
23    pub precision: Vec<f64>,
24    /// Recall values.
25    pub recall: Vec<f64>,
26    /// Thresholds.
27    pub thresholds: Vec<f64>,
28    /// Average precision (area under PR curve).
29    pub avg_precision: f64,
30}
31
32impl RocCurve {
33    /// Create a new ROC curve from precomputed values.
34    pub fn new(fpr: Vec<f64>, tpr: Vec<f64>, thresholds: Vec<f64>, auc: f64) -> Self {
35        Self {
36            fpr,
37            tpr,
38            thresholds,
39            auc,
40        }
41    }
42}
43
44impl PrCurve {
45    /// Create a new precision-recall curve from precomputed values.
46    pub fn new(
47        precision: Vec<f64>,
48        recall: Vec<f64>,
49        thresholds: Vec<f64>,
50        avg_precision: f64,
51    ) -> Self {
52        Self {
53            precision,
54            recall,
55            thresholds,
56            avg_precision,
57        }
58    }
59}
60
61/// Compute the ROC curve and AUC.
62///
63/// `y_true` should be binary (0.0 or 1.0).
64/// `y_scores` should be continuous scores (e.g., predicted probabilities).
65///
66/// Returns `auc = NaN` when only one class is present (ROC is undefined).
67pub fn roc_curve(y_true: &[f64], y_scores: &[f64]) -> RocCurve {
68    let n = y_true.len();
69    let pos_count = y_true.iter().filter(|&&v| v > 0.5).count();
70    let neg_count = n - pos_count;
71
72    // ROC is undefined when only one class is present.
73    if pos_count == 0 || neg_count == 0 {
74        return RocCurve {
75            fpr: vec![0.0],
76            tpr: vec![0.0],
77            thresholds: vec![],
78            auc: f64::NAN,
79        };
80    }
81
82    // Sort by descending score.
83    let mut indices: Vec<usize> = (0..n).collect();
84    indices.sort_unstable_by(|&a, &b| {
85        y_scores[b]
86            .partial_cmp(&y_scores[a])
87            .unwrap_or(std::cmp::Ordering::Equal)
88    });
89
90    let mut fpr = vec![0.0];
91    let mut tpr = vec![0.0];
92    let mut thresholds = Vec::new();
93    let mut tp = 0;
94    let mut fp = 0;
95
96    for &i in &indices {
97        if y_true[i] > 0.5 {
98            tp += 1;
99        } else {
100            fp += 1;
101        }
102        let current_tpr = if pos_count > 0 {
103            tp as f64 / pos_count as f64
104        } else {
105            0.0
106        };
107        let current_fpr = if neg_count > 0 {
108            fp as f64 / neg_count as f64
109        } else {
110            0.0
111        };
112
113        fpr.push(current_fpr);
114        tpr.push(current_tpr);
115        thresholds.push(y_scores[i]);
116    }
117
118    // Compute AUC via trapezoidal rule.
119    let auc = compute_auc(&fpr, &tpr);
120
121    RocCurve {
122        fpr,
123        tpr,
124        thresholds,
125        auc,
126    }
127}
128
129/// Compute the area under the ROC curve directly.
130pub fn roc_auc_score(y_true: &[f64], y_scores: &[f64]) -> f64 {
131    roc_curve(y_true, y_scores).auc
132}
133
134/// Compute the precision-recall curve.
135pub fn pr_curve(y_true: &[f64], y_scores: &[f64]) -> PrCurve {
136    let n = y_true.len();
137    let pos_count = y_true.iter().filter(|&&v| v > 0.5).count();
138
139    let mut indices: Vec<usize> = (0..n).collect();
140    indices.sort_unstable_by(|&a, &b| {
141        y_scores[b]
142            .partial_cmp(&y_scores[a])
143            .unwrap_or(std::cmp::Ordering::Equal)
144    });
145
146    let mut prec = vec![1.0];
147    let mut rec = vec![0.0];
148    let mut thresholds = Vec::new();
149    let mut tp = 0;
150    let mut fp = 0;
151
152    for &i in &indices {
153        if y_true[i] > 0.5 {
154            tp += 1;
155        } else {
156            fp += 1;
157        }
158        let p = tp as f64 / (tp + fp) as f64;
159        let r = if pos_count > 0 {
160            tp as f64 / pos_count as f64
161        } else {
162            0.0
163        };
164        prec.push(p);
165        rec.push(r);
166        thresholds.push(y_scores[i]);
167    }
168
169    let avg_precision = compute_auc(&rec, &prec);
170
171    PrCurve {
172        precision: prec,
173        recall: rec,
174        thresholds,
175        avg_precision,
176    }
177}
178
179/// Trapezoidal AUC computation.
180///
181/// Assumes `x` is monotonically increasing. Uses signed deltas so that
182/// non-monotonic input doesn't silently produce inflated areas.
183fn compute_auc(x: &[f64], y: &[f64]) -> f64 {
184    let mut area = 0.0;
185    for i in 1..x.len() {
186        let dx = x[i] - x[i - 1];
187        area += dx * (y[i] + y[i - 1]) / 2.0;
188    }
189    area
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_roc_auc_perfect() {
198        let y_true = vec![0.0, 0.0, 1.0, 1.0];
199        let y_scores = vec![0.1, 0.2, 0.8, 0.9];
200        let auc = roc_auc_score(&y_true, &y_scores);
201        assert!(
202            (auc - 1.0).abs() < 1e-6,
203            "perfect separation should give AUC=1.0, got {auc}"
204        );
205    }
206
207    #[test]
208    fn test_roc_auc_random() {
209        // Random ordering — AUC should be around 0.5.
210        let y_true = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
211        let y_scores = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
212        let auc = roc_auc_score(&y_true, &y_scores);
213        assert!(
214            (0.0..=1.0).contains(&auc),
215            "AUC should be in [0,1], got {auc}"
216        );
217    }
218
219    #[test]
220    fn test_roc_curve_length() {
221        let roc = roc_curve(&[0.0, 1.0, 0.0, 1.0], &[0.1, 0.9, 0.2, 0.8]);
222        assert_eq!(roc.fpr.len(), roc.tpr.len());
223        assert!(roc.fpr.len() > 2);
224    }
225}