1use ndarray::Array1;
7
8#[derive(Debug, Clone)]
10pub struct CalibrationResult {
11 pub predicted: Array1<f64>,
13 pub observed: Array1<f64>,
15 pub slope: f64,
17 pub intercept: f64,
19}
20
21impl CalibrationResult {
22 pub fn calibration_error(&self) -> f64 {
24 calculate_calib_error(&self.predicted, &self.observed)
25 }
26
27 pub fn is_well_calibrated(&self, slope_tol: f64, intercept_tol: f64) -> bool {
29 (self.slope - 1.0).abs() <= slope_tol && self.intercept.abs() <= intercept_tol
30 }
31}
32
33pub fn calibration_regression<F>(
64 ppf_fn: F,
65 y: &Array1<f64>,
66 bins: usize,
67 eps: f64,
68) -> CalibrationResult
69where
70 F: Fn(f64) -> Array1<f64>,
71{
72 let pctles: Vec<f64> = (0..bins)
73 .map(|i| eps + (1.0 - 2.0 * eps) * (i as f64) / ((bins - 1) as f64))
74 .collect();
75
76 let mut observed = Vec::with_capacity(bins);
77
78 for &pctle in &pctles {
79 let icdfs = ppf_fn(pctle);
80 let count_below: usize = y
81 .iter()
82 .zip(icdfs.iter())
83 .filter(|&(yi, qi)| yi < qi)
84 .count();
85 observed.push(count_below as f64 / y.len() as f64);
86 }
87
88 let pctles_arr = Array1::from_vec(pctles);
89 let observed_arr = Array1::from_vec(observed);
90
91 let (slope, intercept) = polyfit_1(&pctles_arr, &observed_arr);
92
93 CalibrationResult {
94 predicted: pctles_arr,
95 observed: observed_arr,
96 slope,
97 intercept,
98 }
99}
100
101pub fn calibration_time_to_event(
113 cdf_at_t: &Array1<f64>,
114 event: &Array1<bool>,
115) -> CalibrationResult {
116 let km_result = kaplan_meier(cdf_at_t, event);
119
120 let n_points = 11;
122 let predicted: Vec<f64> = (0..n_points)
123 .map(|i| i as f64 / (n_points - 1) as f64)
124 .collect();
125
126 let mut observed = Vec::with_capacity(n_points);
127 for &p in &predicted {
128 let survival = interpolate_km(&km_result, p);
130 observed.push(1.0 - survival);
131 }
132
133 let predicted_arr = Array1::from_vec(predicted);
134 let observed_arr = Array1::from_vec(observed);
135
136 let (slope, intercept) = polyfit_1(&predicted_arr, &observed_arr);
137
138 CalibrationResult {
139 predicted: predicted_arr,
140 observed: observed_arr,
141 slope,
142 intercept,
143 }
144}
145
146pub fn calculate_calib_error(predicted: &Array1<f64>, observed: &Array1<f64>) -> f64 {
155 let n = predicted.len();
156 if n == 0 {
157 return 0.0;
158 }
159 let sum_sq: f64 = predicted
160 .iter()
161 .zip(observed.iter())
162 .map(|(p, o)| (p - o).powi(2))
163 .sum();
164 sum_sq / n as f64
165}
166
167#[derive(Debug, Clone)]
169pub struct PITHistogramData {
170 pub bin_edges: Array1<f64>,
172 pub densities: Array1<f64>,
174 pub expected_density: f64,
176}
177
178pub fn pit_histogram(cdf_values: &Array1<f64>, n_bins: usize) -> PITHistogramData {
190 let bin_edges: Vec<f64> = (0..=n_bins).map(|i| i as f64 / n_bins as f64).collect();
191
192 let mut counts = vec![0usize; n_bins];
193 let n = cdf_values.len();
194
195 for &cdf in cdf_values.iter() {
196 let bin_idx = ((cdf * n_bins as f64).floor() as usize).min(n_bins - 1);
197 counts[bin_idx] += 1;
198 }
199
200 let densities: Vec<f64> = counts
201 .iter()
202 .map(|&c| c as f64 / n as f64 * n_bins as f64)
203 .collect();
204
205 PITHistogramData {
206 bin_edges: Array1::from_vec(bin_edges),
207 densities: Array1::from_vec(densities),
208 expected_density: 1.0,
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct CalibrationCurveData {
215 pub predicted: Array1<f64>,
217 pub observed: Array1<f64>,
219 pub fit_x: Array1<f64>,
221 pub fit_y: Array1<f64>,
223 pub slope: f64,
225 pub intercept: f64,
227}
228
229pub fn calibration_curve_data(
238 predicted: &Array1<f64>,
239 observed: &Array1<f64>,
240) -> CalibrationCurveData {
241 let (slope, intercept) = polyfit_1(predicted, observed);
242
243 let fit_x = Array1::linspace(0.0, 1.0, 50);
244 let fit_y = fit_x.mapv(|x| slope * x + intercept);
245
246 CalibrationCurveData {
247 predicted: predicted.clone(),
248 observed: observed.clone(),
249 fit_x,
250 fit_y,
251 slope,
252 intercept,
253 }
254}
255
256pub fn concordance_index(
275 predictions: &Array1<f64>,
276 times: &Array1<f64>,
277 events: &Array1<bool>,
278) -> f64 {
279 let n = times.len();
280 let mut concordant = 0.0;
281 let mut total_comparable = 0.0;
282
283 for i in 0..n {
284 for j in (i + 1)..n {
285 let e_i = events[i];
286 let e_j = events[j];
287 let t_i = times[i];
288 let t_j = times[j];
289 let p_i = predictions[i];
290 let p_j = predictions[j];
291
292 let comparable = if e_i && e_j {
294 true
296 } else if e_i && !e_j && t_i < t_j {
297 true
299 } else if !e_i && e_j && t_i > t_j {
300 true
302 } else {
303 false
304 };
305
306 if comparable {
307 total_comparable += 1.0;
308
309 if (t_i < t_j && p_i > p_j) || (t_i > t_j && p_i < p_j) {
312 concordant += 1.0;
313 } else if (p_i - p_j).abs() < 1e-10 {
314 concordant += 0.5;
316 }
317 }
318 }
319 }
320
321 if total_comparable == 0.0 {
322 return 0.5; }
324
325 concordant / total_comparable
326}
327
328pub fn concordance_index_uncensored_only(
340 predictions: &Array1<f64>,
341 times: &Array1<f64>,
342 events: &Array1<bool>,
343) -> f64 {
344 let uncensored_indices: Vec<usize> = events
346 .iter()
347 .enumerate()
348 .filter(|&(_, e)| *e)
349 .map(|(i, _)| i)
350 .collect();
351
352 let n = uncensored_indices.len();
353 if n < 2 {
354 return 0.5;
355 }
356
357 let mut concordant = 0.0;
358 let mut total = 0.0;
359
360 for i in 0..n {
361 for j in (i + 1)..n {
362 let idx_i = uncensored_indices[i];
363 let idx_j = uncensored_indices[j];
364
365 let t_i = times[idx_i];
366 let t_j = times[idx_j];
367 let p_i = predictions[idx_i];
368 let p_j = predictions[idx_j];
369
370 total += 1.0;
371
372 if (t_i < t_j && p_i > p_j) || (t_i > t_j && p_i < p_j) {
373 concordant += 1.0;
374 } else if (p_i - p_j).abs() < 1e-10 {
375 concordant += 0.5;
376 }
377 }
378 }
379
380 if total == 0.0 {
381 return 0.5;
382 }
383
384 concordant / total
385}
386
387pub fn brier_score(predicted_probs: &Array1<f64>, outcomes: &Array1<f64>) -> f64 {
399 let n = predicted_probs.len();
400 if n == 0 {
401 return 0.0;
402 }
403
404 let sum_sq: f64 = predicted_probs
405 .iter()
406 .zip(outcomes.iter())
407 .map(|(p, o)| (p - o).powi(2))
408 .sum();
409
410 sum_sq / n as f64
411}
412
413pub fn log_loss(predicted_probs: &Array1<f64>, outcomes: &Array1<f64>, eps: f64) -> f64 {
423 let n = predicted_probs.len();
424 if n == 0 {
425 return 0.0;
426 }
427
428 let sum: f64 = predicted_probs
429 .iter()
430 .zip(outcomes.iter())
431 .map(|(&p, &o)| {
432 let p_clamped = p.clamp(eps, 1.0 - eps);
433 -o * p_clamped.ln() - (1.0 - o) * (1.0 - p_clamped).ln()
434 })
435 .sum();
436
437 sum / n as f64
438}
439
440pub fn mean_absolute_error(predicted: &Array1<f64>, actual: &Array1<f64>) -> f64 {
442 let n = predicted.len();
443 if n == 0 {
444 return 0.0;
445 }
446 let sum: f64 = predicted
447 .iter()
448 .zip(actual.iter())
449 .map(|(p, a)| (p - a).abs())
450 .sum();
451 sum / n as f64
452}
453
454pub fn mean_squared_error(predicted: &Array1<f64>, actual: &Array1<f64>) -> f64 {
456 let n = predicted.len();
457 if n == 0 {
458 return 0.0;
459 }
460 let sum: f64 = predicted
461 .iter()
462 .zip(actual.iter())
463 .map(|(p, a)| (p - a).powi(2))
464 .sum();
465 sum / n as f64
466}
467
468pub fn root_mean_squared_error(predicted: &Array1<f64>, actual: &Array1<f64>) -> f64 {
470 mean_squared_error(predicted, actual).sqrt()
471}
472
473fn polyfit_1(x: &Array1<f64>, y: &Array1<f64>) -> (f64, f64) {
479 let n = x.len() as f64;
480 if n < 2.0 {
481 return (1.0, 0.0);
482 }
483
484 let sum_x: f64 = x.iter().sum();
485 let sum_y: f64 = y.iter().sum();
486 let sum_xy: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
487 let sum_x2: f64 = x.iter().map(|xi| xi * xi).sum();
488
489 let mean_x = sum_x / n;
490 let mean_y = sum_y / n;
491
492 let denom = sum_x2 - n * mean_x * mean_x;
493 if denom.abs() < 1e-15 {
494 return (1.0, mean_y - mean_x);
495 }
496
497 let slope = (sum_xy - n * mean_x * mean_y) / denom;
498 let intercept = mean_y - slope * mean_x;
499
500 (slope, intercept)
501}
502
503struct KaplanMeierResult {
505 times: Vec<f64>,
507 survival: Vec<f64>,
509}
510
511fn kaplan_meier(times: &Array1<f64>, events: &Array1<bool>) -> KaplanMeierResult {
513 let mut indices: Vec<usize> = (0..times.len()).collect();
515 indices.sort_by(|&a, &b| times[a].partial_cmp(×[b]).unwrap());
516
517 let mut unique_times = Vec::new();
518 let mut survival_probs = Vec::new();
519
520 let mut at_risk = times.len();
521 let mut survival = 1.0;
522
523 let mut i = 0;
524 while i < indices.len() {
525 let idx = indices[i];
526 let t = times[idx];
527
528 let mut n_events = 0;
530 let mut n_at_time = 0;
531
532 while i < indices.len() && (times[indices[i]] - t).abs() < 1e-10 {
533 if events[indices[i]] {
534 n_events += 1;
535 }
536 n_at_time += 1;
537 i += 1;
538 }
539
540 if n_events > 0 && at_risk > 0 {
541 survival *= 1.0 - (n_events as f64 / at_risk as f64);
542 }
543
544 unique_times.push(t);
545 survival_probs.push(survival);
546
547 at_risk -= n_at_time;
548 }
549
550 KaplanMeierResult {
551 times: unique_times,
552 survival: survival_probs,
553 }
554}
555
556fn interpolate_km(km: &KaplanMeierResult, t: f64) -> f64 {
558 if km.times.is_empty() {
559 return 1.0;
560 }
561
562 if t <= km.times[0] {
563 return 1.0;
564 }
565
566 for i in 0..km.times.len() {
567 if t <= km.times[i] {
568 return km.survival[i.saturating_sub(1)];
569 }
570 }
571
572 *km.survival.last().unwrap_or(&0.0)
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use approx::assert_relative_eq;
579
580 #[test]
581 fn test_calculate_calib_error() {
582 let predicted = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
583 let observed = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
584 assert_relative_eq!(
585 calculate_calib_error(&predicted, &observed),
586 0.0,
587 epsilon = 1e-10
588 );
589
590 let observed_off = Array1::from_vec(vec![0.2, 0.3, 0.4, 0.5, 0.6]);
591 let error = calculate_calib_error(&predicted, &observed_off);
592 assert_relative_eq!(error, 0.01, epsilon = 1e-10);
593 }
594
595 #[test]
596 fn test_polyfit_1() {
597 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
598 let y = Array1::from_vec(vec![1.0, 3.0, 5.0, 7.0, 9.0]);
599 let (slope, intercept) = polyfit_1(&x, &y);
600 assert_relative_eq!(slope, 2.0, epsilon = 1e-10);
601 assert_relative_eq!(intercept, 1.0, epsilon = 1e-10);
602 }
603
604 #[test]
605 fn test_pit_histogram() {
606 let cdf_values = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]);
608 let result = pit_histogram(&cdf_values, 10);
609 assert_eq!(result.densities.len(), 10);
610 assert_eq!(result.bin_edges.len(), 11);
611 assert_relative_eq!(result.expected_density, 1.0, epsilon = 1e-10);
612 }
613
614 #[test]
615 fn test_concordance_index_perfect() {
616 let predictions = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
618 let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
619 let events = Array1::from_vec(vec![true, true, true, true, true]);
620
621 let c_index = concordance_index(&predictions, ×, &events);
622 assert_relative_eq!(c_index, 1.0, epsilon = 1e-10);
623 }
624
625 #[test]
626 fn test_concordance_index_random() {
627 let predictions = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
629 let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
630 let events = Array1::from_vec(vec![true, true, true, true, true]);
631
632 let c_index = concordance_index(&predictions, ×, &events);
633 assert_relative_eq!(c_index, 0.5, epsilon = 1e-10);
634 }
635
636 #[test]
637 fn test_concordance_index_with_censoring() {
638 let predictions = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
640 let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
641 let events = Array1::from_vec(vec![true, false, true, false, true]);
642
643 let c_index = concordance_index(&predictions, ×, &events);
644 assert!(c_index >= 0.0 && c_index <= 1.0);
645 }
646
647 #[test]
648 fn test_brier_score() {
649 let predicted = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0]);
651 let outcomes = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0]);
652 assert_relative_eq!(brier_score(&predicted, &outcomes), 0.0, epsilon = 1e-10);
653
654 let predicted = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0]);
656 let outcomes = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0]);
657 assert_relative_eq!(brier_score(&predicted, &outcomes), 1.0, epsilon = 1e-10);
658 }
659
660 #[test]
661 fn test_log_loss() {
662 let predicted = Array1::from_vec(vec![0.99, 0.01]);
664 let outcomes = Array1::from_vec(vec![1.0, 0.0]);
665 let loss = log_loss(&predicted, &outcomes, 1e-15);
666 assert!(loss < 0.1);
667 }
668
669 #[test]
670 fn test_mean_squared_error() {
671 let predicted = Array1::from_vec(vec![1.0, 2.0, 3.0]);
672 let actual = Array1::from_vec(vec![1.0, 2.0, 3.0]);
673 assert_relative_eq!(
674 mean_squared_error(&predicted, &actual),
675 0.0,
676 epsilon = 1e-10
677 );
678
679 let actual = Array1::from_vec(vec![2.0, 3.0, 4.0]);
680 assert_relative_eq!(
681 mean_squared_error(&predicted, &actual),
682 1.0,
683 epsilon = 1e-10
684 );
685 }
686
687 #[test]
688 fn test_mean_absolute_error() {
689 let predicted = Array1::from_vec(vec![1.0, 2.0, 3.0]);
690 let actual = Array1::from_vec(vec![2.0, 3.0, 4.0]);
691 assert_relative_eq!(
692 mean_absolute_error(&predicted, &actual),
693 1.0,
694 epsilon = 1e-10
695 );
696 }
697
698 #[test]
699 fn test_calibration_result() {
700 let result = CalibrationResult {
701 predicted: Array1::from_vec(vec![0.1, 0.5, 0.9]),
702 observed: Array1::from_vec(vec![0.1, 0.5, 0.9]),
703 slope: 1.0,
704 intercept: 0.0,
705 };
706
707 assert!(result.is_well_calibrated(0.1, 0.1));
708 assert_relative_eq!(result.calibration_error(), 0.0, epsilon = 1e-10);
709 }
710
711 #[test]
712 fn test_kaplan_meier() {
713 let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
714 let events = Array1::from_vec(vec![true, false, true, false, true]);
715
716 let km = kaplan_meier(×, &events);
717 assert_eq!(km.times.len(), 5);
718 assert!(km.survival[0] < 1.0);
719 assert!(km.survival.last().unwrap() < &km.survival[0]);
720 }
721
722 #[test]
723 fn test_concordance_uncensored_only() {
724 let predictions = Array1::from_vec(vec![5.0, 4.0, 3.0]);
725 let times = Array1::from_vec(vec![1.0, 2.0, 3.0]);
726 let events = Array1::from_vec(vec![true, true, true]);
727
728 let c_index = concordance_index_uncensored_only(&predictions, ×, &events);
729 assert_relative_eq!(c_index, 1.0, epsilon = 1e-10);
730 }
731}