scry_learn/metrics/
roc.rs1#[derive(Clone, Debug)]
6#[non_exhaustive]
7pub struct RocCurve {
8 pub fpr: Vec<f64>,
10 pub tpr: Vec<f64>,
12 pub thresholds: Vec<f64>,
14 pub auc: f64,
16}
17
18#[derive(Clone, Debug)]
20#[non_exhaustive]
21pub struct PrCurve {
22 pub precision: Vec<f64>,
24 pub recall: Vec<f64>,
26 pub thresholds: Vec<f64>,
28 pub avg_precision: f64,
30}
31
32impl RocCurve {
33 pub fn new(fpr: Vec<f64>, tpr: Vec<f64>, thresholds: Vec<f64>, auc: f64) -> Self {
35 Self {
36 fpr,
37 tpr,
38 thresholds,
39 auc,
40 }
41 }
42}
43
44impl PrCurve {
45 pub fn new(
47 precision: Vec<f64>,
48 recall: Vec<f64>,
49 thresholds: Vec<f64>,
50 avg_precision: f64,
51 ) -> Self {
52 Self {
53 precision,
54 recall,
55 thresholds,
56 avg_precision,
57 }
58 }
59}
60
61pub fn roc_curve(y_true: &[f64], y_scores: &[f64]) -> RocCurve {
68 let n = y_true.len();
69 let pos_count = y_true.iter().filter(|&&v| v > 0.5).count();
70 let neg_count = n - pos_count;
71
72 if pos_count == 0 || neg_count == 0 {
74 return RocCurve {
75 fpr: vec![0.0],
76 tpr: vec![0.0],
77 thresholds: vec![],
78 auc: f64::NAN,
79 };
80 }
81
82 let mut indices: Vec<usize> = (0..n).collect();
84 indices.sort_unstable_by(|&a, &b| {
85 y_scores[b]
86 .partial_cmp(&y_scores[a])
87 .unwrap_or(std::cmp::Ordering::Equal)
88 });
89
90 let mut fpr = vec![0.0];
91 let mut tpr = vec![0.0];
92 let mut thresholds = Vec::new();
93 let mut tp = 0;
94 let mut fp = 0;
95
96 for &i in &indices {
97 if y_true[i] > 0.5 {
98 tp += 1;
99 } else {
100 fp += 1;
101 }
102 let current_tpr = if pos_count > 0 {
103 tp as f64 / pos_count as f64
104 } else {
105 0.0
106 };
107 let current_fpr = if neg_count > 0 {
108 fp as f64 / neg_count as f64
109 } else {
110 0.0
111 };
112
113 fpr.push(current_fpr);
114 tpr.push(current_tpr);
115 thresholds.push(y_scores[i]);
116 }
117
118 let auc = compute_auc(&fpr, &tpr);
120
121 RocCurve {
122 fpr,
123 tpr,
124 thresholds,
125 auc,
126 }
127}
128
129pub fn roc_auc_score(y_true: &[f64], y_scores: &[f64]) -> f64 {
131 roc_curve(y_true, y_scores).auc
132}
133
134pub fn pr_curve(y_true: &[f64], y_scores: &[f64]) -> PrCurve {
136 let n = y_true.len();
137 let pos_count = y_true.iter().filter(|&&v| v > 0.5).count();
138
139 let mut indices: Vec<usize> = (0..n).collect();
140 indices.sort_unstable_by(|&a, &b| {
141 y_scores[b]
142 .partial_cmp(&y_scores[a])
143 .unwrap_or(std::cmp::Ordering::Equal)
144 });
145
146 let mut prec = vec![1.0];
147 let mut rec = vec![0.0];
148 let mut thresholds = Vec::new();
149 let mut tp = 0;
150 let mut fp = 0;
151
152 for &i in &indices {
153 if y_true[i] > 0.5 {
154 tp += 1;
155 } else {
156 fp += 1;
157 }
158 let p = tp as f64 / (tp + fp) as f64;
159 let r = if pos_count > 0 {
160 tp as f64 / pos_count as f64
161 } else {
162 0.0
163 };
164 prec.push(p);
165 rec.push(r);
166 thresholds.push(y_scores[i]);
167 }
168
169 let avg_precision = compute_auc(&rec, &prec);
170
171 PrCurve {
172 precision: prec,
173 recall: rec,
174 thresholds,
175 avg_precision,
176 }
177}
178
179fn compute_auc(x: &[f64], y: &[f64]) -> f64 {
184 let mut area = 0.0;
185 for i in 1..x.len() {
186 let dx = x[i] - x[i - 1];
187 area += dx * (y[i] + y[i - 1]) / 2.0;
188 }
189 area
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_roc_auc_perfect() {
198 let y_true = vec![0.0, 0.0, 1.0, 1.0];
199 let y_scores = vec![0.1, 0.2, 0.8, 0.9];
200 let auc = roc_auc_score(&y_true, &y_scores);
201 assert!(
202 (auc - 1.0).abs() < 1e-6,
203 "perfect separation should give AUC=1.0, got {auc}"
204 );
205 }
206
207 #[test]
208 fn test_roc_auc_random() {
209 let y_true = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
211 let y_scores = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
212 let auc = roc_auc_score(&y_true, &y_scores);
213 assert!(
214 (0.0..=1.0).contains(&auc),
215 "AUC should be in [0,1], got {auc}"
216 );
217 }
218
219 #[test]
220 fn test_roc_curve_length() {
221 let roc = roc_curve(&[0.0, 1.0, 0.0, 1.0], &[0.1, 0.9, 0.2, 0.8]);
222 assert_eq!(roc.fpr.len(), roc.tpr.len());
223 assert!(roc.fpr.len() > 2);
224 }
225}