1use crate::EvalError;
4
5pub fn top_k_accuracy(
10 scores: &[f32],
11 num_classes: usize,
12 targets: &[usize],
13 k: usize,
14) -> Result<f32, EvalError> {
15 if scores.is_empty() || num_classes == 0 {
16 return Ok(0.0);
17 }
18 let n = scores.len() / num_classes;
19 if n != targets.len() {
20 return Err(EvalError::CountLengthMismatch {
21 ground_truth: targets.len(),
22 predictions: n,
23 });
24 }
25
26 let mut correct = 0;
27 for i in 0..n {
28 let row = &scores[i * num_classes..(i + 1) * num_classes];
29 let mut indices: Vec<usize> = (0..num_classes).collect();
30 indices.sort_unstable_by(|&a, &b| {
31 row[b]
32 .partial_cmp(&row[a])
33 .unwrap_or(std::cmp::Ordering::Equal)
34 });
35 if indices[..k.min(num_classes)].contains(&targets[i]) {
36 correct += 1;
37 }
38 }
39 Ok(correct as f32 / n as f32)
40}
41
42pub fn roc_curve(
46 scores: &[f32],
47 labels: &[bool],
48) -> Result<(Vec<f32>, Vec<f32>, Vec<f32>), EvalError> {
49 if scores.len() != labels.len() {
50 return Err(EvalError::CountLengthMismatch {
51 ground_truth: labels.len(),
52 predictions: scores.len(),
53 });
54 }
55
56 let n = scores.len();
57 let total_pos = labels.iter().filter(|&&l| l).count() as f32;
58 let total_neg = n as f32 - total_pos;
59
60 if total_pos == 0.0 || total_neg == 0.0 {
61 return Ok((
62 vec![0.0, 1.0],
63 vec![0.0, 1.0],
64 vec![f32::INFINITY, f32::NEG_INFINITY],
65 ));
66 }
67
68 let mut indices: Vec<usize> = (0..n).collect();
70 indices.sort_unstable_by(|&a, &b| {
71 scores[b]
72 .partial_cmp(&scores[a])
73 .unwrap_or(std::cmp::Ordering::Equal)
74 });
75
76 let mut fpr_list = vec![0.0f32];
77 let mut tpr_list = vec![0.0f32];
78 let mut thresholds = vec![f32::INFINITY];
79
80 let mut tp = 0.0f32;
81 let mut fp = 0.0f32;
82
83 for &i in &indices {
84 if labels[i] {
85 tp += 1.0;
86 } else {
87 fp += 1.0;
88 }
89 fpr_list.push(fp / total_neg);
90 tpr_list.push(tp / total_pos);
91 thresholds.push(scores[i]);
92 }
93
94 Ok((fpr_list, tpr_list, thresholds))
95}
96
97pub fn auc(x: &[f32], y: &[f32]) -> Result<f32, EvalError> {
101 if x.len() != y.len() {
102 return Err(EvalError::CountLengthMismatch {
103 ground_truth: x.len(),
104 predictions: y.len(),
105 });
106 }
107 if x.len() < 2 {
108 return Ok(0.0);
109 }
110
111 let mut area = 0.0f32;
112 for i in 1..x.len() {
113 area += (x[i] - x[i - 1]) * (y[i] + y[i - 1]) / 2.0;
114 }
115 Ok(area.abs())
116}
117
118pub fn mean_iou(
122 predictions: &[usize],
123 targets: &[usize],
124 num_classes: usize,
125) -> Result<f32, EvalError> {
126 if predictions.len() != targets.len() {
127 return Err(EvalError::CountLengthMismatch {
128 ground_truth: targets.len(),
129 predictions: predictions.len(),
130 });
131 }
132
133 let mut intersection = vec![0usize; num_classes];
134 let mut union = vec![0usize; num_classes];
135
136 for (&p, &t) in predictions.iter().zip(targets.iter()) {
137 if t < num_classes {
138 if p == t {
139 intersection[t] += 1;
140 }
141 union[t] += 1;
142 }
143 if p < num_classes && p != t {
144 union[p] += 1;
145 }
146 }
147
148 let mut sum_iou = 0.0f32;
149 let mut valid_classes = 0;
150 for c in 0..num_classes {
151 if union[c] > 0 {
152 sum_iou += intersection[c] as f32 / union[c] as f32;
153 valid_classes += 1;
154 }
155 }
156
157 if valid_classes == 0 {
158 return Ok(0.0);
159 }
160 Ok(sum_iou / valid_classes as f32)
161}
162
163pub fn dice_score(predictions: &[usize], targets: &[usize], num_classes: usize) -> Vec<f32> {
168 let mut tp = vec![0usize; num_classes];
169 let mut fp = vec![0usize; num_classes];
170 let mut fn_ = vec![0usize; num_classes];
171
172 for (&p, &t) in predictions.iter().zip(targets.iter()) {
173 if p == t {
174 if p < num_classes {
175 tp[p] += 1;
176 }
177 } else {
178 if p < num_classes {
179 fp[p] += 1;
180 }
181 if t < num_classes {
182 fn_[t] += 1;
183 }
184 }
185 }
186
187 (0..num_classes)
188 .map(|c| {
189 let denom = 2 * tp[c] + fp[c] + fn_[c];
190 if denom == 0 {
191 0.0
192 } else {
193 (2 * tp[c]) as f32 / denom as f32
194 }
195 })
196 .collect()
197}
198
199pub fn per_class_iou(predictions: &[usize], targets: &[usize], num_classes: usize) -> Vec<f32> {
204 let mut tp = vec![0usize; num_classes];
205 let mut fp = vec![0usize; num_classes];
206 let mut fn_ = vec![0usize; num_classes];
207
208 for (&p, &t) in predictions.iter().zip(targets.iter()) {
209 if p == t {
210 if p < num_classes {
211 tp[p] += 1;
212 }
213 } else {
214 if p < num_classes {
215 fp[p] += 1;
216 }
217 if t < num_classes {
218 fn_[t] += 1;
219 }
220 }
221 }
222
223 (0..num_classes)
224 .map(|c| {
225 let denom = tp[c] + fp[c] + fn_[c];
226 if denom == 0 {
227 0.0
228 } else {
229 tp[c] as f32 / denom as f32
230 }
231 })
232 .collect()
233}
234
235pub fn ssim(img1: &[f32], img2: &[f32]) -> Result<f32, EvalError> {
239 if img1.len() != img2.len() {
240 return Err(EvalError::CountLengthMismatch {
241 ground_truth: img1.len(),
242 predictions: img2.len(),
243 });
244 }
245 let n = img1.len() as f32;
246 if n == 0.0 {
247 return Ok(1.0);
248 }
249
250 let c1 = (0.01f32 * 1.0).powi(2); let c2 = (0.03f32 * 1.0).powi(2);
252
253 let mu1: f32 = img1.iter().sum::<f32>() / n;
254 let mu2: f32 = img2.iter().sum::<f32>() / n;
255
256 let sigma1_sq: f32 = img1.iter().map(|&v| (v - mu1).powi(2)).sum::<f32>() / n;
257 let sigma2_sq: f32 = img2.iter().map(|&v| (v - mu2).powi(2)).sum::<f32>() / n;
258 let sigma12: f32 = img1
259 .iter()
260 .zip(img2.iter())
261 .map(|(&a, &b)| (a - mu1) * (b - mu2))
262 .sum::<f32>()
263 / n;
264
265 let numerator = (2.0 * mu1 * mu2 + c1) * (2.0 * sigma12 + c2);
266 let denominator = (mu1.powi(2) + mu2.powi(2) + c1) * (sigma1_sq + sigma2_sq + c2);
267
268 Ok(numerator / denominator)
269}
270
271pub fn psnr(img1: &[f32], img2: &[f32], max_val: f32) -> Result<f32, EvalError> {
275 if img1.len() != img2.len() {
276 return Err(EvalError::CountLengthMismatch {
277 ground_truth: img1.len(),
278 predictions: img2.len(),
279 });
280 }
281 let mse: f32 = img1
282 .iter()
283 .zip(img2.iter())
284 .map(|(&a, &b)| (a - b).powi(2))
285 .sum::<f32>()
286 / img1.len() as f32;
287
288 if mse == 0.0 {
289 return Ok(f32::INFINITY);
290 }
291 Ok(10.0 * (max_val.powi(2) / mse).log10())
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_top_k_accuracy() {
300 let scores = vec![0.1, 0.8, 0.5, 0.7, 0.2, 0.1];
304 let targets = vec![2, 0];
305 let acc = top_k_accuracy(&scores, 3, &targets, 1).unwrap();
306 assert!((acc - 0.5).abs() < 1e-6); let acc_k2 = top_k_accuracy(&scores, 3, &targets, 2).unwrap();
308 assert!((acc_k2 - 1.0).abs() < 1e-6); }
310
311 #[test]
312 fn test_roc_curve_and_auc() {
313 let scores = vec![0.9, 0.8, 0.4, 0.3, 0.1];
314 let labels = vec![true, true, false, false, false];
315 let (fpr, tpr, _) = roc_curve(&scores, &labels).unwrap();
316 let area = auc(&fpr, &tpr).unwrap();
317 assert!(area > 0.9, "AUC should be high: {area}");
318 }
319
320 #[test]
321 fn test_auc_perfect() {
322 let fpr = vec![0.0, 0.0, 1.0];
323 let tpr = vec![0.0, 1.0, 1.0];
324 let area = auc(&fpr, &tpr).unwrap();
325 assert!((area - 1.0).abs() < 1e-6);
326 }
327
328 #[test]
329 fn test_mean_iou() {
330 let preds = vec![0, 0, 1, 1, 2, 2];
331 let targets = vec![0, 0, 1, 1, 2, 2];
332 let miou = mean_iou(&preds, &targets, 3).unwrap();
333 assert!((miou - 1.0).abs() < 1e-6);
334 }
335
336 #[test]
337 fn test_mean_iou_partial() {
338 let preds = vec![0, 1, 1, 0];
339 let targets = vec![0, 0, 1, 1];
340 let miou = mean_iou(&preds, &targets, 2).unwrap();
341 assert!((miou - 1.0 / 3.0).abs() < 0.01);
345 }
346
347 #[test]
348 fn test_ssim_identical() {
349 let img = vec![0.5f32; 100];
350 let val = ssim(&img, &img).unwrap();
351 assert!((val - 1.0).abs() < 1e-4);
352 }
353
354 #[test]
355 fn test_psnr_identical() {
356 let img = vec![0.5f32; 100];
357 let val = psnr(&img, &img, 1.0).unwrap();
358 assert!(val.is_infinite() && val > 0.0);
359 }
360
361 #[test]
362 fn dice_score_perfect() {
363 let preds = vec![0, 0, 1, 1, 2, 2];
364 let targets = vec![0, 0, 1, 1, 2, 2];
365 let scores = dice_score(&preds, &targets, 3);
366 for &s in &scores {
367 assert!((s - 1.0).abs() < 1e-6, "expected 1.0, got {s}");
368 }
369 }
370
371 #[test]
372 fn dice_score_partial() {
373 let preds = vec![0, 1, 1, 0];
378 let targets = vec![0, 0, 1, 1];
379 let scores = dice_score(&preds, &targets, 2);
380 assert!((scores[0] - 0.5).abs() < 1e-6, "class 0: {}", scores[0]);
381 assert!((scores[1] - 0.5).abs() < 1e-6, "class 1: {}", scores[1]);
382 }
383
384 #[test]
385 fn per_class_iou_known_values() {
386 let preds = vec![0, 1, 1, 0];
391 let targets = vec![0, 0, 1, 1];
392 let ious = per_class_iou(&preds, &targets, 2);
393 assert!((ious[0] - 1.0 / 3.0).abs() < 1e-6, "class 0: {}", ious[0]);
394 assert!((ious[1] - 1.0 / 3.0).abs() < 1e-6, "class 1: {}", ious[1]);
395 }
396
397 #[test]
398 fn per_class_iou_no_overlap() {
399 let preds = vec![0, 0, 0, 0];
401 let targets = vec![1, 1, 1, 1];
402 let ious = per_class_iou(&preds, &targets, 2);
403 assert!((ious[0]).abs() < 1e-6, "class 0 should be 0: {}", ious[0]);
404 assert!((ious[1]).abs() < 1e-6, "class 1 should be 0: {}", ious[1]);
405 }
406
407 #[test]
408 fn test_psnr_different() {
409 let img1 = vec![0.0f32; 100];
410 let img2 = vec![1.0f32; 100];
411 let val = psnr(&img1, &img2, 1.0).unwrap();
412 assert!((val - 0.0).abs() < 1e-6); }
414}