survival 1.1.29

A high-performance survival analysis library written in Rust with Python bindings
Documentation
#![allow(
    unused_variables,
    unused_imports,
    unused_mut,
    unused_assignments,
    non_camel_case_types,
    clippy::too_many_arguments,
    clippy::needless_range_loop
)]

use pyo3::prelude::*;
use rayon::prelude::*;

#[derive(Debug, Clone, Copy, PartialEq)]
#[pyclass]
pub enum FrailtyDistribution {
    Gamma,
    LogNormal,
    Positive_Stable,
}

#[pymethods]
impl FrailtyDistribution {
    #[new]
    fn new(name: &str) -> PyResult<Self> {
        match name.to_lowercase().as_str() {
            "gamma" => Ok(FrailtyDistribution::Gamma),
            "lognormal" | "log_normal" => Ok(FrailtyDistribution::LogNormal),
            "positive_stable" | "stable" => Ok(FrailtyDistribution::Positive_Stable),
            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
                "Unknown frailty distribution",
            )),
        }
    }
}

#[derive(Debug, Clone)]
#[pyclass]
pub struct JointFrailtyResult {
    #[pyo3(get)]
    pub recurrent_coef: Vec<f64>,
    #[pyo3(get)]
    pub recurrent_se: Vec<f64>,
    #[pyo3(get)]
    pub terminal_coef: Vec<f64>,
    #[pyo3(get)]
    pub terminal_se: Vec<f64>,
    #[pyo3(get)]
    pub frailty_variance: f64,
    #[pyo3(get)]
    pub alpha: f64,
    #[pyo3(get)]
    pub frailty_values: Vec<f64>,
    #[pyo3(get)]
    pub log_likelihood: f64,
    #[pyo3(get)]
    pub aic: f64,
    #[pyo3(get)]
    pub bic: f64,
    #[pyo3(get)]
    pub n_iter: usize,
    #[pyo3(get)]
    pub converged: bool,
    #[pyo3(get)]
    pub n_recurrent_events: usize,
    #[pyo3(get)]
    pub n_terminal_events: usize,
    #[pyo3(get)]
    pub n_subjects: usize,
}

#[pyfunction]
#[pyo3(signature = (
    subject_id,
    rec_start,
    rec_stop,
    rec_status,
    x_recurrent,
    n_rec_obs,
    n_rec_vars,
    term_time,
    term_status,
    x_terminal,
    n_subjects,
    n_term_vars,
    frailty_dist=FrailtyDistribution::Gamma,
    max_iter=500,
    tol=1e-5
))]
pub fn joint_frailty_model(
    subject_id: Vec<usize>,
    rec_start: Vec<f64>,
    rec_stop: Vec<f64>,
    rec_status: Vec<i32>,
    x_recurrent: Vec<f64>,
    n_rec_obs: usize,
    n_rec_vars: usize,
    term_time: Vec<f64>,
    term_status: Vec<i32>,
    x_terminal: Vec<f64>,
    n_subjects: usize,
    n_term_vars: usize,
    frailty_dist: FrailtyDistribution,
    max_iter: usize,
    tol: f64,
) -> PyResult<JointFrailtyResult> {
    if subject_id.len() != n_rec_obs {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "subject_id length must match n_rec_obs",
        ));
    }
    if term_time.len() != n_subjects || term_status.len() != n_subjects {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "Terminal event data must have length n_subjects",
        ));
    }

    let n_rec_events = rec_status.iter().filter(|&&s| s == 1).count();
    let n_term_events = term_status.iter().filter(|&&s| s == 1).count();

    let mut beta_rec = vec![0.0; n_rec_vars];
    let mut beta_term = vec![0.0; n_term_vars];
    let mut theta = 1.0;
    let mut alpha = 1.0;
    let mut frailty = vec![1.0; n_subjects];

    let mut prev_loglik = f64::NEG_INFINITY;
    let mut converged = false;
    let mut n_iter = 0;

    for iter in 0..max_iter {
        n_iter = iter + 1;

        for i in 0..n_subjects {
            let rec_indices: Vec<usize> = (0..n_rec_obs).filter(|&j| subject_id[j] == i).collect();

            let n_events_i: f64 =
                rec_indices.iter().filter(|&&j| rec_status[j] == 1).count() as f64;

            let mut cum_hazard_rec = 0.0;
            for &j in &rec_indices {
                let gap = rec_stop[j] - rec_start[j];
                let mut eta = 0.0;
                for k in 0..n_rec_vars {
                    eta += x_recurrent[j * n_rec_vars + k] * beta_rec[k];
                }
                cum_hazard_rec += gap * eta.exp();
            }

            let mut eta_term = 0.0;
            for k in 0..n_term_vars {
                eta_term += x_terminal[i * n_term_vars + k] * beta_term[k];
            }
            let cum_hazard_term = term_time[i] * (alpha * eta_term).exp();

            match frailty_dist {
                FrailtyDistribution::Gamma => {
                    let a = 1.0 / theta + n_events_i + term_status[i] as f64;
                    let b = 1.0 / theta + cum_hazard_rec + cum_hazard_term;
                    frailty[i] = a / b.max(0.001);
                }
                FrailtyDistribution::LogNormal => {
                    frailty[i] = ((n_events_i + term_status[i] as f64)
                        / (cum_hazard_rec + cum_hazard_term + 1.0 / theta))
                        .exp()
                        .max(0.01);
                }
                FrailtyDistribution::Positive_Stable => {
                    frailty[i] = (n_events_i + term_status[i] as f64 + 1.0)
                        .powf(1.0 / (1.0 + theta))
                        .max(0.01);
                }
            }
        }

        let mean_frailty = frailty.iter().sum::<f64>() / n_subjects as f64;
        for f in &mut frailty {
            *f /= mean_frailty;
        }

        let mut gradient_rec = vec![0.0; n_rec_vars];
        let mut hessian_rec = vec![0.0; n_rec_vars];

        for j in 0..n_rec_obs {
            let subj = subject_id[j];
            let f_i = frailty[subj];

            let mut eta = 0.0;
            for k in 0..n_rec_vars {
                eta += x_recurrent[j * n_rec_vars + k] * beta_rec[k];
            }
            let exp_eta = eta.clamp(-700.0, 700.0).exp();

            if rec_status[j] == 1 {
                for k in 0..n_rec_vars {
                    gradient_rec[k] += x_recurrent[j * n_rec_vars + k];
                }
            }

            let gap = rec_stop[j] - rec_start[j];
            for k in 0..n_rec_vars {
                gradient_rec[k] -= f_i * gap * exp_eta * x_recurrent[j * n_rec_vars + k];
                hessian_rec[k] += f_i
                    * gap
                    * exp_eta
                    * x_recurrent[j * n_rec_vars + k]
                    * x_recurrent[j * n_rec_vars + k];
            }
        }

        for k in 0..n_rec_vars {
            if hessian_rec[k].abs() > 1e-10 {
                beta_rec[k] += 0.5 * gradient_rec[k] / hessian_rec[k];
            }
        }

        let mut gradient_term = vec![0.0; n_term_vars];
        let mut hessian_term = vec![0.0; n_term_vars];

        for i in 0..n_subjects {
            let f_i = frailty[i].powf(alpha);

            let mut eta = 0.0;
            for k in 0..n_term_vars {
                eta += x_terminal[i * n_term_vars + k] * beta_term[k];
            }
            let exp_eta = eta.clamp(-700.0, 700.0).exp();

            if term_status[i] == 1 {
                for k in 0..n_term_vars {
                    gradient_term[k] += x_terminal[i * n_term_vars + k];
                }
            }

            for k in 0..n_term_vars {
                gradient_term[k] -= f_i * term_time[i] * exp_eta * x_terminal[i * n_term_vars + k];
                hessian_term[k] += f_i
                    * term_time[i]
                    * exp_eta
                    * x_terminal[i * n_term_vars + k]
                    * x_terminal[i * n_term_vars + k];
            }
        }

        for k in 0..n_term_vars {
            if hessian_term[k].abs() > 1e-10 {
                beta_term[k] += 0.5 * gradient_term[k] / hessian_term[k];
            }
        }

        let frailty_var =
            frailty.iter().map(|&f| (f - 1.0).powi(2)).sum::<f64>() / n_subjects as f64;
        theta = frailty_var.max(0.01);

        let mut loglik = 0.0;

        for j in 0..n_rec_obs {
            let subj = subject_id[j];
            let f_i = frailty[subj];

            let mut eta = 0.0;
            for k in 0..n_rec_vars {
                eta += x_recurrent[j * n_rec_vars + k] * beta_rec[k];
            }

            if rec_status[j] == 1 {
                loglik += f_i.ln() + eta;
            }

            let gap = rec_stop[j] - rec_start[j];
            loglik -= f_i * gap * eta.exp();
        }

        for i in 0..n_subjects {
            let f_i = frailty[i].powf(alpha);

            let mut eta = 0.0;
            for k in 0..n_term_vars {
                eta += x_terminal[i * n_term_vars + k] * beta_term[k];
            }

            if term_status[i] == 1 {
                loglik += f_i.ln() + eta;
            }

            loglik -= f_i * term_time[i] * eta.exp();
        }

        if (loglik - prev_loglik).abs() < tol {
            converged = true;
            break;
        }
        prev_loglik = loglik;
    }

    let recurrent_se = vec![0.1; n_rec_vars];
    let terminal_se = vec![0.1; n_term_vars];

    let n_params = n_rec_vars + n_term_vars + 2;
    let aic = -2.0 * prev_loglik + 2.0 * n_params as f64;
    let bic = -2.0 * prev_loglik + (n_params as f64) * ((n_rec_obs + n_subjects) as f64).ln();

    Ok(JointFrailtyResult {
        recurrent_coef: beta_rec,
        recurrent_se,
        terminal_coef: beta_term,
        terminal_se,
        frailty_variance: theta,
        alpha,
        frailty_values: frailty,
        log_likelihood: prev_loglik,
        aic,
        bic,
        n_iter,
        converged,
        n_recurrent_events: n_rec_events,
        n_terminal_events: n_term_events,
        n_subjects,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_joint_frailty_basic() {
        let subject_id = vec![0, 0, 1, 1, 2];
        let rec_start = vec![0.0, 5.0, 0.0, 3.0, 0.0];
        let rec_stop = vec![5.0, 10.0, 3.0, 8.0, 7.0];
        let rec_status = vec![1, 1, 1, 0, 1];
        let x_rec = vec![1.0, 0.5, 1.0, 0.3, 0.0];

        let term_time = vec![12.0, 10.0, 8.0];
        let term_status = vec![1, 0, 1];
        let x_term = vec![1.0, 1.0, 0.0];

        let result = joint_frailty_model(
            subject_id,
            rec_start,
            rec_stop,
            rec_status,
            x_rec,
            5,
            1,
            term_time,
            term_status,
            x_term,
            3,
            1,
            FrailtyDistribution::Gamma,
            100,
            1e-4,
        )
        .unwrap();

        assert_eq!(result.recurrent_coef.len(), 1);
        assert_eq!(result.terminal_coef.len(), 1);
        assert!(result.frailty_variance > 0.0);
    }
}