survival 1.0.17

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use pyo3::prelude::*;
use rayon::prelude::*;
pub(crate) struct CoxScoreData<'a> {
    pub y: &'a [f64],
    pub strata: &'a [i32],
    pub covar: &'a [f64],
    pub score: &'a [f64],
    pub weights: &'a [f64],
}
pub(crate) struct CoxScoreParams {
    pub method: i32,
    pub n: usize,
    pub nvar: usize,
}
fn find_strata_boundaries(strata: &[i32], n: usize) -> Vec<(usize, usize)> {
    if n == 0 {
        return vec![];
    }
    let mut boundaries = Vec::new();
    let mut start = 0;
    let mut current = strata[0];
    for (i, &stratum) in strata.iter().enumerate().skip(1).take(n - 1) {
        if stratum != current {
            boundaries.push((start, i - 1));
            start = i;
            current = stratum;
        }
    }
    boundaries.push((start, n - 1));
    boundaries
}
#[allow(clippy::too_many_arguments)]
fn process_stratum(
    start: usize,
    end: usize,
    time: &[f64],
    status: &[f64],
    covar: &[f64],
    score: &[f64],
    weights: &[f64],
    nvar: usize,
    method: i32,
) -> Vec<(usize, Vec<f64>)> {
    let mut results = Vec::new();
    let mut a = vec![0.0; nvar];
    let mut a2 = vec![0.0; nvar];
    let mut xhaz = vec![0.0; nvar];
    let mut denom = 0.0;
    let mut cumhaz = 0.0;
    let mut partial_resid: Vec<Vec<f64>> = (start..=end).map(|_| vec![0.0; nvar]).collect();
    let mut i = end as i32;
    while i >= start as i32 {
        let i_usize = i as usize;
        let newtime = time[i_usize];
        let mut deaths_count = 0;
        let mut e_denom = 0.0;
        let mut meanwt = 0.0;
        a2.fill(0.0);
        let mut j = i;
        while j >= start as i32 {
            let j_usize = j as usize;
            if time[j_usize] != newtime {
                break;
            }
            let risk = score[j_usize] * weights[j_usize];
            denom += risk;
            let local_idx = j_usize - start;
            for var in 0..nvar {
                let covar_val = covar[j_usize * nvar + var];
                partial_resid[local_idx][var] = score[j_usize] * (covar_val * cumhaz - xhaz[var]);
            }
            for var in 0..nvar {
                a[var] += risk * covar[j_usize * nvar + var];
            }
            if status[j_usize] == 1.0 {
                deaths_count += 1;
                e_denom += risk;
                meanwt += weights[j_usize];
                for var in 0..nvar {
                    a2[var] += risk * covar[j_usize * nvar + var];
                }
            }
            j -= 1;
        }
        let processed_start = j + 1;
        let processed_end = i;
        i = j;
        if deaths_count > 0 {
            let deaths = deaths_count as f64;
            if deaths < 2.0 || method == 0 {
                let hazard = meanwt / denom;
                cumhaz += hazard;
                for var in 0..nvar {
                    let xbar = a[var] / denom;
                    xhaz[var] += xbar * hazard;
                    for k in processed_start..=processed_end {
                        let k_usize = k as usize;
                        let local_idx = k_usize - start;
                        partial_resid[local_idx][var] += covar[k_usize * nvar + var] - xbar;
                    }
                }
            } else {
                let meanwt_per_death = meanwt / deaths;
                for dd in 0..deaths_count {
                    let downwt = dd as f64 / deaths;
                    let temp = denom - downwt * e_denom;
                    let hazard = meanwt_per_death / temp;
                    cumhaz += hazard;
                    for var in 0..nvar {
                        let xbar = (a[var] - downwt * a2[var]) / temp;
                        xhaz[var] += xbar * hazard;
                        for k in processed_start..=processed_end {
                            let k_usize = k as usize;
                            let local_idx = k_usize - start;
                            let temp2 = covar[k_usize * nvar + var] - xbar;
                            partial_resid[local_idx][var] += temp2 / deaths;
                            partial_resid[local_idx][var] +=
                                temp2 * score[k_usize] * hazard * downwt;
                        }
                    }
                }
            }
        }
    }
    for k in start..=end {
        let local_idx = k - start;
        for var in 0..nvar {
            partial_resid[local_idx][var] +=
                score[k] * (xhaz[var] - covar[k * nvar + var] * cumhaz);
        }
        results.push((k, partial_resid[local_idx].clone()));
    }
    results
}
pub(crate) fn compute_cox_score_residuals(data: CoxScoreData, params: CoxScoreParams) -> Vec<f64> {
    let time = &data.y[0..params.n];
    let status = &data.y[params.n..2 * params.n];
    let boundaries = find_strata_boundaries(data.strata, params.n);
    if boundaries.len() > 1 {
        let all_results: Vec<Vec<(usize, Vec<f64>)>> = boundaries
            .par_iter()
            .map(|&(start, end)| {
                process_stratum(
                    start,
                    end,
                    time,
                    status,
                    data.covar,
                    data.score,
                    data.weights,
                    params.nvar,
                    params.method,
                )
            })
            .collect();
        let mut resid = vec![0.0; params.n * params.nvar];
        for stratum_results in all_results {
            for (idx, values) in stratum_results {
                for (var, &val) in values.iter().enumerate() {
                    resid[idx * params.nvar + var] = val;
                }
            }
        }
        return resid;
    }
    let mut resid = vec![0.0; params.n * params.nvar];
    let mut a = vec![0.0; params.nvar];
    let mut a2 = vec![0.0; params.nvar];
    let mut xhaz = vec![0.0; params.nvar];
    let mut denom = 0.0;
    let mut cumhaz = 0.0;
    let mut stratastart = params.n as i32 - 1;
    let mut currentstrata = if params.n > 0 {
        data.strata[params.n - 1]
    } else {
        0
    };
    let mut i = stratastart;
    while i >= 0 {
        let i_usize = i as usize;
        let newtime = time[i_usize];
        let mut deaths_count = 0;
        let mut e_denom = 0.0;
        let mut meanwt = 0.0;
        a2.fill(0.0);
        let mut j = i;
        while j >= 0 {
            let j_usize = j as usize;
            if time[j_usize] != newtime || data.strata[j_usize] != currentstrata {
                break;
            }
            let risk = data.score[j_usize] * data.weights[j_usize];
            denom += risk;
            for (var, (resid_elem, xhaz_elem)) in resid
                .iter_mut()
                .skip(j_usize * params.nvar)
                .take(params.nvar)
                .zip(xhaz.iter())
                .enumerate()
            {
                let idx = j_usize * params.nvar + var;
                let covar_val = data.covar[idx];
                *resid_elem = data.score[j_usize] * (covar_val * cumhaz - xhaz_elem);
            }
            for (var, a_elem) in a.iter_mut().enumerate().take(params.nvar) {
                *a_elem += risk * data.covar[j_usize * params.nvar + var];
            }
            if status[j_usize] == 1.0 {
                deaths_count += 1;
                e_denom += risk;
                meanwt += data.weights[j_usize];
                for (var, a2_elem) in a2.iter_mut().enumerate().take(params.nvar) {
                    *a2_elem += risk * data.covar[j_usize * params.nvar + var];
                }
            }
            j -= 1;
        }
        let processed_start = j + 1;
        let processed_end = i;
        i = j;
        if deaths_count > 0 {
            let deaths = deaths_count as f64;
            if deaths < 2.0 || params.method == 0 {
                let hazard = meanwt / denom;
                cumhaz += hazard;
                for (var, (a_elem, xhaz_elem)) in
                    a.iter().zip(xhaz.iter_mut()).enumerate().take(params.nvar)
                {
                    let xbar = a_elem / denom;
                    *xhaz_elem += xbar * hazard;
                    for k in processed_start..=processed_end {
                        let k_usize = k as usize;
                        let idx = k_usize * params.nvar + var;
                        resid[idx] += data.covar[idx] - xbar;
                    }
                }
            } else {
                let meanwt_per_death = meanwt / deaths;
                for dd in 0..deaths_count {
                    let downwt = dd as f64 / deaths;
                    let temp = denom - downwt * e_denom;
                    let hazard = meanwt_per_death / temp;
                    cumhaz += hazard;
                    for var in 0..params.nvar {
                        let xbar = (a[var] - downwt * a2[var]) / temp;
                        xhaz[var] += xbar * hazard;
                        for k in processed_start..=processed_end {
                            let k_usize = k as usize;
                            let idx = k_usize * params.nvar + var;
                            let temp2 = data.covar[idx] - xbar;
                            resid[idx] += temp2 / deaths;
                            resid[idx] += temp2 * data.score[k_usize] * hazard * downwt;
                        }
                    }
                }
            }
        }
        if i < 0 || data.strata[i as usize] != currentstrata {
            for k in (i + 1)..=stratastart {
                let k_usize = k as usize;
                for (var, (resid_elem, xhaz_elem)) in resid
                    .iter_mut()
                    .skip(k_usize * params.nvar)
                    .take(params.nvar)
                    .zip(xhaz.iter())
                    .enumerate()
                {
                    let idx = k_usize * params.nvar + var;
                    *resid_elem += data.score[k_usize] * (xhaz_elem - data.covar[idx] * cumhaz);
                }
            }
            denom = 0.0;
            cumhaz = 0.0;
            a.fill(0.0);
            xhaz.fill(0.0);
            stratastart = i;
            if i >= 0 {
                currentstrata = data.strata[i as usize];
            }
        }
    }
    resid
}
#[pyfunction]
#[pyo3(signature = (y, strata, covar, score, weights, nvar, method=0))]
pub fn cox_score_residuals(
    y: Vec<f64>,
    strata: Vec<i32>,
    covar: Vec<f64>,
    score: Vec<f64>,
    weights: Vec<f64>,
    nvar: usize,
    method: i32,
) -> PyResult<Vec<f64>> {
    let n = score.len();
    if y.len() < 2 * n {
        return Err(pyo3::exceptions::PyValueError::new_err(
            "y array must have length >= 2 * n (time, status)",
        ));
    }
    if strata.len() < n {
        return Err(pyo3::exceptions::PyValueError::new_err(
            "strata array length must match n",
        ));
    }
    if covar.len() < n * nvar {
        return Err(pyo3::exceptions::PyValueError::new_err(
            "covar array must have length >= n * nvar",
        ));
    }
    if weights.len() < n {
        return Err(pyo3::exceptions::PyValueError::new_err(
            "weights array length must match n",
        ));
    }
    let data = CoxScoreData {
        y: &y,
        strata: &strata,
        covar: &covar,
        score: &score,
        weights: &weights,
    };
    let params = CoxScoreParams { method, n, nvar };
    Ok(compute_cox_score_residuals(data, params))
}