#![allow(clippy::redundant_closure)]
use crate::regression::survregc1::{SurvivalDist, survregc1};
use ndarray::{Array1, Array2, ArrayView1};
use ndarray_linalg::{Cholesky, Solve, UPLO};
use pyo3::prelude::*;
#[derive(Debug, Clone)]
#[pyclass]
pub struct SurvivalFit {
#[pyo3(get)]
pub coefficients: Vec<f64>,
#[pyo3(get)]
pub iterations: usize,
#[pyo3(get)]
pub variance_matrix: Vec<Vec<f64>>,
#[pyo3(get)]
pub log_likelihood: f64,
#[pyo3(get)]
pub convergence_flag: i32,
#[pyo3(get)]
pub score_vector: Vec<f64>,
}
#[allow(clippy::too_many_arguments)]
fn calculate_likelihood(
n: usize,
nvar: usize,
nstrat: usize,
beta: &[f64],
distribution: &DistributionType,
strata: &[usize],
offsets: &Array1<f64>,
time1: &ArrayView1<f64>,
time2: Option<&ArrayView1<f64>>,
status: &ArrayView1<f64>,
weights: &Array1<f64>,
covariates: &Array2<f64>,
imat: &mut Array2<f64>,
jj: &mut Array2<f64>,
u: &mut Array1<f64>,
) -> Result<f64, Box<dyn std::error::Error>> {
let dist = match distribution {
DistributionType::ExtremeValue => SurvivalDist::ExtremeValue,
DistributionType::Logistic => SurvivalDist::Logistic,
DistributionType::Gaussian => SurvivalDist::Gaussian,
DistributionType::Weibull => SurvivalDist::Weibull,
DistributionType::LogNormal => SurvivalDist::LogNormal,
};
let strat_vec: Vec<i32> = strata.iter().map(|&s| (s + 1) as i32).collect();
let strat_arr = Array1::from_vec(strat_vec);
let status_vec: Vec<i32> = status.iter().map(|&s| s as i32).collect();
let status_arr = Array1::from_vec(status_vec);
let beta_arr = Array1::from_vec(beta.to_vec());
let frail_arr = Array1::from_vec(vec![0i32; n]);
let nvar2 = nvar + nstrat;
let result = survregc1(
n,
nvar,
nstrat,
false,
&beta_arr.view(),
dist,
&strat_arr.view(),
&offsets.view(),
time1,
time2,
&status_arr.view(),
&weights.view(),
&covariates.view(),
0,
&frail_arr.view(),
)?;
for i in 0..nvar2.min(u.len()) {
if i < result.u.len() {
u[i] = result.u[i];
}
}
for i in 0..nvar2.min(imat.nrows()) {
for j in 0..nvar2.min(imat.ncols()) {
if i < result.imat.nrows() && j < result.imat.ncols() {
imat[[i, j]] = -result.imat[[i, j]];
}
}
}
for i in 0..nvar2.min(jj.nrows()) {
for j in 0..nvar2.min(jj.ncols()) {
if i < result.jj.nrows() && j < result.jj.ncols() {
jj[[i, j]] = result.jj[[i, j]];
}
}
}
Ok(result.loglik)
}
fn cholesky_solve(
matrix: &Array2<f64>,
vector: &Array1<f64>,
_tol: f64,
) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
if matrix.nrows() == 0 || matrix.ncols() == 0 {
return Ok(Array1::zeros(vector.len()));
}
let max_val = matrix.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
if max_val < 1e-10 {
return Ok(Array1::zeros(vector.len()));
}
match matrix.cholesky(UPLO::Lower) {
Ok(chol) => chol
.solve(vector)
.map_err(|e| format!("Cholesky solve failed: {}", e).into()),
Err(_) => {
let n = matrix.nrows();
let mut reg_matrix = matrix.clone();
let ridge = max_val * 1e-6;
for i in 0..n {
reg_matrix[[i, i]] += ridge;
}
match reg_matrix.cholesky(UPLO::Lower) {
Ok(chol) => chol
.solve(vector)
.map_err(|e| format!("Cholesky solve failed: {}", e).into()),
Err(_) => Ok(Array1::zeros(vector.len())),
}
}
}
}
fn check_convergence(old: f64, new: f64, eps: f64) -> bool {
(1.0 - new / old).abs() <= eps || (old - new).abs() <= eps
}
fn adjust_strata(newbeta: &mut [f64], beta: &[f64], nvar: usize, nstrat: usize) {
for i in 0..nstrat {
let idx = nvar + i;
if beta[idx] - newbeta[idx] > 1.1 {
newbeta[idx] = beta[idx] - 1.1;
}
}
}
fn calculate_variance_matrix(
imat: Array2<f64>,
_nvar2: usize,
_tol_chol: f64,
) -> Result<Array2<f64>, Box<dyn std::error::Error>> {
use ndarray_linalg::Inverse;
if imat.nrows() == 0 || imat.ncols() == 0 {
return Ok(imat);
}
let max_val = imat.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
if max_val < 1e-10 {
return Ok(imat);
}
match imat.inv() {
Ok(inv) => Ok(inv),
Err(_) => Ok(imat),
}
}
#[derive(Debug, Clone, Copy)]
#[pyclass]
pub enum DistributionType {
#[pyo3(name = "extreme_value")]
ExtremeValue,
#[pyo3(name = "logistic")]
Logistic,
#[pyo3(name = "gaussian")]
Gaussian,
#[pyo3(name = "weibull")]
Weibull,
#[pyo3(name = "lognormal")]
LogNormal,
}
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn survreg(
time: Vec<f64>,
status: Vec<f64>,
covariates: Vec<Vec<f64>>,
weights: Option<Vec<f64>>,
offsets: Option<Vec<f64>>,
initial_beta: Option<Vec<f64>>,
strata: Option<Vec<usize>>,
distribution: Option<&str>,
max_iter: Option<usize>,
eps: Option<f64>,
tol_chol: Option<f64>,
) -> PyResult<SurvivalFit> {
let n = time.len();
if status.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"time and status must have the same length",
));
}
let nvar = if !covariates.is_empty() {
covariates[0].len()
} else {
0
};
if !covariates.is_empty() && covariates.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"covariates must have the same number of rows as time",
));
}
let weights = weights.unwrap_or_else(|| vec![1.0; n]);
let offsets = offsets.unwrap_or_else(|| vec![0.0; n]);
let strata = strata.unwrap_or_else(|| vec![0; n]);
let max_iter = max_iter.unwrap_or(20);
let eps = eps.unwrap_or(1e-5);
let tol_chol = tol_chol.unwrap_or(1e-9);
let dist_type = match distribution {
Some("logistic") | Some("Logistic") => DistributionType::Logistic,
Some("gaussian") | Some("Gaussian") | Some("normal") | Some("Normal") => {
DistributionType::Gaussian
}
Some("weibull") | Some("Weibull") => DistributionType::Weibull,
Some("lognormal") | Some("LogNormal") | Some("lognorm") | Some("LogNorm") => {
DistributionType::LogNormal
}
_ => DistributionType::ExtremeValue,
};
let nstrat = if strata.is_empty() {
1
} else {
strata.iter().max().copied().unwrap_or(0) + 1
};
let initial_beta = initial_beta.unwrap_or_else(|| vec![0.0; nvar + nstrat]);
let y = {
let mut y_data = Vec::new();
for i in 0..n {
y_data.push(vec![time[i], status[i]]);
}
Array2::from_shape_vec((n, 2), y_data.into_iter().flatten().collect())
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?
};
let cov_array = if nvar > 0 {
let flat: Vec<f64> = covariates.into_iter().flatten().collect();
let temp = Array2::from_shape_vec((n, nvar), flat)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
temp.t().to_owned()
} else {
Array2::zeros((0, n))
};
let weights_arr = Array1::from_vec(weights);
let offsets_arr = Array1::from_vec(offsets);
let result = compute_survreg(
max_iter,
nvar,
&y,
&cov_array,
&weights_arr,
&offsets_arr,
initial_beta,
nstrat,
&strata,
eps,
tol_chol,
dist_type,
)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{}", e)))?;
let variance_matrix = result
.variance_matrix
.outer_iter()
.map(|row| row.iter().cloned().collect())
.collect();
Ok(SurvivalFit {
coefficients: result.coefficients,
iterations: result.iterations,
variance_matrix,
log_likelihood: result.log_likelihood,
convergence_flag: result.convergence_flag,
score_vector: result.score_vector,
})
}
#[allow(clippy::too_many_arguments)]
fn compute_survreg(
max_iter: usize,
nvar: usize,
y: &Array2<f64>,
covariates: &Array2<f64>,
weights: &Array1<f64>,
offsets: &Array1<f64>,
mut beta: Vec<f64>,
nstrat: usize,
strata: &[usize],
eps: f64,
tol_chol: f64,
distribution: DistributionType,
) -> Result<SurvivalFitComputed, Box<dyn std::error::Error>> {
let n = y.nrows();
let ny = y.ncols();
let nvar2 = nvar + nstrat;
let mut imat = Array2::zeros((nvar2, nvar2));
let mut jj = Array2::zeros((nvar2, nvar2));
let mut u = Array1::zeros(nvar2);
let mut newbeta = beta.clone();
let mut usave = Array1::zeros(nvar2);
let time1_vec: Vec<f64> = y.column(0).iter().cloned().collect();
let status_vec: Vec<f64> = if ny == 2 {
y.column(1).iter().cloned().collect()
} else {
y.column(2).iter().cloned().collect()
};
let time2_vec: Option<Vec<f64>> = if ny == 3 {
Some(y.column(1).iter().cloned().collect())
} else {
None
};
let time1_arr = Array1::from_vec(time1_vec);
let status_arr = Array1::from_vec(status_vec);
let time2_arr = time2_vec.map(|v| Array1::from_vec(v));
let time1 = time1_arr.view();
let status = status_arr.view();
let time2_view: Option<ArrayView1<f64>> = time2_arr.as_ref().map(|v| v.view());
let mut loglik = calculate_likelihood(
n,
nvar,
nstrat,
&beta,
&distribution,
strata,
offsets,
&time1,
time2_view.as_ref(),
&status,
weights,
covariates,
&mut imat,
&mut jj,
&mut u,
)?;
usave.assign(&u);
let mut iter = 0;
let mut halving = 0;
while iter < max_iter {
let chol_result = cholesky_solve(&imat, &u, tol_chol);
let delta = match chol_result {
Ok(d) => d,
Err(_) => cholesky_solve(&jj, &u, tol_chol)?,
};
newbeta
.iter_mut()
.zip(beta.iter().zip(delta.iter()))
.for_each(|(nb, (b, d))| *nb = b + d);
let newlik = calculate_likelihood(
n,
nvar,
nstrat,
&newbeta,
&distribution,
strata,
offsets,
&time1,
time2_view.as_ref(),
&status,
weights,
covariates,
&mut imat,
&mut jj,
&mut u,
)?;
if check_convergence(loglik, newlik, eps) && halving == 0 {
loglik = newlik;
beta = newbeta.clone();
iter += 1;
break;
}
if newlik.is_nan() || newlik < loglik {
halving += 1;
newbeta
.iter_mut()
.zip(&beta)
.for_each(|(nb, b)| *nb = (*nb + 2.0 * b) / 3.0);
if halving == 1 {
adjust_strata(&mut newbeta, &beta, nvar, nstrat);
}
} else {
halving = 0;
loglik = newlik;
beta = newbeta.clone();
}
iter += 1;
}
let converged = iter < max_iter;
let convergence_flag = if converged { 0 } else { -1 };
let variance = calculate_variance_matrix(imat, nvar2, tol_chol)?;
Ok(SurvivalFitComputed {
coefficients: beta,
iterations: iter,
variance_matrix: variance,
log_likelihood: loglik,
convergence_flag,
score_vector: usave.to_vec(),
})
}
pub(crate) struct SurvivalFitComputed {
coefficients: Vec<f64>,
iterations: usize,
variance_matrix: Array2<f64>,
log_likelihood: f64,
convergence_flag: i32,
score_vector: Vec<f64>,
}
#[pymodule]
#[pyo3(name = "survreg")]
fn survreg_module(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(survreg, &m)?)?;
m.add_class::<SurvivalFit>()?;
m.add_class::<DistributionType>()?;
Ok(())
}