survival 1.1.29

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use super::common::{
    apply_deltas_add, apply_deltas_set, build_score_result, validate_scoring_inputs,
};
use ndarray::{Array2, ArrayView2};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;

#[inline]
pub fn agscore3(
    y: &[f64],
    covar: &[f64],
    strata: &[i32],
    score: &[f64],
    weights: &[f64],
    method: i32,
    sort1: &[i32],
) -> Result<Vec<f64>, String> {
    let n = y.len() / 3;
    let nvar = covar.len() / n;
    let tstart = &y[0..n];
    let tstop = &y[n..2 * n];
    let event = &y[2 * n..3 * n];

    let covar_matrix = ArrayView2::from_shape((nvar, n), covar).map_err(|e| {
        format!(
            "Failed to create covariate view with shape ({}, {}): {}",
            nvar, n, e
        )
    })?;

    let mut resid_matrix = Array2::zeros((nvar, n));
    let mut a = vec![0.0; nvar];
    let mut a2 = vec![0.0; nvar];
    let mut mean = vec![0.0; nvar];
    let mut mh1 = vec![0.0; nvar];
    let mut mh2 = vec![0.0; nvar];
    let mut mh3 = vec![0.0; nvar];
    let mut xhaz = vec![0.0; nvar];

    let mut cumhaz = 0.0;
    let mut denom = 0.0;
    let mut current_stratum = *strata.last().unwrap_or(&0);
    let mut i1 = n - 1;
    let sort1: Vec<usize> = sort1.iter().map(|&x| (x - 1) as usize).collect();

    let mut person = n - 1;
    while person > 0 {
        let dtime = tstop[person];

        if strata[person] != current_stratum {
            let mut exit_indices = Vec::new();
            while i1 > 0 && sort1[i1] > person {
                exit_indices.push(sort1[i1]);
                i1 -= 1;
            }

            apply_deltas_add(&exit_indices, nvar, &mut resid_matrix, |k| {
                (0..nvar)
                    .map(|j| -score[k] * (cumhaz * covar_matrix[[j, k]] - xhaz[j]))
                    .collect()
            });

            cumhaz = 0.0;
            denom = 0.0;
            a.fill(0.0);
            xhaz.fill(0.0);
            current_stratum = strata[person];
        } else {
            let mut exit_indices = Vec::new();
            while i1 > 0 && tstart[sort1[i1]] >= dtime {
                let k = sort1[i1];
                if strata[k] != current_stratum {
                    break;
                }
                exit_indices.push(k);
                let risk = score[k] * weights[k];
                denom -= risk;
                for j in 0..nvar {
                    a[j] -= risk * covar_matrix[[j, k]];
                }
                i1 -= 1;
            }

            apply_deltas_add(&exit_indices, nvar, &mut resid_matrix, |k| {
                (0..nvar)
                    .map(|j| -score[k] * (cumhaz * covar_matrix[[j, k]] - xhaz[j]))
                    .collect()
            });
        }

        let mut e_denom = 0.0;
        let mut deaths = 0.0;
        let mut meanwt = 0.0;
        a2.fill(0.0);

        let mut processed_indices = Vec::new();
        while person > 0 && tstop[person] == dtime {
            if strata[person] != current_stratum {
                break;
            }
            processed_indices.push(person);

            let risk = score[person] * weights[person];
            denom += risk;
            for j in 0..nvar {
                a[j] += risk * covar_matrix[[j, person]];
            }

            if event[person] > 0.5 {
                deaths += 1.0;
                e_denom += risk;
                meanwt += weights[person];
                for j in 0..nvar {
                    a2[j] += risk * covar_matrix[[j, person]];
                }
            }
            person -= 1;
        }

        apply_deltas_set(&processed_indices, nvar, &mut resid_matrix, |p| {
            (0..nvar)
                .map(|j| (covar_matrix[[j, p]] * cumhaz - xhaz[j]) * score[p])
                .collect()
        });

        if deaths > 0.0 {
            if deaths < 2.0 || method == 0 {
                let hazard = meanwt / denom;
                cumhaz += hazard;
                for j in 0..nvar {
                    mean[j] = a[j] / denom;
                    xhaz[j] += mean[j] * hazard;
                }

                apply_deltas_add(&processed_indices, nvar, &mut resid_matrix, |p| {
                    (0..nvar).map(|j| covar_matrix[[j, p]] - mean[j]).collect()
                });
            } else {
                mh1.fill(0.0);
                mh2.fill(0.0);
                mh3.fill(0.0);
                let meanwt_norm = meanwt / deaths;

                for dd in 0..deaths as i32 {
                    let downwt = dd as f64 / deaths;
                    let d2 = denom - downwt * e_denom;
                    let hazard = meanwt_norm / d2;
                    cumhaz += hazard;
                    for j in 0..nvar {
                        mean[j] = (a[j] - downwt * a2[j]) / d2;
                        xhaz[j] += mean[j] * hazard;
                        mh1[j] += hazard * downwt;
                        mh2[j] += mean[j] * hazard * downwt;
                        mh3[j] += mean[j] / deaths;
                    }
                }

                apply_deltas_add(&processed_indices, nvar, &mut resid_matrix, |p| {
                    (0..nvar)
                        .map(|j| {
                            (covar_matrix[[j, p]] - mh3[j])
                                + score[p] * (covar_matrix[[j, p]] * mh1[j] - mh2[j])
                        })
                        .collect()
                });
            }
        }
    }

    let mut final_indices = Vec::new();
    while i1 > 0 {
        final_indices.push(sort1[i1]);
        i1 -= 1;
    }

    apply_deltas_add(&final_indices, nvar, &mut resid_matrix, |k| {
        (0..nvar)
            .map(|j| -score[k] * (cumhaz * covar_matrix[[j, k]] - xhaz[j]))
            .collect()
    });

    Ok(resid_matrix.into_raw_vec_and_offset().0)
}

#[pyfunction]
pub fn perform_agscore3_calculation(
    time_data: Vec<f64>,
    covariates: Vec<f64>,
    strata: Vec<i32>,
    score: Vec<f64>,
    weights: Vec<f64>,
    method: i32,
    sort1: Vec<i32>,
) -> PyResult<Py<PyAny>> {
    let n = weights.len();
    validate_scoring_inputs(
        n,
        time_data.len(),
        covariates.len(),
        strata.len(),
        score.len(),
        weights.len(),
    )?;
    if sort1.len() != n {
        return Err(PyRuntimeError::new_err(
            "Sort1 length does not match observations",
        ));
    }
    let residuals = agscore3(
        &time_data,
        &covariates,
        &strata,
        &score,
        &weights,
        method,
        &sort1,
    )
    .map_err(PyRuntimeError::new_err)?;
    let nvar = covariates.len() / n;
    Python::attach(|py| build_score_result(py, residuals, n, nvar, method).map(|d| d.into()))
}