survival 1.0.17

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use crate::utilities::validation::{ValidationError, validate_length};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyDict;

fn validation_err_to_pyresult<T>(result: Result<T, ValidationError>) -> PyResult<T> {
    result.map_err(|e| PyRuntimeError::new_err(e.to_string()))
}

pub fn validate_concordance_inputs(
    time_data_len: usize,
    n: usize,
    indices_len: usize,
    weights_len: usize,
) -> PyResult<()> {
    if n == 0 {
        return Err(PyRuntimeError::new_err("No observations provided"));
    }
    validation_err_to_pyresult(validate_length(2 * n, time_data_len, "time_data"))?;
    validation_err_to_pyresult(validate_length(n, indices_len, "indices"))?;
    validation_err_to_pyresult(validate_length(n, weights_len, "weights"))?;
    Ok(())
}

pub fn validate_extended_concordance_inputs(
    time_data_len: usize,
    n: usize,
    indices_len: usize,
    weights_len: usize,
    time_weights_len: usize,
    sort_stop_len: usize,
) -> PyResult<()> {
    validate_concordance_inputs(time_data_len, n, indices_len, weights_len)?;
    validation_err_to_pyresult(validate_length(n, time_weights_len, "time_weights"))?;
    validation_err_to_pyresult(validate_length(n, sort_stop_len, "sort_stop"))?;
    Ok(())
}
pub fn build_concordance_result(
    py: Python<'_>,
    count: &[f64],
    imat: Option<&[f64]>,
    resid: Option<&[f64]>,
    n: Option<usize>,
) -> PyResult<Py<PyDict>> {
    let concordant = count[0];
    let discordant = count[1];
    let tied_x = count[2];
    let tied_y = count[3];
    let tied_xy = count.get(4).copied().unwrap_or(0.0);
    let variance = count.get(5).copied();
    let total_pairs = concordant + discordant + tied_x + tied_y + tied_xy;
    let concordance_index = if total_pairs > 0.0 {
        (concordant + 0.5 * (tied_x + tied_y + tied_xy)) / total_pairs
    } else {
        0.0
    };
    let dict = PyDict::new(py);
    dict.set_item("concordant", concordant)?;
    dict.set_item("discordant", discordant)?;
    dict.set_item("tied_x", tied_x)?;
    dict.set_item("tied_y", tied_y)?;
    dict.set_item("tied_xy", tied_xy)?;
    dict.set_item("concordance_index", concordance_index)?;
    dict.set_item("total_pairs", total_pairs)?;
    if let Some(v) = variance {
        dict.set_item("variance", v)?;
    }
    if let Some(imat_data) = imat {
        dict.set_item("information_matrix", imat_data.to_vec())?;
    }
    if let Some(resid_data) = resid {
        dict.set_item("residuals", resid_data.to_vec())?;
    }
    if let Some(n_obs) = n {
        dict.set_item("n_observations", n_obs)?;
    }
    Ok(dict.into())
}
pub fn walkup_binary_tree(nwt: &[f64], twt: &[f64], index: usize, ntree: usize) -> [f64; 3] {
    let mut sums = [0.0; 3];
    if index >= ntree {
        return sums;
    }
    sums[2] = nwt[index];
    let right_child = 2 * index + 2;
    if right_child < ntree {
        sums[0] += twt[right_child];
    }
    let left_child = 2 * index + 1;
    if left_child < ntree {
        sums[1] += twt[left_child];
    }
    let mut current = index;
    while current > 0 {
        let parent = (current - 1) / 2;
        let parent_twt = twt[parent];
        let current_twt = twt[current];
        if current % 2 == 1 {
            sums[0] += parent_twt - current_twt;
        } else {
            sums[1] += parent_twt - current_twt;
        }
        current = parent;
    }
    sums
}
pub fn add_to_binary_tree(nwt: &mut [f64], twt: &mut [f64], index: usize, wt: f64) {
    nwt[index] += wt;
    let mut current = index;
    while current > 0 {
        let parent = (current - 1) / 2;
        twt[parent] += wt;
        current = parent;
    }
    twt[0] += wt;
}