survival 1.1.29

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};

type CoxCallbackResult = (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>, Vec<i32>);

#[pyfunction]
pub fn cox_callback(
    which: i32,
    mut coef: Vec<f64>,
    mut first: Vec<f64>,
    mut second: Vec<f64>,
    mut penalty: Vec<f64>,
    mut flag: Vec<i32>,
    fexpr: &Bound<PyAny>,
) -> PyResult<CoxCallbackResult> {
    let py = fexpr.py();
    let coef_vec: Vec<f64> = coef.to_vec();
    let coef_list = PyList::new(py, &coef_vec)?;
    let kwargs = PyDict::new(py);
    kwargs.set_item("which", which)?;
    let result = fexpr.call((coef_list.as_any(),), Some(&kwargs))?;
    let dict = result.cast::<PyDict>()?;
    macro_rules! extract_values {
        ($key:expr, $rust_slice:expr, $pytype:ty) => {
            let item = dict.get_item($key)?.ok_or_else(|| {
                PyErr::new::<pyo3::exceptions::PyKeyError, _>(format!("Missing key: {}", $key))
            })?;
            let py_values = item.cast::<PyList>()?;
            if py_values.len() != $rust_slice.len() {
                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                    "Invalid length for {}",
                    $key
                )));
            }
            for (i, item) in py_values.iter().enumerate() {
                $rust_slice[i] = item.extract::<$pytype>()?;
            }
        };
    }
    extract_values!("coef", coef, f64);
    extract_values!("first", first, f64);
    extract_values!("second", second, f64);
    extract_values!("penalty", penalty, f64);
    let flag_item = dict
        .get_item("flag")?
        .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyKeyError, _>("Missing key: flag"))?;
    let py_flags = flag_item.cast::<PyList>()?;
    if py_flags.len() != flag.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "Invalid length for flag",
        ));
    }
    for (i, item) in py_flags.iter().enumerate() {
        flag[i] = match item.extract::<bool>() {
            Ok(b) => b as i32,
            Err(_) => item.extract::<i32>()?,
        };
    }
    Ok((coef, first, second, penalty, flag))
}