ngboost_rs/
evaluation.rs

1//! Evaluation metrics for NGBoost models.
2//!
3//! This module provides functions for evaluating probabilistic predictions,
4//! including calibration metrics and concordance indices for survival analysis.
5
6use ndarray::Array1;
7
8/// Result of calibration analysis.
9#[derive(Debug, Clone)]
10pub struct CalibrationResult {
11    /// The predicted quantiles/percentiles.
12    pub predicted: Array1<f64>,
13    /// The observed proportions.
14    pub observed: Array1<f64>,
15    /// The slope of the calibration line.
16    pub slope: f64,
17    /// The intercept of the calibration line.
18    pub intercept: f64,
19}
20
21impl CalibrationResult {
22    /// Calculate the calibration error (sum of squared differences).
23    pub fn calibration_error(&self) -> f64 {
24        calculate_calib_error(&self.predicted, &self.observed)
25    }
26
27    /// Check if the model is well-calibrated (slope close to 1, intercept close to 0).
28    pub fn is_well_calibrated(&self, slope_tol: f64, intercept_tol: f64) -> bool {
29        (self.slope - 1.0).abs() <= slope_tol && self.intercept.abs() <= intercept_tol
30    }
31}
32
33/// Calculate calibration in the regression setting.
34///
35/// Computes how well-calibrated the predicted distributions are by comparing
36/// predicted quantiles to observed proportions.
37///
38/// # Arguments
39/// * `ppf_fn` - Function that computes the percent point function (inverse CDF)
40///              given a percentile value. Should return an Array1<f64> of quantiles.
41/// * `y` - Observed values.
42/// * `bins` - Number of bins/percentiles to evaluate (default: 11).
43/// * `eps` - Small value to avoid edge effects (default: 1e-3).
44///
45/// # Returns
46/// A `CalibrationResult` containing predicted percentiles, observed proportions,
47/// and the fitted calibration line parameters.
48///
49/// # Example
50/// ```ignore
51/// use ngboost_rs::evaluation::calibration_regression;
52/// use ndarray::Array1;
53///
54/// let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
55/// let result = calibration_regression(
56///     |p| ppf_values_at_percentile_p,
57///     &y,
58///     11,
59///     1e-3
60/// );
61/// println!("Slope: {}, Intercept: {}", result.slope, result.intercept);
62/// ```
63pub fn calibration_regression<F>(
64    ppf_fn: F,
65    y: &Array1<f64>,
66    bins: usize,
67    eps: f64,
68) -> CalibrationResult
69where
70    F: Fn(f64) -> Array1<f64>,
71{
72    let pctles: Vec<f64> = (0..bins)
73        .map(|i| eps + (1.0 - 2.0 * eps) * (i as f64) / ((bins - 1) as f64))
74        .collect();
75
76    let mut observed = Vec::with_capacity(bins);
77
78    for &pctle in &pctles {
79        let icdfs = ppf_fn(pctle);
80        let count_below: usize = y
81            .iter()
82            .zip(icdfs.iter())
83            .filter(|&(yi, qi)| yi < qi)
84            .count();
85        observed.push(count_below as f64 / y.len() as f64);
86    }
87
88    let pctles_arr = Array1::from_vec(pctles);
89    let observed_arr = Array1::from_vec(observed);
90
91    let (slope, intercept) = polyfit_1(&pctles_arr, &observed_arr);
92
93    CalibrationResult {
94        predicted: pctles_arr,
95        observed: observed_arr,
96        slope,
97        intercept,
98    }
99}
100
101/// Calculate calibration in the time-to-event (survival) setting.
102///
103/// Uses the probability integral transform and Kaplan-Meier estimation
104/// to assess calibration of survival predictions.
105///
106/// # Arguments
107/// * `cdf_at_t` - CDF values at the observed times (F(T) for each observation).
108/// * `event` - Event indicators (true = event occurred, false = censored).
109///
110/// # Returns
111/// A `CalibrationResult` containing the calibration analysis.
112pub fn calibration_time_to_event(
113    cdf_at_t: &Array1<f64>,
114    event: &Array1<bool>,
115) -> CalibrationResult {
116    // Compute Kaplan-Meier estimate on the CDF values
117    // The idea: if well-calibrated, CDF(T) should be uniform on [0,1] for uncensored
118    let km_result = kaplan_meier(cdf_at_t, event);
119
120    // Sample at 11 evenly spaced points
121    let n_points = 11;
122    let predicted: Vec<f64> = (0..n_points)
123        .map(|i| i as f64 / (n_points - 1) as f64)
124        .collect();
125
126    let mut observed = Vec::with_capacity(n_points);
127    for &p in &predicted {
128        // Find the survival probability at this CDF value
129        let survival = interpolate_km(&km_result, p);
130        observed.push(1.0 - survival);
131    }
132
133    let predicted_arr = Array1::from_vec(predicted);
134    let observed_arr = Array1::from_vec(observed);
135
136    let (slope, intercept) = polyfit_1(&predicted_arr, &observed_arr);
137
138    CalibrationResult {
139        predicted: predicted_arr,
140        observed: observed_arr,
141        slope,
142        intercept,
143    }
144}
145
146/// Calculate calibration error as sum of squared differences.
147///
148/// # Arguments
149/// * `predicted` - Predicted values/quantiles.
150/// * `observed` - Observed proportions.
151///
152/// # Returns
153/// The mean squared calibration error.
154pub fn calculate_calib_error(predicted: &Array1<f64>, observed: &Array1<f64>) -> f64 {
155    let n = predicted.len();
156    if n == 0 {
157        return 0.0;
158    }
159    let sum_sq: f64 = predicted
160        .iter()
161        .zip(observed.iter())
162        .map(|(p, o)| (p - o).powi(2))
163        .sum();
164    sum_sq / n as f64
165}
166
167/// Data for a PIT (Probability Integral Transform) histogram.
168#[derive(Debug, Clone)]
169pub struct PITHistogramData {
170    /// Bin edges.
171    pub bin_edges: Array1<f64>,
172    /// Density values for each bin.
173    pub densities: Array1<f64>,
174    /// Expected uniform density (1 / (n_bins)).
175    pub expected_density: f64,
176}
177
178/// Compute PIT histogram data.
179///
180/// The PIT histogram shows how well-calibrated a probabilistic forecast is.
181/// For a well-calibrated model, the histogram should be approximately uniform.
182///
183/// # Arguments
184/// * `cdf_values` - CDF evaluated at the observed values (F(y) for each y).
185/// * `n_bins` - Number of bins for the histogram (default: 10).
186///
187/// # Returns
188/// PIT histogram data including bin edges and densities.
189pub fn pit_histogram(cdf_values: &Array1<f64>, n_bins: usize) -> PITHistogramData {
190    let bin_edges: Vec<f64> = (0..=n_bins).map(|i| i as f64 / n_bins as f64).collect();
191
192    let mut counts = vec![0usize; n_bins];
193    let n = cdf_values.len();
194
195    for &cdf in cdf_values.iter() {
196        let bin_idx = ((cdf * n_bins as f64).floor() as usize).min(n_bins - 1);
197        counts[bin_idx] += 1;
198    }
199
200    let densities: Vec<f64> = counts
201        .iter()
202        .map(|&c| c as f64 / n as f64 * n_bins as f64)
203        .collect();
204
205    PITHistogramData {
206        bin_edges: Array1::from_vec(bin_edges),
207        densities: Array1::from_vec(densities),
208        expected_density: 1.0,
209    }
210}
211
212/// Data for a calibration curve plot.
213#[derive(Debug, Clone)]
214pub struct CalibrationCurveData {
215    /// Predicted probabilities/quantiles.
216    pub predicted: Array1<f64>,
217    /// Observed proportions.
218    pub observed: Array1<f64>,
219    /// Fitted line x-values.
220    pub fit_x: Array1<f64>,
221    /// Fitted line y-values.
222    pub fit_y: Array1<f64>,
223    /// Slope of the calibration line.
224    pub slope: f64,
225    /// Intercept of the calibration line.
226    pub intercept: f64,
227}
228
229/// Compute calibration curve data for plotting.
230///
231/// # Arguments
232/// * `predicted` - Predicted probabilities/quantiles.
233/// * `observed` - Observed proportions.
234///
235/// # Returns
236/// Data for plotting a calibration curve.
237pub fn calibration_curve_data(
238    predicted: &Array1<f64>,
239    observed: &Array1<f64>,
240) -> CalibrationCurveData {
241    let (slope, intercept) = polyfit_1(predicted, observed);
242
243    let fit_x = Array1::linspace(0.0, 1.0, 50);
244    let fit_y = fit_x.mapv(|x| slope * x + intercept);
245
246    CalibrationCurveData {
247        predicted: predicted.clone(),
248        observed: observed.clone(),
249        fit_x,
250        fit_y,
251        slope,
252        intercept,
253    }
254}
255
256/// Calculate Harrell's C-statistic (concordance index) with censoring support.
257///
258/// The concordance index measures the ability of a model to correctly rank
259/// pairs of observations by their predicted risk/time.
260///
261/// # Comparable Pairs
262/// - Both uncensored: can compare
263/// - One censored, one not: can compare if censored time > uncensored time
264/// - Both censored: cannot compare
265///
266/// # Arguments
267/// * `predictions` - Predicted risk scores or times (higher = higher risk).
268/// * `times` - Observed times to event or censoring.
269/// * `events` - Event indicators (true = event occurred, false = censored).
270///
271/// # Returns
272/// The concordance index in [0, 1]. A value of 0.5 indicates random predictions,
273/// while 1.0 indicates perfect concordance.
274pub fn concordance_index(
275    predictions: &Array1<f64>,
276    times: &Array1<f64>,
277    events: &Array1<bool>,
278) -> f64 {
279    let n = times.len();
280    let mut concordant = 0.0;
281    let mut total_comparable = 0.0;
282
283    for i in 0..n {
284        for j in (i + 1)..n {
285            let e_i = events[i];
286            let e_j = events[j];
287            let t_i = times[i];
288            let t_j = times[j];
289            let p_i = predictions[i];
290            let p_j = predictions[j];
291
292            // Determine if this pair is comparable
293            let comparable = if e_i && e_j {
294                // Both uncensored: always comparable
295                true
296            } else if e_i && !e_j && t_i < t_j {
297                // i uncensored, j censored, and i's event time < j's censoring time
298                true
299            } else if !e_i && e_j && t_i > t_j {
300                // i censored, j uncensored, and i's censoring time > j's event time
301                true
302            } else {
303                false
304            };
305
306            if comparable {
307                total_comparable += 1.0;
308
309                // Compare predictions based on true ordering
310                // For survival: lower predicted time (or higher risk) = earlier event
311                if (t_i < t_j && p_i > p_j) || (t_i > t_j && p_i < p_j) {
312                    concordant += 1.0;
313                } else if (p_i - p_j).abs() < 1e-10 {
314                    // Tie in predictions
315                    concordant += 0.5;
316                }
317            }
318        }
319    }
320
321    if total_comparable == 0.0 {
322        return 0.5; // No comparable pairs
323    }
324
325    concordant / total_comparable
326}
327
328/// Calculate concordance index considering only uncensored observations.
329///
330/// This is a simplified version that ignores censored observations entirely.
331///
332/// # Arguments
333/// * `predictions` - Predicted risk scores or times.
334/// * `times` - Observed times to event.
335/// * `events` - Event indicators (true = event occurred, false = censored).
336///
337/// # Returns
338/// The concordance index computed only on uncensored pairs.
339pub fn concordance_index_uncensored_only(
340    predictions: &Array1<f64>,
341    times: &Array1<f64>,
342    events: &Array1<bool>,
343) -> f64 {
344    // Filter to only uncensored observations
345    let uncensored_indices: Vec<usize> = events
346        .iter()
347        .enumerate()
348        .filter(|&(_, e)| *e)
349        .map(|(i, _)| i)
350        .collect();
351
352    let n = uncensored_indices.len();
353    if n < 2 {
354        return 0.5;
355    }
356
357    let mut concordant = 0.0;
358    let mut total = 0.0;
359
360    for i in 0..n {
361        for j in (i + 1)..n {
362            let idx_i = uncensored_indices[i];
363            let idx_j = uncensored_indices[j];
364
365            let t_i = times[idx_i];
366            let t_j = times[idx_j];
367            let p_i = predictions[idx_i];
368            let p_j = predictions[idx_j];
369
370            total += 1.0;
371
372            if (t_i < t_j && p_i > p_j) || (t_i > t_j && p_i < p_j) {
373                concordant += 1.0;
374            } else if (p_i - p_j).abs() < 1e-10 {
375                concordant += 0.5;
376            }
377        }
378    }
379
380    if total == 0.0 {
381        return 0.5;
382    }
383
384    concordant / total
385}
386
387/// Compute the Brier score for probabilistic predictions.
388///
389/// The Brier score measures the accuracy of probabilistic predictions.
390/// Lower is better, with 0 being perfect predictions.
391///
392/// # Arguments
393/// * `predicted_probs` - Predicted probabilities.
394/// * `outcomes` - Binary outcomes (0 or 1).
395///
396/// # Returns
397/// The Brier score.
398pub fn brier_score(predicted_probs: &Array1<f64>, outcomes: &Array1<f64>) -> f64 {
399    let n = predicted_probs.len();
400    if n == 0 {
401        return 0.0;
402    }
403
404    let sum_sq: f64 = predicted_probs
405        .iter()
406        .zip(outcomes.iter())
407        .map(|(p, o)| (p - o).powi(2))
408        .sum();
409
410    sum_sq / n as f64
411}
412
413/// Compute the log loss (cross-entropy) for probabilistic predictions.
414///
415/// # Arguments
416/// * `predicted_probs` - Predicted probabilities (should be in (0, 1)).
417/// * `outcomes` - Binary outcomes (0 or 1).
418/// * `eps` - Small value to avoid log(0) (default: 1e-15).
419///
420/// # Returns
421/// The log loss.
422pub fn log_loss(predicted_probs: &Array1<f64>, outcomes: &Array1<f64>, eps: f64) -> f64 {
423    let n = predicted_probs.len();
424    if n == 0 {
425        return 0.0;
426    }
427
428    let sum: f64 = predicted_probs
429        .iter()
430        .zip(outcomes.iter())
431        .map(|(&p, &o)| {
432            let p_clamped = p.clamp(eps, 1.0 - eps);
433            -o * p_clamped.ln() - (1.0 - o) * (1.0 - p_clamped).ln()
434        })
435        .sum();
436
437    sum / n as f64
438}
439
440/// Compute the mean absolute error.
441pub fn mean_absolute_error(predicted: &Array1<f64>, actual: &Array1<f64>) -> f64 {
442    let n = predicted.len();
443    if n == 0 {
444        return 0.0;
445    }
446    let sum: f64 = predicted
447        .iter()
448        .zip(actual.iter())
449        .map(|(p, a)| (p - a).abs())
450        .sum();
451    sum / n as f64
452}
453
454/// Compute the mean squared error.
455pub fn mean_squared_error(predicted: &Array1<f64>, actual: &Array1<f64>) -> f64 {
456    let n = predicted.len();
457    if n == 0 {
458        return 0.0;
459    }
460    let sum: f64 = predicted
461        .iter()
462        .zip(actual.iter())
463        .map(|(p, a)| (p - a).powi(2))
464        .sum();
465    sum / n as f64
466}
467
468/// Compute the root mean squared error.
469pub fn root_mean_squared_error(predicted: &Array1<f64>, actual: &Array1<f64>) -> f64 {
470    mean_squared_error(predicted, actual).sqrt()
471}
472
473// ============================================================================
474// Helper functions
475// ============================================================================
476
477/// Simple linear regression to fit a line y = slope * x + intercept.
478fn polyfit_1(x: &Array1<f64>, y: &Array1<f64>) -> (f64, f64) {
479    let n = x.len() as f64;
480    if n < 2.0 {
481        return (1.0, 0.0);
482    }
483
484    let sum_x: f64 = x.iter().sum();
485    let sum_y: f64 = y.iter().sum();
486    let sum_xy: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
487    let sum_x2: f64 = x.iter().map(|xi| xi * xi).sum();
488
489    let mean_x = sum_x / n;
490    let mean_y = sum_y / n;
491
492    let denom = sum_x2 - n * mean_x * mean_x;
493    if denom.abs() < 1e-15 {
494        return (1.0, mean_y - mean_x);
495    }
496
497    let slope = (sum_xy - n * mean_x * mean_y) / denom;
498    let intercept = mean_y - slope * mean_x;
499
500    (slope, intercept)
501}
502
503/// Kaplan-Meier estimate result.
504struct KaplanMeierResult {
505    /// Unique event times.
506    times: Vec<f64>,
507    /// Survival probabilities at each time.
508    survival: Vec<f64>,
509}
510
511/// Compute Kaplan-Meier survival estimate.
512fn kaplan_meier(times: &Array1<f64>, events: &Array1<bool>) -> KaplanMeierResult {
513    // Sort by time
514    let mut indices: Vec<usize> = (0..times.len()).collect();
515    indices.sort_by(|&a, &b| times[a].partial_cmp(&times[b]).unwrap());
516
517    let mut unique_times = Vec::new();
518    let mut survival_probs = Vec::new();
519
520    let mut at_risk = times.len();
521    let mut survival = 1.0;
522
523    let mut i = 0;
524    while i < indices.len() {
525        let idx = indices[i];
526        let t = times[idx];
527
528        // Count events and censored at this time
529        let mut n_events = 0;
530        let mut n_at_time = 0;
531
532        while i < indices.len() && (times[indices[i]] - t).abs() < 1e-10 {
533            if events[indices[i]] {
534                n_events += 1;
535            }
536            n_at_time += 1;
537            i += 1;
538        }
539
540        if n_events > 0 && at_risk > 0 {
541            survival *= 1.0 - (n_events as f64 / at_risk as f64);
542        }
543
544        unique_times.push(t);
545        survival_probs.push(survival);
546
547        at_risk -= n_at_time;
548    }
549
550    KaplanMeierResult {
551        times: unique_times,
552        survival: survival_probs,
553    }
554}
555
556/// Interpolate Kaplan-Meier survival function at a given time.
557fn interpolate_km(km: &KaplanMeierResult, t: f64) -> f64 {
558    if km.times.is_empty() {
559        return 1.0;
560    }
561
562    if t <= km.times[0] {
563        return 1.0;
564    }
565
566    for i in 0..km.times.len() {
567        if t <= km.times[i] {
568            return km.survival[i.saturating_sub(1)];
569        }
570    }
571
572    *km.survival.last().unwrap_or(&0.0)
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use approx::assert_relative_eq;
579
580    #[test]
581    fn test_calculate_calib_error() {
582        let predicted = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
583        let observed = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5]);
584        assert_relative_eq!(
585            calculate_calib_error(&predicted, &observed),
586            0.0,
587            epsilon = 1e-10
588        );
589
590        let observed_off = Array1::from_vec(vec![0.2, 0.3, 0.4, 0.5, 0.6]);
591        let error = calculate_calib_error(&predicted, &observed_off);
592        assert_relative_eq!(error, 0.01, epsilon = 1e-10);
593    }
594
595    #[test]
596    fn test_polyfit_1() {
597        let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0]);
598        let y = Array1::from_vec(vec![1.0, 3.0, 5.0, 7.0, 9.0]);
599        let (slope, intercept) = polyfit_1(&x, &y);
600        assert_relative_eq!(slope, 2.0, epsilon = 1e-10);
601        assert_relative_eq!(intercept, 1.0, epsilon = 1e-10);
602    }
603
604    #[test]
605    fn test_pit_histogram() {
606        // Well-calibrated predictions should give uniform PIT
607        let cdf_values = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]);
608        let result = pit_histogram(&cdf_values, 10);
609        assert_eq!(result.densities.len(), 10);
610        assert_eq!(result.bin_edges.len(), 11);
611        assert_relative_eq!(result.expected_density, 1.0, epsilon = 1e-10);
612    }
613
614    #[test]
615    fn test_concordance_index_perfect() {
616        // Perfect concordance: predictions match true ordering
617        let predictions = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
618        let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
619        let events = Array1::from_vec(vec![true, true, true, true, true]);
620
621        let c_index = concordance_index(&predictions, &times, &events);
622        assert_relative_eq!(c_index, 1.0, epsilon = 1e-10);
623    }
624
625    #[test]
626    fn test_concordance_index_random() {
627        // Random/independent predictions should give ~0.5
628        let predictions = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0, 1.0]);
629        let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
630        let events = Array1::from_vec(vec![true, true, true, true, true]);
631
632        let c_index = concordance_index(&predictions, &times, &events);
633        assert_relative_eq!(c_index, 0.5, epsilon = 1e-10);
634    }
635
636    #[test]
637    fn test_concordance_index_with_censoring() {
638        // Test with some censored observations
639        let predictions = Array1::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0]);
640        let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
641        let events = Array1::from_vec(vec![true, false, true, false, true]);
642
643        let c_index = concordance_index(&predictions, &times, &events);
644        assert!(c_index >= 0.0 && c_index <= 1.0);
645    }
646
647    #[test]
648    fn test_brier_score() {
649        // Perfect predictions
650        let predicted = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0]);
651        let outcomes = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0]);
652        assert_relative_eq!(brier_score(&predicted, &outcomes), 0.0, epsilon = 1e-10);
653
654        // Worst predictions
655        let predicted = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0]);
656        let outcomes = Array1::from_vec(vec![0.0, 1.0, 0.0, 1.0]);
657        assert_relative_eq!(brier_score(&predicted, &outcomes), 1.0, epsilon = 1e-10);
658    }
659
660    #[test]
661    fn test_log_loss() {
662        // Perfect confident predictions
663        let predicted = Array1::from_vec(vec![0.99, 0.01]);
664        let outcomes = Array1::from_vec(vec![1.0, 0.0]);
665        let loss = log_loss(&predicted, &outcomes, 1e-15);
666        assert!(loss < 0.1);
667    }
668
669    #[test]
670    fn test_mean_squared_error() {
671        let predicted = Array1::from_vec(vec![1.0, 2.0, 3.0]);
672        let actual = Array1::from_vec(vec![1.0, 2.0, 3.0]);
673        assert_relative_eq!(
674            mean_squared_error(&predicted, &actual),
675            0.0,
676            epsilon = 1e-10
677        );
678
679        let actual = Array1::from_vec(vec![2.0, 3.0, 4.0]);
680        assert_relative_eq!(
681            mean_squared_error(&predicted, &actual),
682            1.0,
683            epsilon = 1e-10
684        );
685    }
686
687    #[test]
688    fn test_mean_absolute_error() {
689        let predicted = Array1::from_vec(vec![1.0, 2.0, 3.0]);
690        let actual = Array1::from_vec(vec![2.0, 3.0, 4.0]);
691        assert_relative_eq!(
692            mean_absolute_error(&predicted, &actual),
693            1.0,
694            epsilon = 1e-10
695        );
696    }
697
698    #[test]
699    fn test_calibration_result() {
700        let result = CalibrationResult {
701            predicted: Array1::from_vec(vec![0.1, 0.5, 0.9]),
702            observed: Array1::from_vec(vec![0.1, 0.5, 0.9]),
703            slope: 1.0,
704            intercept: 0.0,
705        };
706
707        assert!(result.is_well_calibrated(0.1, 0.1));
708        assert_relative_eq!(result.calibration_error(), 0.0, epsilon = 1e-10);
709    }
710
711    #[test]
712    fn test_kaplan_meier() {
713        let times = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
714        let events = Array1::from_vec(vec![true, false, true, false, true]);
715
716        let km = kaplan_meier(&times, &events);
717        assert_eq!(km.times.len(), 5);
718        assert!(km.survival[0] < 1.0);
719        assert!(km.survival.last().unwrap() < &km.survival[0]);
720    }
721
722    #[test]
723    fn test_concordance_uncensored_only() {
724        let predictions = Array1::from_vec(vec![5.0, 4.0, 3.0]);
725        let times = Array1::from_vec(vec![1.0, 2.0, 3.0]);
726        let events = Array1::from_vec(vec![true, true, true]);
727
728        let c_index = concordance_index_uncensored_only(&predictions, &times, &events);
729        assert_relative_eq!(c_index, 1.0, epsilon = 1e-10);
730    }
731}