survival 1.1.29

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use super::common::{apply_deltas_add, build_score_result, validate_scoring_inputs};
use ndarray::{Array2, ArrayView2};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;

#[inline]
pub fn agscore2(
    y: &[f64],
    covar: &[f64],
    strata: &[i32],
    score: &[f64],
    weights: &[f64],
    method: i32,
) -> Result<Vec<f64>, String> {
    let n = y.len() / 3;
    let nvar = covar.len() / n;
    let tstart = &y[0..n];
    let tstop = &y[n..2 * n];
    let event = &y[2 * n..3 * n];

    let covar_matrix: ArrayView2<f64> = ArrayView2::from_shape((nvar, n), covar).map_err(|e| {
        format!(
            "Failed to create covariate matrix view with shape ({}, {}): {}",
            nvar, n, e
        )
    })?;

    let mut resid_matrix = Array2::zeros((nvar, n));
    let mut a = vec![0.0; nvar];
    let mut a2 = vec![0.0; nvar];
    let mut mean = vec![0.0; nvar];
    let mut mh1 = vec![0.0; nvar];
    let mut mh2 = vec![0.0; nvar];
    let mut mh3 = vec![0.0; nvar];

    let mut person = 0;
    while person < n {
        if event[person] == 0.0 {
            person += 1;
            continue;
        }

        let time = tstop[person];
        let mut denom = 0.0;
        let mut e_denom = 0.0;
        let mut deaths = 0.0;
        let mut meanwt = 0.0;
        a.fill(0.0);
        a2.fill(0.0);

        let mut at_risk_indices: Vec<usize> = Vec::new();
        let mut k = person;
        while k < n && strata[k] == strata[person] {
            if tstart[k] < time {
                at_risk_indices.push(k);
                let risk = score[k] * weights[k];
                denom += risk;
                for i in 0..nvar {
                    a[i] += risk * covar_matrix[[i, k]];
                }
                if tstop[k] == time && event[k] == 1.0 {
                    deaths += 1.0;
                    e_denom += risk;
                    meanwt += weights[k];
                    for i in 0..nvar {
                        a2[i] += risk * covar_matrix[[i, k]];
                    }
                }
            }
            k += 1;
        }

        if deaths < 2.0 || method == 0 {
            let hazard = meanwt / denom;
            for i in 0..nvar {
                mean[i] = a[i] / denom;
            }

            apply_deltas_add(&at_risk_indices, nvar, &mut resid_matrix, |k| {
                let risk = score[k];
                let is_event = tstop[k] == time && event[k] == 1.0;
                (0..nvar)
                    .map(|i| {
                        let diff = covar_matrix[[i, k]] - mean[i];
                        let mut delta = -diff * risk * hazard;
                        if is_event {
                            delta += diff;
                        }
                        delta
                    })
                    .collect()
            });
        } else {
            let meanwt_norm = meanwt / deaths;
            let mut temp1 = 0.0;
            let mut temp2 = 0.0;
            mh1.fill(0.0);
            mh2.fill(0.0);
            mh3.fill(0.0);

            for dd in 0..deaths as usize {
                let downwt = dd as f64 / deaths;
                let d2 = denom - downwt * e_denom;
                let hazard = meanwt_norm / d2;
                temp1 += hazard;
                temp2 += (1.0 - downwt) * hazard;
                for i in 0..nvar {
                    mean[i] = (a[i] - downwt * a2[i]) / d2;
                    mh1[i] += mean[i] * hazard;
                    mh2[i] += mean[i] * (1.0 - downwt) * hazard;
                    mh3[i] += mean[i] / deaths;
                }
            }

            apply_deltas_add(&at_risk_indices, nvar, &mut resid_matrix, |k| {
                let risk = score[k];
                let is_event = tstop[k] == time && event[k] == 1.0;
                (0..nvar)
                    .map(|i| {
                        if is_event {
                            (covar_matrix[[i, k]] - mh3[i]) - risk * covar_matrix[[i, k]] * temp2
                                + risk * mh2[i]
                        } else {
                            -risk * (covar_matrix[[i, k]] * temp1 - mh1[i])
                        }
                    })
                    .collect()
            });
        }

        while person < n && tstop[person] == time {
            person += 1;
        }
    }

    Ok(resid_matrix.into_raw_vec_and_offset().0)
}

#[pyfunction]
pub fn perform_score_calculation(
    time_data: Vec<f64>,
    covariates: Vec<f64>,
    strata: Vec<i32>,
    score: Vec<f64>,
    weights: Vec<f64>,
    method: i32,
) -> PyResult<Py<PyAny>> {
    let n = weights.len();
    validate_scoring_inputs(
        n,
        time_data.len(),
        covariates.len(),
        strata.len(),
        score.len(),
        weights.len(),
    )?;
    let residuals = agscore2(&time_data, &covariates, &strata, &score, &weights, method)
        .map_err(PyRuntimeError::new_err)?;
    let nvar = covariates.len() / n;
    Python::attach(|py| build_score_result(py, residuals, n, nvar, method).map(|d| d.into()))
}