survival 1.1.27

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use crate::utilities::statistical::normal_cdf;
use pyo3::prelude::*;

#[allow(dead_code)]
type DistFn = fn(f64) -> f64;

#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(dead_code)]
pub enum SurvregResidType {
    Response,
    Deviance,
    Dfbeta,
    Dfbetas,
    Working,
    Ldcase,
    Ldresp,
    Ldshape,
    Matrix,
}

impl SurvregResidType {
    pub fn from_str(s: &str) -> Option<Self> {
        match s.to_lowercase().as_str() {
            "response" => Some(SurvregResidType::Response),
            "deviance" => Some(SurvregResidType::Deviance),
            "dfbeta" => Some(SurvregResidType::Dfbeta),
            "dfbetas" => Some(SurvregResidType::Dfbetas),
            "working" => Some(SurvregResidType::Working),
            "ldcase" => Some(SurvregResidType::Ldcase),
            "ldresp" => Some(SurvregResidType::Ldresp),
            "ldshape" => Some(SurvregResidType::Ldshape),
            "matrix" => Some(SurvregResidType::Matrix),
            _ => None,
        }
    }
}

#[derive(Debug, Clone)]
#[pyclass]
pub struct SurvregResiduals {
    #[pyo3(get)]
    pub residuals: Vec<f64>,
    #[pyo3(get)]
    pub residual_type: String,
    #[pyo3(get)]
    pub n: usize,
}

#[pymethods]
impl SurvregResiduals {
    fn __repr__(&self) -> String {
        format!(
            "SurvregResiduals(type='{}', n={})",
            self.residual_type, self.n
        )
    }
}

fn extreme_value_cdf(z: f64) -> f64 {
    1.0 - (-z.exp()).exp()
}

fn extreme_value_pdf(z: f64) -> f64 {
    let ez = z.exp();
    ez * (-ez).exp()
}

fn logistic_cdf(z: f64) -> f64 {
    1.0 / (1.0 + (-z).exp())
}

fn logistic_pdf(z: f64) -> f64 {
    let ez = (-z).exp();
    ez / ((1.0 + ez) * (1.0 + ez))
}

fn gaussian_cdf(z: f64) -> f64 {
    normal_cdf(z)
}

fn gaussian_pdf(z: f64) -> f64 {
    (-0.5 * z * z).exp() / (2.0 * std::f64::consts::PI).sqrt()
}

#[allow(dead_code)]
pub fn compute_standardized_residuals(
    time: &[f64],
    _status: &[i32],
    linear_pred: &[f64],
    scale: f64,
    _distribution: &str,
) -> Vec<f64> {
    let n = time.len();
    let mut residuals = Vec::with_capacity(n);

    for i in 0..n {
        let y = time[i].ln();
        let z = (y - linear_pred[i]) / scale;
        residuals.push(z);
    }

    residuals
}

pub fn compute_response_residuals(time: &[f64], linear_pred: &[f64]) -> Vec<f64> {
    time.iter()
        .zip(linear_pred.iter())
        .map(|(&t, &lp)| t.ln() - lp)
        .collect()
}

pub fn compute_deviance_residuals_survreg(
    time: &[f64],
    status: &[i32],
    linear_pred: &[f64],
    scale: f64,
    distribution: &str,
) -> Vec<f64> {
    let n = time.len();
    let mut residuals = Vec::with_capacity(n);

    let (cdf_fn, pdf_fn): (DistFn, DistFn) = match distribution.to_lowercase().as_str() {
        "weibull" | "extreme" | "extreme_value" | "extremevalue" => {
            (extreme_value_cdf, extreme_value_pdf)
        }
        "logistic" | "loglogistic" => (logistic_cdf, logistic_pdf),
        "gaussian" | "lognormal" | "normal" => (gaussian_cdf, gaussian_pdf),
        _ => (extreme_value_cdf, extreme_value_pdf),
    };

    for i in 0..n {
        let y = time[i].ln();
        let z = (y - linear_pred[i]) / scale;

        let surv = 1.0 - cdf_fn(z);
        let dens = pdf_fn(z) / scale;

        let dev = if status[i] == 1 {
            let log_dens = if dens > 1e-300 { dens.ln() } else { -690.0 };
            -2.0 * log_dens
        } else {
            let log_surv = if surv > 1e-300 { surv.ln() } else { -690.0 };
            -2.0 * log_surv
        };

        let sign = if z >= 0.0 { 1.0 } else { -1.0 };
        residuals.push(sign * dev.abs().sqrt());
    }

    residuals
}

pub fn compute_working_residuals(
    time: &[f64],
    status: &[i32],
    linear_pred: &[f64],
    scale: f64,
    distribution: &str,
) -> Vec<f64> {
    let n = time.len();
    let mut residuals = Vec::with_capacity(n);

    let (cdf_fn, pdf_fn): (DistFn, DistFn) = match distribution.to_lowercase().as_str() {
        "weibull" | "extreme" | "extreme_value" | "extremevalue" => {
            (extreme_value_cdf, extreme_value_pdf)
        }
        "logistic" | "loglogistic" => (logistic_cdf, logistic_pdf),
        "gaussian" | "lognormal" | "normal" => (gaussian_cdf, gaussian_pdf),
        _ => (extreme_value_cdf, extreme_value_pdf),
    };

    for i in 0..n {
        let y = time[i].ln();
        let z = (y - linear_pred[i]) / scale;

        let resid = if status[i] == 1 {
            let f = pdf_fn(z);
            let f_prime = match distribution.to_lowercase().as_str() {
                "weibull" | "extreme" | "extreme_value" | "extremevalue" => {
                    let ez = z.exp();
                    ez * (-ez).exp() * (1.0 - ez)
                }
                "logistic" | "loglogistic" => {
                    let ez = (-z).exp();
                    let denom = (1.0 + ez).powi(3);
                    ez * (ez - 1.0) / denom
                }
                _ => -z * f,
            };
            if f.abs() > 1e-300 { -f_prime / f } else { 0.0 }
        } else {
            let surv = 1.0 - cdf_fn(z);
            let f = pdf_fn(z);
            if surv.abs() > 1e-300 { f / surv } else { 0.0 }
        };

        residuals.push(resid);
    }

    residuals
}

pub fn compute_dfbeta_survreg(
    time: &[f64],
    status: &[i32],
    covariates: &[Vec<f64>],
    linear_pred: &[f64],
    scale: f64,
    var_matrix: &[Vec<f64>],
    distribution: &str,
) -> Vec<Vec<f64>> {
    let n = time.len();
    let nvar = if n > 0 && !covariates.is_empty() {
        covariates[0].len()
    } else {
        return vec![];
    };

    let working = compute_working_residuals(time, status, linear_pred, scale, distribution);

    let mut dfbeta = Vec::with_capacity(n);

    for i in 0..n {
        let mut row = Vec::with_capacity(nvar);
        for j in 0..nvar {
            let mut val = 0.0;
            for k in 0..nvar {
                if k < var_matrix.len() && j < var_matrix[k].len() {
                    val += var_matrix[k][j] * covariates[i][k] * working[i];
                }
            }
            row.push(val);
        }
        dfbeta.push(row);
    }

    dfbeta
}

pub fn compute_ldcase(
    time: &[f64],
    status: &[i32],
    linear_pred: &[f64],
    scale: f64,
    distribution: &str,
) -> Vec<f64> {
    let n = time.len();
    let mut ld = Vec::with_capacity(n);

    let (cdf_fn, pdf_fn): (DistFn, DistFn) = match distribution.to_lowercase().as_str() {
        "weibull" | "extreme" | "extreme_value" | "extremevalue" => {
            (extreme_value_cdf, extreme_value_pdf)
        }
        "logistic" | "loglogistic" => (logistic_cdf, logistic_pdf),
        "gaussian" | "lognormal" | "normal" => (gaussian_cdf, gaussian_pdf),
        _ => (extreme_value_cdf, extreme_value_pdf),
    };

    for i in 0..n {
        let y = time[i].ln();
        let z = (y - linear_pred[i]) / scale;

        let contrib = if status[i] == 1 {
            let f = pdf_fn(z);
            if f > 1e-300 {
                f.ln() - scale.ln()
            } else {
                -690.0
            }
        } else {
            let surv = 1.0 - cdf_fn(z);
            if surv > 1e-300 { surv.ln() } else { -690.0 }
        };

        ld.push(contrib);
    }

    ld
}

#[pyfunction]
#[pyo3(signature = (time, status, linear_pred, scale, distribution, residual_type="deviance".to_string()))]
pub fn residuals_survreg(
    time: Vec<f64>,
    status: Vec<i32>,
    linear_pred: Vec<f64>,
    scale: f64,
    distribution: String,
    residual_type: String,
) -> PyResult<SurvregResiduals> {
    let n = time.len();
    if status.len() != n || linear_pred.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "time, status, and linear_pred must have the same length",
        ));
    }

    let resid_type = SurvregResidType::from_str(&residual_type).ok_or_else(|| {
        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
            "Unknown residual type: {}. Valid types: response, deviance, working, ldcase",
            residual_type
        ))
    })?;

    let residuals = match resid_type {
        SurvregResidType::Response => compute_response_residuals(&time, &linear_pred),
        SurvregResidType::Deviance => {
            compute_deviance_residuals_survreg(&time, &status, &linear_pred, scale, &distribution)
        }
        SurvregResidType::Working => {
            compute_working_residuals(&time, &status, &linear_pred, scale, &distribution)
        }
        SurvregResidType::Ldcase | SurvregResidType::Ldresp | SurvregResidType::Ldshape => {
            compute_ldcase(&time, &status, &linear_pred, scale, &distribution)
        }
        _ => compute_deviance_residuals_survreg(&time, &status, &linear_pred, scale, &distribution),
    };

    Ok(SurvregResiduals {
        residuals,
        residual_type,
        n,
    })
}

#[pyfunction]
#[pyo3(signature = (time, status, covariates, linear_pred, scale, var_matrix, distribution))]
pub fn dfbeta_survreg(
    time: Vec<f64>,
    status: Vec<i32>,
    covariates: Vec<Vec<f64>>,
    linear_pred: Vec<f64>,
    scale: f64,
    var_matrix: Vec<Vec<f64>>,
    distribution: String,
) -> PyResult<Vec<Vec<f64>>> {
    let n = time.len();
    if status.len() != n || linear_pred.len() != n || covariates.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "All inputs must have the same length",
        ));
    }

    Ok(compute_dfbeta_survreg(
        &time,
        &status,
        &covariates,
        &linear_pred,
        scale,
        &var_matrix,
        &distribution,
    ))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_response_residuals() {
        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        let linear_pred = vec![0.0, 0.5, 1.0, 1.2, 1.5];
        let resid = compute_response_residuals(&time, &linear_pred);

        assert_eq!(resid.len(), 5);
        assert!((resid[0] - 0.0).abs() < 1e-10);
    }

    #[test]
    fn test_deviance_residuals() {
        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        let status = vec![1, 0, 1, 0, 1];
        let linear_pred = vec![0.0, 0.5, 1.0, 1.2, 1.5];
        let scale = 1.0;

        let resid =
            compute_deviance_residuals_survreg(&time, &status, &linear_pred, scale, "weibull");

        assert_eq!(resid.len(), 5);
    }
}