survival 1.0.10

A high-performance survival analysis library written in Rust with Python bindings
Documentation
pub struct SurvivalData<'a> {
    pub time: &'a [f64],
    pub status: &'a [i32],
    pub strata: &'a mut [i32],
}
pub struct Weights<'a> {
    pub score: &'a [f64],
    pub wt: &'a [f64],
}
#[pymodule]
#[pyo3(name = "coxmart")]
fn coxmart_module(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(coxmart, &m)?)?;
    Ok(())
}
use pyo3::prelude::*;
#[pyfunction]
pub fn coxmart(
    time: Vec<f64>,
    status: Vec<i32>,
    score: Vec<f64>,
    weights: Option<Vec<f64>>,
    strata: Option<Vec<i32>>,
    method: Option<i32>,
) -> PyResult<Vec<f64>> {
    let n = time.len();
    if status.len() != n || score.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "time, status, and score must have the same length",
        ));
    }
    let weights_vec = weights.unwrap_or_else(|| vec![1.0; n]);
    let mut strata_vec = strata.unwrap_or_else(|| vec![0; n]);
    let method_val = method.unwrap_or(0);
    if weights_vec.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "weights must have the same length as time",
        ));
    }
    if strata_vec.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "strata must have the same length as time",
        ));
    }
    let mut expect = vec![0.0; n];
    let surv_data = SurvivalData {
        time: &time,
        status: &status,
        strata: &mut strata_vec,
    };
    let weights_data = Weights {
        score: &score,
        wt: &weights_vec,
    };
    compute_coxmart(n, method_val, surv_data, weights_data, &mut expect);
    Ok(expect)
}
pub fn compute_coxmart(
    n: usize,
    method: i32,
    surv_data: SurvivalData,
    weights: Weights,
    expect: &mut [f64],
) {
    if n == 0 {
        return;
    }
    surv_data.strata[n - 1] = 1;
    let mut denom = 0.0;
    for i in (0..n).rev() {
        if surv_data.strata[i] == 1 {
            denom = 0.0;
        }
        denom += weights.score[i] * weights.wt[i];
        let condition = if i == 0 {
            true
        } else {
            surv_data.strata[i - 1] == 1 || (surv_data.time[i - 1] != surv_data.time[i])
        };
        expect[i] = if condition { denom } else { 0.0 };
    }
    let mut deaths = 0;
    let mut wtsum = 0.0;
    let mut e_denom = 0.0;
    let mut hazard = 0.0;
    let mut lastone = 0;
    let mut current_denom = 0.0;
    for i in 0..n {
        if expect[i] != 0.0 {
            current_denom = expect[i];
        }
        expect[i] = surv_data.status[i] as f64;
        deaths += surv_data.status[i];
        wtsum += surv_data.status[i] as f64 * weights.wt[i];
        e_denom += weights.score[i] * surv_data.status[i] as f64 * weights.wt[i];
        let is_last =
            surv_data.strata[i] == 1 || (i < n - 1 && surv_data.time[i + 1] != surv_data.time[i]);
        if is_last {
            if deaths < 2 || method == 0 {
                hazard += wtsum / current_denom;
                #[allow(clippy::needless_range_loop)]
                for j in lastone..=i {
                    expect[j] -= weights.score[j] * hazard;
                }
            } else {
                let mut temp = hazard;
                let deaths_f = deaths as f64;
                wtsum /= deaths_f;
                for j in 0..deaths {
                    let j_f = j as f64;
                    let downwt = j_f / deaths_f;
                    hazard += wtsum / (current_denom - e_denom * downwt);
                    temp += wtsum * (1.0 - downwt) / (current_denom - e_denom * downwt);
                }
                #[allow(clippy::needless_range_loop)]
                for j in lastone..=i {
                    if surv_data.status[j] == 0 {
                        expect[j] = -weights.score[j] * hazard;
                    } else {
                        expect[j] -= weights.score[j] * temp;
                    }
                }
            }
            lastone = i + 1;
            deaths = 0;
            wtsum = 0.0;
            e_denom = 0.0;
        }
        if surv_data.strata[i] == 1 {
            hazard = 0.0;
        }
    }
    #[allow(clippy::needless_range_loop)]
    for j in lastone..n {
        expect[j] -= weights.score[j] * hazard;
    }
}