use crate::error::{InferustError, Result};
use statrs::distribution::{ChiSquared, ContinuousCDF, Normal};
#[derive(Debug, Clone, Default)]
pub struct KaplanMeier;
#[derive(Debug, Clone)]
pub struct KmStep {
pub time: f64,
pub n_at_risk: usize,
pub n_events: usize,
pub survival: f64,
pub ci_lower: f64,
pub ci_upper: f64,
}
#[derive(Debug, Clone)]
pub struct KaplanMeierResult {
pub curve: Vec<KmStep>,
pub n: usize,
pub n_events: usize,
pub rmst: f64,
pub median_survival: Option<f64>,
}
impl KaplanMeier {
pub fn new() -> Self {
Self
}
pub fn fit(&self, times: &[f64], events: &[usize]) -> Result<KaplanMeierResult> {
let n = times.len();
if n < 1 {
return Err(InferustError::InsufficientData { needed: 1, got: 0 });
}
if events.len() != n {
return Err(InferustError::DimensionMismatch {
x_rows: events.len(),
y_len: n,
});
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
times[a]
.partial_cmp(×[b])
.unwrap_or(std::cmp::Ordering::Equal)
.then(events[b].cmp(&events[a])) });
let mut curve: Vec<KmStep> = Vec::new();
let mut survival = 1.0_f64;
let mut greenwood_sum = 0.0_f64; let mut n_at_risk = n;
let total_events: usize = events.iter().sum();
let mut i = 0;
while i < n {
let t = times[order[i]];
let mut d = 0usize; let mut c = 0usize; while i < n && (times[order[i]] - t).abs() < f64::EPSILON {
if events[order[i]] == 1 {
d += 1;
} else {
c += 1;
}
i += 1;
}
if d > 0 {
let factor = 1.0 - d as f64 / n_at_risk as f64;
survival *= factor;
if n_at_risk > d {
greenwood_sum += d as f64 / (n_at_risk as f64 * (n_at_risk - d) as f64);
}
let se = survival * (greenwood_sum.sqrt());
let z = Normal::new(0.0, 1.0).unwrap().inverse_cdf(0.975);
let (ci_lower, ci_upper) = if survival > 0.0 && survival < 1.0 {
let log_s = survival.ln();
let log_se = se / (survival * log_s.abs().max(f64::EPSILON));
let (lo, hi) = (log_s * (1.0 + z * log_se), log_s * (1.0 - z * log_se));
(lo.exp().clamp(0.0, 1.0), hi.exp().clamp(0.0, 1.0))
} else {
(0.0, 1.0)
};
curve.push(KmStep {
time: t,
n_at_risk,
n_events: d,
survival,
ci_lower,
ci_upper,
});
}
n_at_risk -= d + c;
}
let rmst = compute_rmst(&curve);
let median_survival = curve.iter().find(|s| s.survival <= 0.5).map(|s| s.time);
Ok(KaplanMeierResult {
curve,
n,
n_events: total_events,
rmst,
median_survival,
})
}
}
impl KaplanMeierResult {
pub fn survival_at(&self, t: f64) -> f64 {
let mut s = 1.0;
for step in &self.curve {
if step.time > t {
break;
}
s = step.survival;
}
s
}
pub fn print_summary(&self) {
println!();
println!("── Kaplan-Meier Survival Estimate ─────────────────────────────────");
println!(
" n = {} events = {} median survival = {}",
self.n,
self.n_events,
self.median_survival
.map_or("undefined".to_string(), |m| format!("{m:.3}"))
);
println!(" RMST = {:.4}", self.rmst);
println!();
println!(
"{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}",
"Time", "N.risk", "N.event", "Survival", "CI lower", "CI upper"
);
println!("{}", "─".repeat(65));
for s in &self.curve {
println!(
"{:>10.3} {:>10} {:>10} {:>10.4} {:>10.4} {:>10.4}",
s.time, s.n_at_risk, s.n_events, s.survival, s.ci_lower, s.ci_upper
);
}
println!();
}
}
fn compute_rmst(curve: &[KmStep]) -> f64 {
let mut area = 0.0;
let mut prev_t = 0.0;
let mut prev_s = 1.0;
for step in curve {
area += prev_s * (step.time - prev_t);
prev_t = step.time;
prev_s = step.survival;
}
area
}
#[derive(Debug, Clone)]
pub struct LogRankResult {
pub statistic: f64,
pub p_value: f64,
}
impl LogRankResult {
pub fn print(&self) {
println!();
println!("── Log-Rank Test ──────────────────────────────────────");
println!(
" χ²({}) = {:.4} p = {:.6}",
1, self.statistic, self.p_value
);
if self.p_value < 0.05 {
println!(" ✓ Reject H₀ (p < 0.05): survival curves differ.");
} else {
println!(" ✗ Fail to reject H₀ (p ≥ 0.05).");
}
println!();
}
}
pub fn log_rank_test(
times1: &[f64],
events1: &[usize],
times2: &[f64],
events2: &[usize],
) -> Result<LogRankResult> {
let n1 = times1.len();
let n2 = times2.len();
if n1 < 1 || n2 < 1 {
return Err(InferustError::InsufficientData {
needed: 1,
got: n1.min(n2),
});
}
let mut event_times: Vec<f64> = times1
.iter()
.zip(events1.iter())
.filter(|(_, &e)| e == 1)
.map(|(&t, _)| t)
.chain(
times2
.iter()
.zip(events2.iter())
.filter(|(_, &e)| e == 1)
.map(|(&t, _)| t),
)
.collect();
event_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
event_times.dedup_by(|a, b| (*a - *b).abs() < f64::EPSILON);
let mut o_e = 0.0_f64; let mut var = 0.0_f64;
for &t in &event_times {
let n_at_risk_1 = times1.iter().filter(|&&ti| ti >= t).count();
let n_at_risk_2 = times2.iter().filter(|&&ti| ti >= t).count();
let d1 = times1
.iter()
.zip(events1.iter())
.filter(|(&ti, &ei)| (ti - t).abs() < f64::EPSILON && ei == 1)
.count();
let d2 = times2
.iter()
.zip(events2.iter())
.filter(|(&ti, &ei)| (ti - t).abs() < f64::EPSILON && ei == 1)
.count();
let n = (n_at_risk_1 + n_at_risk_2) as f64;
let d = (d1 + d2) as f64;
if n < 2.0 {
continue;
}
let e1 = n_at_risk_1 as f64 * d / n;
o_e += d1 as f64 - e1;
let n1f = n_at_risk_1 as f64;
let n2f = n_at_risk_2 as f64;
var += n1f * n2f * d * (n - d) / (n * n * (n - 1.0));
}
let stat = if var > f64::EPSILON {
o_e * o_e / var
} else {
0.0
};
let chi = ChiSquared::new(1.0)
.map_err(|_| InferustError::InvalidInput("chi-squared distribution error".into()))?;
let p = 1.0 - chi.cdf(stat);
Ok(LogRankResult {
statistic: stat,
p_value: p,
})
}
#[derive(Debug, Clone)]
pub struct CoxPh {
feature_names: Vec<String>,
max_iter: usize,
tolerance: f64,
}
#[derive(Debug, Clone)]
pub struct CoxPhResult {
pub coefficients: Vec<f64>,
pub hazard_ratios: Vec<f64>,
pub std_errors: Vec<f64>,
pub z_statistics: Vec<f64>,
pub p_values: Vec<f64>,
pub hr_ci: Vec<(f64, f64)>,
pub log_likelihood: f64,
pub lr_statistic: f64,
pub lr_p_value: f64,
pub feature_names: Vec<String>,
pub n: usize,
pub n_events: usize,
pub iterations: usize,
}
impl Default for CoxPh {
fn default() -> Self {
Self::new()
}
}
impl CoxPh {
pub fn new() -> Self {
Self {
feature_names: Vec::new(),
max_iter: 200,
tolerance: 1e-8,
}
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = names;
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn fit(&self, times: &[f64], events: &[usize], x: &[Vec<f64>]) -> Result<CoxPhResult> {
let n = times.len();
if n < 2 {
return Err(InferustError::InsufficientData { needed: 2, got: n });
}
if events.len() != n || x.len() != n {
return Err(InferustError::DimensionMismatch {
x_rows: x.len(),
y_len: n,
});
}
let p = x[0].len();
if p == 0 {
return Err(InferustError::InvalidInput(
"CoxPh requires at least one covariate".into(),
));
}
for row in x.iter() {
if row.len() != p {
return Err(InferustError::DimensionMismatch {
x_rows: row.len(),
y_len: p,
});
}
}
let n_events: usize = events.iter().sum();
if n_events == 0 {
return Err(InferustError::InvalidInput("no events in the data".into()));
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
times[a]
.partial_cmp(×[b])
.unwrap_or(std::cmp::Ordering::Equal)
.then(events[b].cmp(&events[a]))
});
let t_sorted: Vec<f64> = order.iter().map(|&i| times[i]).collect();
let e_sorted: Vec<usize> = order.iter().map(|&i| events[i]).collect();
let x_sorted: Vec<Vec<f64>> = order.iter().map(|&i| x[i].clone()).collect();
let mut beta = vec![0.0_f64; p];
let null_ll = cox_partial_ll(&t_sorted, &e_sorted, &x_sorted, &vec![0.0_f64; p]);
let mut iterations = 0;
for iter in 0..self.max_iter {
let (score, hessian) = cox_score_hessian(&t_sorted, &e_sorted, &x_sorted, &beta);
let step = solve_linear_regularized(hessian, score)?;
let max_step: f64 = step.iter().map(|s| s.abs()).fold(0.0_f64, f64::max);
for j in 0..p {
beta[j] += step[j];
}
iterations = iter + 1;
if max_step < self.tolerance {
break;
}
}
let ll = cox_partial_ll(&t_sorted, &e_sorted, &x_sorted, &beta);
let lr_stat = -2.0 * (null_ll - ll);
let chi = ChiSquared::new(p as f64)
.map_err(|_| InferustError::InvalidInput("chi-squared distribution error".into()))?;
let lr_p = 1.0 - chi.cdf(lr_stat.max(0.0));
let (_, hessian) = cox_score_hessian(&t_sorted, &e_sorted, &x_sorted, &beta);
let var_cov = invert_symmetric_regularized(hessian)?;
let se: Vec<f64> = (0..p).map(|j| var_cov[j][j].abs().sqrt()).collect();
let normal = Normal::new(0.0, 1.0)
.map_err(|_| InferustError::InvalidInput("normal distribution error".into()))?;
let z: Vec<f64> = beta.iter().zip(se.iter()).map(|(b, s)| b / s).collect();
let pv: Vec<f64> = z
.iter()
.map(|&z| 2.0 * (1.0 - normal.cdf(z.abs())))
.collect();
let hr: Vec<f64> = beta.iter().map(|b| b.exp()).collect();
let z196 = normal.inverse_cdf(0.975);
let hr_ci: Vec<(f64, f64)> = beta
.iter()
.zip(se.iter())
.map(|(b, s)| ((b - z196 * s).exp(), (b + z196 * s).exp()))
.collect();
let feat = if self.feature_names.len() == p {
self.feature_names.clone()
} else {
(1..=p).map(|i| format!("x{i}")).collect()
};
Ok(CoxPhResult {
coefficients: beta,
hazard_ratios: hr,
std_errors: se,
z_statistics: z,
p_values: pv,
hr_ci,
log_likelihood: ll,
lr_statistic: lr_stat,
lr_p_value: lr_p,
feature_names: feat,
n,
n_events,
iterations,
})
}
}
impl CoxPhResult {
pub fn print_summary(&self) {
println!();
println!("══════════════════════════════════════════════════════════════════");
println!(" Cox Proportional Hazards Model");
println!("══════════════════════════════════════════════════════════════════");
println!(
" n = {} events = {} iterations = {}",
self.n, self.n_events, self.iterations
);
println!(" Log-likelihood: {:.4}", self.log_likelihood);
println!(
" LR χ²({}) = {:.4} p = {:.6}",
self.feature_names.len(),
self.lr_statistic,
self.lr_p_value
);
println!("──────────────────────────────────────────────────────────────────");
println!(
"{:<18} {:>9} {:>9} {:>8} {:>9} {:>12}",
"Variable", "coef", "HR", "SE", "z", "P>|z|"
);
println!("{}", "─".repeat(66));
for i in 0..self.coefficients.len() {
let (hr_lo, hr_hi) = self.hr_ci[i];
println!(
"{:<18} {:>9.4} {:>9.4} {:>8.4} {:>9.4} {:>9.6} {} [{:.4}, {:.4}]",
self.feature_names[i],
self.coefficients[i],
self.hazard_ratios[i],
self.std_errors[i],
self.z_statistics[i],
self.p_values[i],
sig_stars(self.p_values[i]),
hr_lo,
hr_hi
);
}
println!("──────────────────────────────────────────────────────────────────");
println!(" Significance: *** p<0.001 ** p<0.01 * p<0.05 . p<0.1");
println!("══════════════════════════════════════════════════════════════════");
println!();
}
}
fn cox_partial_ll(times: &[f64], events: &[usize], x: &[Vec<f64>], beta: &[f64]) -> f64 {
let n = times.len();
let xb: Vec<f64> = x
.iter()
.map(|row| row.iter().zip(beta).map(|(xi, b)| xi * b).sum())
.collect();
let exp_xb: Vec<f64> = xb.iter().map(|v| v.exp()).collect();
let mut ll = 0.0;
for i in 0..n {
if events[i] != 1 {
continue;
}
let risk_sum: f64 = (i..n).map(|j| exp_xb[j]).sum();
ll += xb[i] - risk_sum.ln();
}
ll
}
fn cox_score_hessian(
times: &[f64],
events: &[usize],
x: &[Vec<f64>],
beta: &[f64],
) -> (Vec<f64>, Vec<Vec<f64>>) {
let n = times.len();
let p = beta.len();
let exp_xb: Vec<f64> = x
.iter()
.map(|row| {
row.iter()
.zip(beta)
.map(|(xi, b)| xi * b)
.sum::<f64>()
.exp()
})
.collect();
let mut score = vec![0.0_f64; p];
let mut info = vec![vec![0.0_f64; p]; p];
for i in 0..n {
if events[i] != 1 {
continue;
}
let mut s0 = 0.0_f64;
let mut s1 = vec![0.0_f64; p];
let mut s2 = vec![vec![0.0_f64; p]; p];
for j in i..n {
let w = exp_xb[j];
s0 += w;
for k in 0..p {
s1[k] += w * x[j][k];
for l in 0..p {
s2[k][l] += w * x[j][k] * x[j][l];
}
}
}
if s0 < f64::EPSILON {
continue;
}
for k in 0..p {
score[k] += x[i][k] - s1[k] / s0;
for l in 0..p {
info[k][l] += s2[k][l] / s0 - (s1[k] / s0) * (s1[l] / s0);
}
}
}
(score, info)
}
fn solve_linear(mut a: Vec<Vec<f64>>, mut b: Vec<f64>) -> Result<Vec<f64>> {
let n = b.len();
for col in 0..n {
let max_row = (col..n)
.max_by(|&r1, &r2| {
a[r1][col]
.abs()
.partial_cmp(&a[r2][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(col);
a.swap(col, max_row);
b.swap(col, max_row);
let pivot = a[col][col];
if pivot.abs() < f64::EPSILON {
return Err(InferustError::InvalidInput(
"singular information matrix — model may be under-identified".into(),
));
}
for row in (col + 1)..n {
let factor = a[row][col] / pivot;
let pivot_tail: Vec<f64> = a[col][col..n].to_vec();
for (value, pivot_value) in a[row][col..n].iter_mut().zip(pivot_tail.iter()) {
*value -= factor * pivot_value;
}
b[row] -= factor * b[col];
}
}
let mut x = vec![0.0_f64; n];
for i in (0..n).rev() {
x[i] = b[i];
for j in (i + 1)..n {
x[i] -= a[i][j] * x[j];
}
x[i] /= a[i][i];
}
Ok(x)
}
fn solve_linear_regularized(a: Vec<Vec<f64>>, b: Vec<f64>) -> Result<Vec<f64>> {
match solve_linear(a.clone(), b.clone()) {
Ok(solution) => Ok(solution),
Err(_) => {
let mut regularized = a;
let ridge = diagonal_scale(®ularized) * 1e-8;
for (i, row) in regularized.iter_mut().enumerate() {
row[i] += ridge;
}
solve_linear(regularized, b)
}
}
}
fn invert_symmetric(a: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>> {
let n = a.len();
let mut aug: Vec<Vec<f64>> = a
.iter()
.enumerate()
.map(|(i, row)| {
let mut r = row.clone();
r.extend((0..n).map(|j| if i == j { 1.0 } else { 0.0 }));
r
})
.collect();
for col in 0..n {
let max_row = (col..n)
.max_by(|&r1, &r2| {
aug[r1][col]
.abs()
.partial_cmp(&aug[r2][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(col);
aug.swap(col, max_row);
let pivot = aug[col][col];
if pivot.abs() < f64::EPSILON {
return Err(InferustError::InvalidInput(
"information matrix is singular".into(),
));
}
let inv_pivot = 1.0 / pivot;
for value in aug[col].iter_mut().take(2 * n) {
*value *= inv_pivot;
}
for row in 0..n {
if row == col {
continue;
}
let factor = aug[row][col];
let pivot_row = aug[col].clone();
for (value, pivot_value) in aug[row].iter_mut().zip(pivot_row.iter()).take(2 * n) {
*value -= factor * pivot_value;
}
}
}
Ok(aug.iter().map(|row| row[n..].to_vec()).collect())
}
fn invert_symmetric_regularized(a: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>> {
match invert_symmetric(a.clone()) {
Ok(inverse) => Ok(inverse),
Err(_) => {
let mut regularized = a;
let ridge = diagonal_scale(®ularized) * 1e-8;
for (i, row) in regularized.iter_mut().enumerate() {
row[i] += ridge;
}
invert_symmetric(regularized)
}
}
}
fn diagonal_scale(a: &[Vec<f64>]) -> f64 {
let scale = a
.iter()
.enumerate()
.map(|(i, row)| row[i].abs())
.fold(0.0_f64, f64::max);
scale.max(1.0)
}
fn sig_stars(p: f64) -> &'static str {
if p < 0.001 {
"***"
} else if p < 0.01 {
"**"
} else if p < 0.05 {
"*"
} else if p < 0.1 {
"."
} else {
""
}
}
#[cfg(test)]
mod tests {
use super::{log_rank_test, CoxPh, KaplanMeier};
#[test]
fn km_basic_curve() {
let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let events = vec![1, 1, 0, 1, 0, 1];
let km = KaplanMeier::new().fit(×, &events).unwrap();
assert_eq!(km.n, 6);
assert_eq!(km.n_events, 4);
assert!(km.curve[0].survival < 1.0);
for w in km.curve.windows(2) {
assert!(w[0].survival >= w[1].survival);
}
}
#[test]
fn km_all_censored_still_works() {
let times = vec![5.0, 10.0, 15.0];
let events = vec![0, 0, 0];
let km = KaplanMeier::new().fit(×, &events).unwrap();
assert_eq!(km.n_events, 0);
assert!(km.curve.is_empty());
}
#[test]
fn log_rank_different_groups() {
let t1 = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let e1 = vec![1, 1, 1, 1, 1];
let t2 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let e2 = vec![1, 1, 1, 1, 1];
let res = log_rank_test(&t1, &e1, &t2, &e2).unwrap();
assert!(res.p_value < 0.05, "expected p < 0.05, got {}", res.p_value);
}
#[test]
fn cox_ph_fits_simple_covariate() {
let times = vec![5.0, 8.0, 12.0, 20.0, 3.0, 7.0, 15.0, 22.0, 9.0, 11.0];
let events = vec![1, 1, 0, 1, 1, 0, 1, 1, 1, 0];
let x: Vec<Vec<f64>> = (0..10)
.map(|i| vec![if i < 5 { 0.0 } else { 1.0 }])
.collect();
let cox = CoxPh::new()
.with_feature_names(vec!["treatment".to_string()])
.fit(×, &events, &x)
.unwrap();
assert_eq!(cox.feature_names[0], "treatment");
assert_eq!(cox.coefficients.len(), 1);
assert!(cox.coefficients[0].is_finite());
assert!(cox.hazard_ratios[0] > 0.0);
}
#[test]
fn cox_ph_lr_test_statistic_positive() {
let times: Vec<f64> = (1..=20).map(|i| i as f64).collect();
let events: Vec<usize> = (0..20).map(|i| if i % 3 != 0 { 1 } else { 0 }).collect();
let x: Vec<Vec<f64>> = (0..20).map(|i| vec![(i as f64) / 10.0]).collect();
let cox = CoxPh::new().fit(×, &events, &x).unwrap();
assert!(cox.lr_statistic >= 0.0);
}
}