use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::Array2;
use scirs2_stats::survival::{CoxPH, KaplanMeier as KMInner, NelsonAalen as NAInner};
pub struct KaplanMeier {
inner: KMInner,
}
impl KaplanMeier {
pub fn fit(times: &[f64], events: &[bool]) -> Result<Self> {
let inner = KMInner::fit(times, events)
.map_err(|e| NumRs2Error::ComputationError(e.to_string()))?;
Ok(Self { inner })
}
pub fn survival_at(&self, t: f64) -> f64 {
self.inner.survival_at(t)
}
pub fn confidence_interval(&self, t: f64, alpha: f64) -> Result<(f64, f64)> {
if alpha <= 0.0 || alpha >= 1.0 {
return Err(NumRs2Error::InvalidInput(format!(
"alpha must be in (0, 1), got {alpha}"
)));
}
let s = self.survival_at(t);
if s <= 0.0 || s >= 1.0 {
let clamped = s.clamp(0.0, 1.0);
return Ok((clamped, clamped));
}
let greenwood: f64 = self
.inner
.times
.iter()
.enumerate()
.take_while(|(_, &tk)| tk <= t)
.map(|(k, _)| {
let n_k = self.inner.n_at_risk[k] as f64;
let d_k = self.inner.n_events[k] as f64;
if n_k > d_k {
d_k / (n_k * (n_k - d_k))
} else {
0.0
}
})
.sum();
if greenwood == 0.0 {
return Ok((s, s));
}
let z = norm_ppf(1.0 - alpha / 2.0);
let ln_s = s.ln();
let se_ll = (greenwood / (ln_s * ln_s)).sqrt();
let log_log_s = (-ln_s).ln();
let ll_lo = log_log_s - z * se_ll;
let ll_hi = log_log_s + z * se_ll;
let lower = (-ll_hi.exp()).exp().clamp(0.0, 1.0);
let upper = (-ll_lo.exp()).exp().clamp(0.0, 1.0);
Ok((lower.min(upper), lower.max(upper)))
}
pub fn median_survival(&self) -> Option<f64> {
self.inner.median_survival()
}
pub fn mean_survival(&self) -> f64 {
self.inner.mean_survival()
}
pub fn event_times(&self) -> &[f64] {
&self.inner.times
}
pub fn survival_probabilities(&self) -> &[f64] {
&self.inner.survival
}
pub fn n_at_risk(&self) -> &[usize] {
&self.inner.n_at_risk
}
pub fn n_events(&self) -> &[usize] {
&self.inner.n_events
}
}
pub struct NelsonAalen {
inner: NAInner,
}
impl NelsonAalen {
pub fn fit(times: &[f64], events: &[bool]) -> Result<Self> {
let inner = NAInner::fit(times, events)
.map_err(|e| NumRs2Error::ComputationError(e.to_string()))?;
Ok(Self { inner })
}
pub fn cumulative_hazard_at(&self, t: f64) -> f64 {
self.inner.hazard_at(t)
}
pub fn survival_at(&self, t: f64) -> f64 {
self.inner.survival_at(t)
}
pub fn confidence_interval(&self, t: f64, alpha: f64) -> Result<(f64, f64)> {
if alpha <= 0.0 || alpha >= 1.0 {
return Err(NumRs2Error::InvalidInput(format!(
"alpha must be in (0, 1), got {alpha}"
)));
}
let s = self.survival_at(t);
if s <= 0.0 || s >= 1.0 {
let clamped = s.clamp(0.0, 1.0);
return Ok((clamped, clamped));
}
let var_h: f64 = {
let times_ref = &self.inner.times;
let cumhaz_ref = &self.inner.cumulative_hazard;
let mut var_acc = 0.0_f64;
let mut prev_h = 0.0_f64;
for (k, &tk) in times_ref.iter().enumerate() {
if tk > t {
break;
}
let increment = cumhaz_ref[k] - prev_h;
var_acc += increment * increment;
prev_h = cumhaz_ref[k];
}
var_acc
};
if var_h == 0.0 {
return Ok((s, s));
}
let h = -s.ln();
let z = norm_ppf(1.0 - alpha / 2.0);
let se = var_h.sqrt();
let c = (z * se / h).exp();
let h_lo = h / c;
let h_hi = h * c;
let upper = (-h_lo).exp().clamp(0.0, 1.0);
let lower = (-h_hi).exp().clamp(0.0, 1.0);
Ok((lower.min(upper), lower.max(upper)))
}
pub fn event_times(&self) -> &[f64] {
&self.inner.times
}
pub fn cumulative_hazard_values(&self) -> &[f64] {
&self.inner.cumulative_hazard
}
}
pub fn log_rank_test(
times1: &[f64],
events1: &[bool],
times2: &[f64],
events2: &[bool],
) -> Result<(f64, f64)> {
KMInner::log_rank_test(times1, events1, times2, events2)
.map_err(|e| NumRs2Error::ComputationError(e.to_string()))
}
pub struct CoxProportionalHazards {
inner: CoxPH,
}
impl CoxProportionalHazards {
pub fn fit(times: &[f64], events: &[bool], covariates: &Array2<f64>) -> Result<Self> {
let inner = CoxPH::fit(times, events, covariates)
.map_err(|e| NumRs2Error::ComputationError(e.to_string()))?;
Ok(Self { inner })
}
pub fn coefficients(&self) -> Vec<f64> {
self.inner.coefficients.iter().copied().collect()
}
pub fn standard_errors(&self) -> Vec<f64> {
self.inner.std_errors.iter().copied().collect()
}
pub fn p_values(&self) -> Vec<f64> {
self.inner.p_values.iter().copied().collect()
}
pub fn hazard_ratio(&self) -> Vec<f64> {
self.inner.hazard_ratio().iter().copied().collect()
}
pub fn predict_survival(&self, x: &[f64]) -> f64 {
use scirs2_core::ndarray::Array1;
let arr = Array1::from_vec(x.to_vec());
self.inner.predict_risk(&arr)
}
pub fn concordance_index(
&self,
times: &[f64],
events: &[bool],
covariates: &Array2<f64>,
) -> f64 {
self.inner.concordance_index(times, events, covariates)
}
pub fn log_likelihood(&self) -> f64 {
self.inner.log_likelihood
}
pub fn n_iterations(&self) -> usize {
self.inner.n_iter
}
}
fn norm_ppf(p: f64) -> f64 {
let p = p.clamp(1e-15, 1.0 - 1e-15);
let q = p - 0.5;
if q.abs() <= 0.42 {
let r = q * q;
q * ((((-25.445_87 * r + 41.391_663) * r - 18.615_43) * r + 2.506_628)
/ ((((3.130_347 * r - 21.060_244) * r + 23.083_928) * r - 8.476_377) * r + 1.0))
} else {
let r = if q < 0.0 { p } else { 1.0 - p };
let r = (-r.ln()).sqrt();
let x = (((2.321_213_5 * r + 4.850_091_7) * r - 2.297_460_0) * r - 2.787_688_0)
/ ((1.637_547_9 * r + 3.543_889_2) * r + 1.0);
if q < 0.0 {
-x
} else {
x
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_km_basic_monotone_decreasing() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let events = [true, true, false, true, false, true];
let km = KaplanMeier::fit(×, &events).expect("KM fit should succeed");
assert_eq!(km.survival_at(0.0), 1.0);
let s_vals: Vec<f64> = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]
.iter()
.map(|&t| km.survival_at(t))
.collect();
for pair in s_vals.windows(2) {
assert!(
pair[0] >= pair[1],
"S(t) must be non-increasing; got {} then {}",
pair[0],
pair[1]
);
}
}
#[test]
fn test_km_exact_small_dataset() {
let times = [1.0, 2.0, 3.0];
let events = [true, true, true];
let km = KaplanMeier::fit(×, &events).expect("KM fit");
let eps = 1e-10;
assert!(
(km.survival_at(1.0) - 2.0 / 3.0).abs() < eps,
"S(1) should be 2/3, got {}",
km.survival_at(1.0)
);
assert!(
(km.survival_at(2.0) - 1.0 / 3.0).abs() < eps,
"S(2) should be 1/3, got {}",
km.survival_at(2.0)
);
assert!(
(km.survival_at(3.0) - 0.0).abs() < eps,
"S(3) should be 0, got {}",
km.survival_at(3.0)
);
}
#[test]
fn test_km_with_censoring() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0];
let events = [true, false, true, false, true];
let km = KaplanMeier::fit(×, &events).expect("KM fit with censoring");
assert_eq!(km.survival_at(0.9), 1.0);
assert!(km.survival_at(5.0) < 1.0);
assert!(km.survival_at(10.0) >= 0.0);
}
#[test]
fn test_km_confidence_interval_valid() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let events = [true, false, true, true, false, true, false, true];
let km = KaplanMeier::fit(×, &events).expect("KM fit");
for &t in &[2.0, 4.0, 6.0, 8.0] {
let (lo, hi) = km.confidence_interval(t, 0.05).expect("CI computation");
assert!(
lo <= hi,
"CI lower {} must be <= upper {} at t={}",
lo,
hi,
t
);
assert!(lo >= 0.0, "CI lower must be >= 0 at t={}", t);
assert!(hi <= 1.0, "CI upper must be <= 1 at t={}", t);
let s = km.survival_at(t);
assert!(
lo <= s + 1e-10 && s <= hi + 1e-10,
"S({t}) = {s} should be inside CI [{lo}, {hi}]"
);
}
}
#[test]
fn test_km_ci_invalid_alpha() {
let times = [1.0, 2.0];
let events = [true, true];
let km = KaplanMeier::fit(×, &events).expect("KM fit");
assert!(
km.confidence_interval(1.0, 0.0).is_err(),
"alpha=0 should fail"
);
assert!(
km.confidence_interval(1.0, 1.0).is_err(),
"alpha=1 should fail"
);
assert!(
km.confidence_interval(1.0, -0.1).is_err(),
"negative alpha should fail"
);
}
#[test]
fn test_na_consistency_with_survival() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let events = [true, false, true, true, false, true];
let na = NelsonAalen::fit(×, &events).expect("NA fit");
for &t in &[0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0] {
let h = na.cumulative_hazard_at(t);
let s = na.survival_at(t);
assert!(
(s - (-h).exp()).abs() < 1e-10,
"S({t}) = {s} but exp(-H({t})) = {}",
(-h).exp()
);
}
let h_vals: Vec<f64> = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|&t| na.cumulative_hazard_at(t))
.collect();
for pair in h_vals.windows(2) {
assert!(
pair[0] <= pair[1] + 1e-14,
"H(t) must be non-decreasing; got {} then {}",
pair[0],
pair[1]
);
}
}
#[test]
fn test_na_exact_values_no_censoring() {
let times = [1.0, 2.0, 3.0, 4.0];
let events = [true, true, true, true];
let na = NelsonAalen::fit(×, &events).expect("NA fit");
let h1_expected = 1.0 / 4.0;
let h2_expected = 1.0 / 4.0 + 1.0 / 3.0;
let h3_expected = 1.0 / 4.0 + 1.0 / 3.0 + 1.0 / 2.0;
let eps = 1e-10;
assert!(
(na.cumulative_hazard_at(1.0) - h1_expected).abs() < eps,
"H(1) expected {h1_expected}, got {}",
na.cumulative_hazard_at(1.0)
);
assert!(
(na.cumulative_hazard_at(2.0) - h2_expected).abs() < eps,
"H(2) expected {h2_expected}, got {}",
na.cumulative_hazard_at(2.0)
);
assert!(
(na.cumulative_hazard_at(3.0) - h3_expected).abs() < eps,
"H(3) expected {h3_expected}, got {}",
na.cumulative_hazard_at(3.0)
);
}
#[test]
fn test_log_rank_identical_groups() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0];
let events = [true, false, true, false, true];
let (chi2, pval) =
log_rank_test(×, &events, ×, &events).expect("log-rank identical");
assert!(
chi2 < 1e-10,
"chi2 for identical groups should be ~0, got {chi2}"
);
assert!(
pval > 0.9,
"p-value for identical groups should be ~1, got {pval}"
);
}
#[test]
fn test_log_rank_different_groups() {
let t1 = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5];
let e1 = [true, true, true, true, true, true, true, true];
let t2 = [10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0];
let e2 = [true, true, true, true, true, true, true, true];
let (chi2, pval) = log_rank_test(&t1, &e1, &t2, &e2).expect("log-rank different");
assert!(
chi2 > 5.0,
"chi2 should be large for clearly different groups, got {chi2}"
);
assert!(
pval < 0.05,
"p-value should be small for different groups, got {pval}"
);
}
#[test]
fn test_log_rank_empty_group_error() {
let result = log_rank_test(&[], &[], &[1.0], &[true]);
assert!(result.is_err(), "empty group should return error");
}
#[test]
fn test_cox_ph_coefficient_sign() {
let times = [5.0, 4.0, 3.0, 2.0, 1.0, 0.5];
let events = [true, true, true, true, true, true];
let cov_data = vec![0.1_f64, 0.3, 0.5, 0.7, 0.9, 1.1];
let cov = Array2::from_shape_vec((6, 1), cov_data).expect("cov matrix");
let model =
CoxProportionalHazards::fit(×, &events, &cov).expect("Cox PH fit should succeed");
let coeffs = model.coefficients();
assert_eq!(coeffs.len(), 1, "one coefficient for one feature");
assert!(
coeffs[0] > 0.0,
"coefficient should be positive (high covariate => high hazard), got {}",
coeffs[0]
);
}
#[test]
fn test_cox_ph_hazard_ratio_consistency() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let events = [true, false, true, true, false, true];
let cov_data = vec![0.5_f64, -0.5, 0.0, 1.0, -1.0, 0.2];
let cov = Array2::from_shape_vec((6, 1), cov_data).expect("cov");
let model = CoxProportionalHazards::fit(×, &events, &cov).expect("Cox PH fit");
let coeffs = model.coefficients();
let hrs = model.hazard_ratio();
assert_eq!(coeffs.len(), hrs.len());
for (b, hr) in coeffs.iter().zip(hrs.iter()) {
assert!(
(hr - b.exp()).abs() < 1e-10,
"HR = exp(beta) should hold: exp({b}) = {} but got {hr}",
b.exp()
);
}
}
#[test]
fn test_cox_ph_concordance_index_range() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let events = [true, false, true, true, false, true, true, false];
let cov_data = vec![0.1, 0.9, 0.3, 0.7, 0.2, 0.8, 0.4, 0.6_f64];
let cov = Array2::from_shape_vec((8, 1), cov_data).expect("cov");
let model = CoxProportionalHazards::fit(×, &events, &cov).expect("Cox PH fit");
let c = model.concordance_index(×, &events, &cov);
assert!(
(0.0..=1.0).contains(&c),
"C-statistic must be in [0,1], got {c}"
);
}
#[test]
fn test_cox_ph_p_values_valid() {
let times = [1.0, 2.0, 3.0, 4.0, 5.0];
let events = [true, true, false, true, true];
let cov_data = vec![0.1, 0.5, 1.0, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6, 0.9_f64];
let cov = Array2::from_shape_vec((5, 2), cov_data).expect("cov");
let model =
CoxProportionalHazards::fit(×, &events, &cov).expect("Cox PH fit with 2 features");
for pv in model.p_values() {
assert!(
(0.0..=1.0).contains(&pv),
"p-value must be in [0,1], got {pv}"
);
}
}
}