use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2};
fn erf_approx(x: f64) -> f64 {
let sign = if x >= 0.0 { 1.0 } else { -1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * x);
let poly = t
* (0.254_829_592
+ t * (-0.284_496_736
+ t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
sign * (1.0 - poly * (-x * x).exp())
}
fn norm_cdf(z: f64) -> f64 {
0.5 * (1.0 + erf_approx(z / std::f64::consts::SQRT_2))
}
fn lgamma(x: f64) -> f64 {
let c = [
0.999_999_999_999_809_93,
676.520_368_121_885_10,
-1_259.139_216_722_402_8,
771.323_428_777_653_10,
-176.615_029_162_140_60,
12.507_343_278_686_905,
-0.138_571_095_265_720_12,
9.984_369_578_019_572e-6,
1.505_632_735_149_311_6e-7,
];
let x = x - 1.0;
let mut ser = c[0];
for (i, &ci) in c[1..].iter().enumerate() {
ser += ci / (x + i as f64 + 1.0);
}
let tmp = x + 7.5;
0.5 * std::f64::consts::TAU.ln() + (x + 0.5) * tmp.ln() - tmp + ser.ln()
}
fn gamma_q(a: f64, x: f64) -> f64 {
if x < 0.0 {
return 1.0;
}
if x == 0.0 {
return 1.0;
}
if x < a + 1.0 {
let mut ap = a;
let mut sum = 1.0 / a;
let mut del = sum;
for _ in 0..200 {
ap += 1.0;
del *= x / ap;
sum += del;
if del.abs() < sum.abs() * 3e-15 {
break;
}
}
let p = sum * (-x + a * x.ln() - lgamma(a)).exp();
1.0 - p
} else {
let mut b = x + 1.0 - a;
let mut c = 1.0 / 1e-300;
let mut d = 1.0 / b;
let mut h = d;
for i in 1_i64..200 {
let an = -(i as f64) * (i as f64 - a);
b += 2.0;
d = an * d + b;
if d.abs() < 1e-300 {
d = 1e-300;
}
c = b + an / c;
if c.abs() < 1e-300 {
c = 1e-300;
}
d = 1.0 / d;
let del = d * c;
h *= del;
if (del - 1.0).abs() < 3e-15 {
break;
}
}
(-x + a * x.ln() - lgamma(a)).exp() * h
}
}
fn chi2_sf(x: f64, df: f64) -> f64 {
if x <= 0.0 {
return 1.0;
}
gamma_q(df / 2.0, x / 2.0)
}
#[derive(Debug, Clone)]
pub struct CoxPHModel {
pub coefficients: Array1<f64>,
pub baseline_hazard: Vec<(f64, f64)>,
pub feature_names: Vec<String>,
pub std_errors: Array1<f64>,
pub z_scores: Array1<f64>,
pub p_values: Array1<f64>,
pub log_likelihood: f64,
pub n_iter: usize,
pub converged: bool,
pub score_test: f64,
pub lr_test: f64,
pub wald_test: f64,
}
impl CoxPHModel {
pub fn fit(times: &[f64], events: &[bool], x: &Array2<f64>) -> StatsResult<Self> {
Self::fit_with_names(times, events, x, None)
}
pub fn fit_with_names(
times: &[f64],
events: &[bool],
x: &Array2<f64>,
feature_names: Option<Vec<String>>,
) -> StatsResult<Self> {
let n = times.len();
let p = x.ncols();
if n == 0 {
return Err(StatsError::InvalidArgument(
"times must not be empty".to_string(),
));
}
if events.len() != n {
return Err(StatsError::DimensionMismatch(format!(
"times length {} != events length {}",
n,
events.len()
)));
}
if x.nrows() != n {
return Err(StatsError::DimensionMismatch(format!(
"x rows {} != times length {}",
x.nrows(),
n
)));
}
if p == 0 {
return Err(StatsError::InvalidArgument(
"x must have at least one column".to_string(),
));
}
let n_events: usize = events.iter().filter(|&&e| e).count();
if n_events == 0 {
return Err(StatsError::InvalidArgument(
"No events observed".to_string(),
));
}
for &t in times {
if !t.is_finite() || t < 0.0 {
return Err(StatsError::InvalidArgument(format!(
"times must be finite and non-negative; got {t}"
)));
}
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
times[a]
.partial_cmp(×[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_times: Vec<f64> = order.iter().map(|&i| times[i]).collect();
let sorted_events: Vec<bool> = order.iter().map(|&i| events[i]).collect();
let sorted_x: Vec<Vec<f64>> = order
.iter()
.map(|&i| (0..p).map(|j| x[[i, j]]).collect())
.collect();
let x_mean: Vec<f64> = (0..p)
.map(|j| sorted_x.iter().map(|row| row[j]).sum::<f64>() / n as f64)
.collect();
let xc: Vec<Vec<f64>> = sorted_x
.iter()
.map(|row| (0..p).map(|j| row[j] - x_mean[j]).collect())
.collect();
let mut beta = vec![0.0_f64; p];
let max_iter = 200;
let tol = 1e-8;
let mut converged = false;
let mut n_iter = 0usize;
let ridge_lambda = 1e-3;
let ll_null =
partial_log_likelihood_breslow(&sorted_times, &sorted_events, &xc, &vec![0.0; p]);
for iter in 0..max_iter {
let (_ll, mut grad, mut hess) =
partial_ll_gradient_hessian(&sorted_times, &sorted_events, &xc, &beta);
for j in 0..p {
grad[j] -= ridge_lambda * beta[j];
hess[j * p + j] += ridge_lambda;
}
if grad.iter().any(|v| !v.is_finite()) || hess.iter().any(|v| !v.is_finite()) {
break;
}
let delta = solve_linear_system(&hess, &grad)?;
let step = backtrack_step(&sorted_times, &sorted_events, &xc, &beta, &delta, 20);
let max_delta = delta.iter().map(|d| d.abs()).fold(0.0_f64, f64::max);
for j in 0..p {
beta[j] += step * delta[j];
beta[j] = beta[j].clamp(-20.0, 20.0);
}
n_iter = iter + 1;
if max_delta * step < tol {
converged = true;
break;
}
}
let ll_final = partial_log_likelihood_breslow(&sorted_times, &sorted_events, &xc, &beta);
let (_, _grad_final, hess_final) =
partial_ll_gradient_hessian(&sorted_times, &sorted_events, &xc, &beta);
let mut hess_reg = hess_final.clone();
for j in 0..p {
hess_reg[j * p + j] += ridge_lambda;
}
let vcov = invert_matrix(&hess_reg)?;
let std_errors: Vec<f64> = (0..p).map(|j| vcov[j * p + j].max(0.0).sqrt()).collect();
let z_scores: Vec<f64> = (0..p)
.map(|j| beta[j] / std_errors[j].max(1e-300))
.collect();
let p_values: Vec<f64> = z_scores
.iter()
.map(|&z| 2.0 * (1.0 - norm_cdf(z.abs())))
.collect();
let (_, grad_null, hess_null) =
partial_ll_gradient_hessian(&sorted_times, &sorted_events, &xc, &vec![0.0; p]);
let vcov_null = invert_matrix(&hess_null).unwrap_or(vec![0.0; p * p]);
let score_test = quadratic_form(&grad_null, &vcov_null, p);
let lr_test = 2.0 * (ll_final - ll_null);
let wald_test = quadratic_form_vec_mat(&beta, &hess_final, p);
let risk_scores: Vec<f64> = (0..n)
.map(|i| {
let xb: f64 = (0..p).map(|j| xc[i][j] * beta[j]).sum();
xb.exp()
})
.collect();
let pairs: Vec<(f64, bool)> = sorted_times
.iter()
.copied()
.zip(sorted_events.iter().copied())
.collect();
let (bt, bh) = breslow_baseline_hazard(&risk_scores, &pairs);
let baseline_hazard: Vec<(f64, f64)> = bt.into_iter().zip(bh.into_iter()).collect();
let names = feature_names.unwrap_or_else(|| (0..p).map(|j| format!("x{j}")).collect());
Ok(Self {
coefficients: Array1::from_vec(beta),
baseline_hazard,
feature_names: names,
std_errors: Array1::from_vec(std_errors),
z_scores: Array1::from_vec(z_scores),
p_values: Array1::from_vec(p_values),
log_likelihood: ll_final,
n_iter,
converged,
score_test,
lr_test: lr_test.max(0.0),
wald_test: wald_test.max(0.0),
})
}
pub fn predict_hazard(&self, x_new: &Array2<f64>) -> Array1<f64> {
let n = x_new.nrows();
let p = self.coefficients.len();
let mut hazards = Array1::zeros(n);
for i in 0..n {
let xb: f64 = (0..p).map(|j| x_new[[i, j]] * self.coefficients[j]).sum();
hazards[i] = xb.exp();
}
hazards
}
pub fn predict_survival(&self, x_new: &Array2<f64>, t: f64) -> Array1<f64> {
let hazards = self.predict_hazard(x_new);
let h0 = self.baseline_cumulative_hazard_at(t);
let n = x_new.nrows();
let mut surv = Array1::zeros(n);
for i in 0..n {
surv[i] = (-h0 * hazards[i]).exp();
}
surv
}
fn baseline_cumulative_hazard_at(&self, t: f64) -> f64 {
if self.baseline_hazard.is_empty() || t < self.baseline_hazard[0].0 {
return 0.0;
}
let idx = self
.baseline_hazard
.partition_point(|&(tk, _)| tk <= t)
.saturating_sub(1);
self.baseline_hazard[idx].1
}
pub fn hazard_ratio(&self) -> Array1<f64> {
self.coefficients.mapv(f64::exp)
}
pub fn concordance_index(&self, times: &[f64], events: &[bool], x: &Array2<f64>) -> f64 {
let p = self.coefficients.len();
let n = times.len();
if n == 0 {
return 0.5;
}
let risk: Vec<f64> = (0..n)
.map(|i| {
(0..p)
.map(|j| x[[i, j]] * self.coefficients[j])
.sum::<f64>()
})
.collect();
let mut concordant = 0.0_f64;
let mut total = 0.0_f64;
for i in 0..n {
if !events[i] {
continue;
}
for j in 0..n {
if i == j {
continue;
}
if times[j] <= times[i] {
continue;
}
total += 1.0;
if risk[i] > risk[j] {
concordant += 1.0;
} else if (risk[i] - risk[j]).abs() < 1e-14 {
concordant += 0.5;
}
}
}
if total < 1.0 {
0.5
} else {
concordant / total
}
}
pub fn predict_log_hazard(&self, x_new: &Array2<f64>) -> Array1<f64> {
let n = x_new.nrows();
let p = self.coefficients.len();
let mut lp = Array1::zeros(n);
for i in 0..n {
lp[i] = (0..p).map(|j| x_new[[i, j]] * self.coefficients[j]).sum();
}
lp
}
}
fn partial_log_likelihood_breslow(
sorted_times: &[f64],
sorted_events: &[bool],
xc: &[Vec<f64>],
beta: &[f64],
) -> f64 {
let n = sorted_times.len();
let p = beta.len();
let exp_xb: Vec<f64> = (0..n)
.map(|i| {
let xb: f64 = (0..p).map(|j| xc[i][j] * beta[j]).sum();
xb.exp().max(1e-300)
})
.collect();
let mut ll = 0.0_f64;
let mut risk_set_sum = exp_xb.iter().sum::<f64>();
let mut i = 0usize;
while i < n {
let t_cur = sorted_times[i];
let mut j = i;
let mut d = 0usize;
let mut xb_sum = 0.0_f64;
while j < n && (sorted_times[j] - t_cur).abs() < 1e-14 {
if sorted_events[j] {
d += 1;
let xb: f64 = (0..p).map(|k| xc[j][k] * beta[k]).sum();
xb_sum += xb;
}
j += 1;
}
if d > 0 {
ll += xb_sum - d as f64 * risk_set_sum.ln();
}
for k in i..j {
risk_set_sum -= exp_xb[k];
}
risk_set_sum = risk_set_sum.max(1e-300);
i = j;
}
ll
}
fn partial_ll_gradient_hessian(
sorted_times: &[f64],
sorted_events: &[bool],
xc: &[Vec<f64>],
beta: &[f64],
) -> (f64, Vec<f64>, Vec<f64>) {
let n = sorted_times.len();
let p = beta.len();
let exp_xb: Vec<f64> = (0..n)
.map(|i| {
let xb: f64 = (0..p).map(|j| xc[i][j] * beta[j]).sum();
xb.exp().max(1e-300)
})
.collect();
let mut ll = 0.0_f64;
let mut grad = vec![0.0_f64; p];
let mut neg_hess = vec![0.0_f64; p * p];
let mut s0 = exp_xb.iter().sum::<f64>();
let mut s1: Vec<f64> = (0..p)
.map(|j| (0..n).map(|i| xc[i][j] * exp_xb[i]).sum::<f64>())
.collect();
let mut s2: Vec<f64> = {
let mut s = vec![0.0_f64; p * p];
for i in 0..n {
for j in 0..p {
for k in 0..p {
s[j * p + k] += xc[i][j] * xc[i][k] * exp_xb[i];
}
}
}
s
};
let mut i = 0usize;
while i < n {
let t_cur = sorted_times[i];
let mut j = i;
let mut d = 0usize;
while j < n && (sorted_times[j] - t_cur).abs() < 1e-14 {
if sorted_events[j] {
d += 1;
}
j += 1;
}
if d > 0 && s0 > 1e-300 {
ll += {
let mut xb_sum = 0.0_f64;
for k in i..j {
if sorted_events[k] {
xb_sum += (0..p).map(|l| xc[k][l] * beta[l]).sum::<f64>();
}
}
xb_sum - d as f64 * s0.ln()
};
for jj in 0..p {
let mut xb_col = 0.0_f64;
for k in i..j {
if sorted_events[k] {
xb_col += xc[k][jj];
}
}
grad[jj] += xb_col - d as f64 * s1[jj] / s0;
}
let e1: Vec<f64> = (0..p).map(|jj| s1[jj] / s0).collect();
for jj in 0..p {
for kk in 0..p {
let e2 = s2[jj * p + kk] / s0;
neg_hess[jj * p + kk] += d as f64 * (e2 - e1[jj] * e1[kk]);
}
}
}
for k in i..j {
s0 -= exp_xb[k];
for jj in 0..p {
s1[jj] -= xc[k][jj] * exp_xb[k];
for kk in 0..p {
s2[jj * p + kk] -= xc[k][jj] * xc[k][kk] * exp_xb[k];
}
}
}
s0 = s0.max(1e-300);
i = j;
}
(ll, grad, neg_hess)
}
fn solve_linear_system(hess: &[f64], grad: &[f64]) -> StatsResult<Vec<f64>> {
let p = grad.len();
if p == 0 {
return Ok(vec![]);
}
let mut h = hess.to_vec();
let lambda = 1e-8
* hess
.iter()
.map(|&v| v.abs())
.fold(0.0_f64, f64::max)
.max(1e-6);
for j in 0..p {
h[j * p + j] += lambda;
}
let mut l = vec![0.0_f64; p * p];
for i in 0..p {
for j in 0..=i {
let mut s: f64 = h[i * p + j];
for k in 0..j {
s -= l[i * p + k] * l[j * p + k];
}
if i == j {
if s < 1e-300 {
let scale = h.iter().map(|&v| v.abs()).fold(0.0_f64, f64::max).max(1.0);
return Ok(grad.iter().map(|&g| g / scale).collect());
}
l[i * p + j] = s.sqrt();
} else {
l[i * p + j] = s / l[j * p + j];
}
}
}
let mut y = vec![0.0_f64; p];
for i in 0..p {
let mut s = grad[i];
for k in 0..i {
s -= l[i * p + k] * y[k];
}
y[i] = s / l[i * p + i];
}
let mut delta = vec![0.0_f64; p];
for i in (0..p).rev() {
let mut s = y[i];
for k in (i + 1)..p {
s -= l[k * p + i] * delta[k];
}
delta[i] = s / l[i * p + i];
}
Ok(delta)
}
fn invert_matrix(hess: &[f64]) -> StatsResult<Vec<f64>> {
let p = (hess.len() as f64).sqrt() as usize;
if p * p != hess.len() {
return Err(StatsError::DimensionMismatch(
"Hessian is not square".to_string(),
));
}
let max_abs = hess
.iter()
.map(|&v| v.abs())
.fold(0.0_f64, f64::max)
.max(1e-12);
let regularisations = [
1e-8 * max_abs,
1e-6 * max_abs,
1e-4 * max_abs,
1e-2 * max_abs,
0.1 * max_abs,
max_abs,
10.0 * max_abs,
1.0_f64, 10.0_f64,
100.0_f64,
];
for (attempt, &lambda) in regularisations.iter().enumerate() {
let mut h = hess.to_vec();
for j in 0..p {
h[j * p + j] += lambda;
}
match cholesky_invert(&h, p) {
Ok(inv) => return Ok(inv),
Err(_) if attempt < regularisations.len() - 1 => continue,
Err(e) => return Err(e),
}
}
Err(StatsError::ComputationError(
"Hessian is not positive definite after escalating regularisation".to_string(),
))
}
fn cholesky_invert(h: &[f64], p: usize) -> StatsResult<Vec<f64>> {
let mut l = vec![0.0_f64; p * p];
for i in 0..p {
for j in 0..=i {
let mut s = h[i * p + j];
for k in 0..j {
s -= l[i * p + k] * l[j * p + k];
}
if i == j {
if s <= 1e-300 {
return Err(StatsError::ComputationError(
"Hessian is not positive definite (singular)".to_string(),
));
}
l[i * p + j] = s.sqrt();
} else {
if l[j * p + j].abs() < 1e-300 {
return Err(StatsError::ComputationError(
"Cholesky: near-zero diagonal".to_string(),
));
}
l[i * p + j] = s / l[j * p + j];
}
}
}
let mut l_inv = vec![0.0_f64; p * p];
for k in 0..p {
for i in 0..p {
let mut s = if i == k { 1.0 } else { 0.0 };
for j in 0..i {
s -= l[i * p + j] * l_inv[j * p + k];
}
l_inv[i * p + k] = s / l[i * p + i];
}
}
let mut inv = vec![0.0_f64; p * p];
for i in 0..p {
for j in 0..p {
let mut s = 0.0_f64;
for k in 0..p {
s += l_inv[k * p + i] * l_inv[k * p + j];
}
inv[i * p + j] = s;
}
}
Ok(inv)
}
fn quadratic_form_vec_mat(v: &[f64], a: &[f64], p: usize) -> f64 {
let mut result = 0.0_f64;
for i in 0..p {
let mut av_i = 0.0_f64;
for j in 0..p {
av_i += a[i * p + j] * v[j];
}
result += v[i] * av_i;
}
result
}
fn quadratic_form(v: &[f64], a: &[f64], p: usize) -> f64 {
quadratic_form_vec_mat(v, a, p)
}
fn backtrack_step(
sorted_times: &[f64],
sorted_events: &[bool],
xc: &[Vec<f64>],
beta: &[f64],
delta: &[f64],
max_halve: usize,
) -> f64 {
let ll_cur = partial_log_likelihood_breslow(sorted_times, sorted_events, xc, beta);
let p = beta.len();
let c = 1e-4;
let mut step = 1.0_f64;
for _ in 0..max_halve {
let beta_new: Vec<f64> = (0..p).map(|j| beta[j] + step * delta[j]).collect();
let ll_new = partial_log_likelihood_breslow(sorted_times, sorted_events, xc, &beta_new);
if ll_new > ll_cur - c * step * delta.iter().map(|d| d.abs()).sum::<f64>() {
return step;
}
step *= 0.5;
}
step
}
fn breslow_baseline_hazard(risk_scores: &[f64], pairs: &[(f64, bool)]) -> (Vec<f64>, Vec<f64>) {
let n = pairs.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| {
pairs[a]
.0
.partial_cmp(&pairs[b].0)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut times_out = Vec::new();
let mut hazard_out = Vec::new();
let mut cum_h = 0.0_f64;
let mut pos = 0usize;
while pos < n {
let t_cur = pairs[idx[pos]].0;
let mut d_k = 0usize;
let mut end = pos;
while end < n && (pairs[idx[end]].0 - t_cur).abs() < 1e-14 {
if pairs[idx[end]].1 {
d_k += 1;
}
end += 1;
}
if d_k > 0 {
let risk_set_sum: f64 = idx[pos..].iter().map(|&i| risk_scores[i]).sum();
if risk_set_sum > 1e-300 {
cum_h += d_k as f64 / risk_set_sum;
}
times_out.push(t_cur);
hazard_out.push(cum_h);
}
pos = end;
}
(times_out, hazard_out)
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_cox_data() -> (Vec<f64>, Vec<bool>, Array2<f64>) {
let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let events = vec![true, true, true, true, true, false, true, true, false, true];
let mut cov = Array2::zeros((10, 1));
for i in 0..10_usize {
cov[[i, 0]] = (10 - i) as f64;
}
(times, events, cov)
}
#[test]
fn test_cox_fit_basic() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit failed");
assert_eq!(model.coefficients.len(), 1);
assert!(model.n_iter > 0);
assert!(model.log_likelihood.is_finite());
}
#[test]
fn test_cox_coefficients_finite() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
for &c in model.coefficients.iter() {
assert!(c.is_finite(), "coefficient {c} must be finite");
}
}
#[test]
fn test_cox_std_errors_positive() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
for &se in model.std_errors.iter() {
assert!(se >= 0.0, "std error {se} must be non-negative");
}
}
#[test]
fn test_cox_p_values_valid() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
for &p in model.p_values.iter() {
assert!(p >= 0.0 && p <= 1.0, "p-value {p} must be in [0,1]");
}
}
#[test]
fn test_cox_hazard_ratio_positive() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
for &hr in model.hazard_ratio().iter() {
assert!(hr > 0.0, "hazard ratio {hr} must be positive");
}
}
#[test]
fn test_cox_predict_hazard() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
let pred = model.predict_hazard(&cov);
assert_eq!(pred.len(), 10);
for &h in pred.iter() {
assert!(h > 0.0, "hazard {h} must be positive");
}
}
#[test]
fn test_cox_predict_survival_bounded() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
let surv = model.predict_survival(&cov, 5.0);
for &s in surv.iter() {
assert!(s >= 0.0 && s <= 1.0 + 1e-12, "survival {s} out of [0,1]");
}
}
#[test]
fn test_cox_concordance_index() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
let c = model.concordance_index(×, &events, &cov);
assert!(c >= 0.0 && c <= 1.0, "concordance {c} must be in [0,1]");
}
#[test]
fn test_cox_multivariate() {
let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let events = vec![true, true, false, true, true, false, true, true];
let mut cov = Array2::zeros((8, 2));
for i in 0..8_usize {
cov[[i, 0]] = i as f64;
cov[[i, 1]] = (i % 3) as f64;
}
let model = CoxPHModel::fit(×, &events, &cov).expect("multivariate Cox fit");
assert_eq!(model.coefficients.len(), 2);
assert_eq!(model.std_errors.len(), 2);
assert_eq!(model.p_values.len(), 2);
}
#[test]
fn test_cox_error_empty() {
let cov: Array2<f64> = Array2::zeros((0, 1));
let result = CoxPHModel::fit(&[], &[], &cov);
assert!(result.is_err());
}
#[test]
fn test_cox_error_dimension_mismatch() {
let times = vec![1.0, 2.0];
let events = vec![true];
let cov = Array2::zeros((2, 1));
let result = CoxPHModel::fit(×, &events, &cov);
assert!(result.is_err());
}
#[test]
fn test_cox_score_lr_wald_tests() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
assert!(model.score_test >= 0.0, "score test {}", model.score_test);
assert!(model.lr_test >= 0.0, "lr test {}", model.lr_test);
assert!(model.wald_test >= 0.0, "wald test {}", model.wald_test);
}
#[test]
fn test_cox_baseline_hazard_monotone() {
let (times, events, cov) = simple_cox_data();
let model = CoxPHModel::fit(×, &events, &cov).expect("Cox fit");
for i in 1..model.baseline_hazard.len() {
assert!(
model.baseline_hazard[i].1 >= model.baseline_hazard[i - 1].1 - 1e-12,
"Baseline hazard not monotone at index {i}"
);
}
}
}