survival 1.1.29

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use super::column_major_index;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyDict;

pub struct PyearsExpectedParams<'a> {
    pub dim: usize,
    pub fac: &'a [i32],
    pub dims: &'a [usize],
    pub cut: &'a [f64],
    pub rates: &'a [f64],
    pub data: &'a [f64],
}

pub struct PyearsObservedParams<'a> {
    pub dim: usize,
    pub fac: &'a [i32],
    pub dims: &'a [usize],
    pub cut: &'a [f64],
    pub data: &'a [f64],
}

pub struct PyearsOutput<'a> {
    pub pyears: &'a mut [f64],
    pub pn: &'a mut [f64],
    pub pcount: &'a mut [f64],
    pub pexpect: &'a mut [f64],
    pub offtable: &'a mut f64,
}

#[allow(clippy::too_many_arguments)]
pub fn pyears3b(
    n: usize,
    ny: usize,
    doevent: i32,
    y: &[f64],
    weight: &[f64],
    expected: PyearsExpectedParams<'_>,
    observed: PyearsObservedParams<'_>,
    method: i32,
    output: &mut PyearsOutput<'_>,
) {
    let PyearsExpectedParams {
        dim: edim,
        fac: efac,
        dims: edims,
        cut: ecut,
        rates: expect,
        data: edata,
    } = expected;
    let PyearsObservedParams {
        dim: odim,
        fac: ofac,
        dims: odims,
        cut: ocut,
        data: odata,
    } = observed;
    let PyearsOutput {
        pyears,
        pn,
        pcount,
        pexpect,
        offtable,
    } = output;
    let (start, stop, event) = if ny == 3 || (ny == 2 && doevent == 0) {
        let start = &y[0..n];
        let stop = &y[n..2 * n];
        let event = if ny == 3 { &y[2 * n..3 * n] } else { &[] };
        (start, stop, event)
    } else {
        let stop = &y[0..n];
        let event = &y[n..2 * n];
        (&[] as &[f64], stop, event)
    };
    let mut ecut_slices = Vec::with_capacity(edim);
    let mut ecut_ptr = ecut;
    for j in 0..edim {
        let len = if efac[j] == 0 {
            edims[j]
        } else if efac[j] > 1 {
            1 + (efac[j] - 1) as usize * edims[j]
        } else {
            0
        };
        if len > 0 {
            ecut_slices.push(&ecut_ptr[0..len]);
            ecut_ptr = &ecut_ptr[len..];
        } else {
            ecut_slices.push(&[]);
        }
    }
    let mut ocut_slices = Vec::with_capacity(odim);
    let mut ocut_ptr = ocut;
    for j in 0..odim {
        if ofac[j] == 0 {
            let len = odims[j] + 1;
            ocut_slices.push(&ocut_ptr[0..len]);
            ocut_ptr = &ocut_ptr[len..];
        } else {
            ocut_slices.push(&[]);
        }
    }
    let mut eps = 0.0;
    for i in 0..n {
        let timeleft = if start.is_empty() {
            stop[i]
        } else {
            stop[i] - start[i]
        };
        if timeleft > 0.0 {
            eps = timeleft;
            break;
        }
    }
    for i in 0..n {
        let timeleft = if start.is_empty() {
            stop[i]
        } else {
            stop[i] - start[i]
        };
        if timeleft > 0.0 && timeleft < eps {
            eps = timeleft;
        }
    }
    eps *= 1e-8;
    **offtable = 0.0;
    for i in 0..n {
        let mut data = vec![0.0; odim];
        let mut data2 = vec![0.0; edim];
        for j in 0..odim {
            if ofac[j] == 1 || start.is_empty() {
                data[j] = odata[j * n + i];
            } else {
                data[j] = odata[j * n + i] + start[i];
            }
        }
        for j in 0..edim {
            if efac[j] == 1 || start.is_empty() {
                data2[j] = edata[j * n + i];
            } else {
                data2[j] = edata[j * n + i] + start[i];
            }
        }
        let mut timeleft = if start.is_empty() {
            stop[i]
        } else {
            stop[i] - start[i]
        };
        let mut cumhaz = 0.0;
        if timeleft <= eps && doevent == 1 {
            let (_, _idx, _, _) =
                pystep(odim, &mut data.clone(), ofac, odims, &ocut_slices, timeleft);
        }
        while timeleft > eps {
            let mut data_current = data.clone();
            let (thiscell, idx, _idx2, _lwt) =
                pystep(odim, &mut data_current, ofac, odims, &ocut_slices, timeleft);
            data.copy_from_slice(&data_current);
            pyears[idx] += thiscell * weight[i];
            pn[idx] += 1.0;
            if doevent == 1 && !event.is_empty() && event[i] > 0.0 {
                pcount[idx] += weight[i];
            }
            let mut etime = thiscell;
            let mut hazard = 0.0;
            let mut temp = 0.0;
            let mut data2_current = data2.clone();
            while etime > 0.0 {
                let (et2, edx, edx2, elwt) =
                    pystep(edim, &mut data2_current, efac, edims, &ecut_slices, etime);
                let lambda = if elwt < 1.0 {
                    elwt * expect[edx] + (1.0 - elwt) * expect[edx2]
                } else {
                    expect[edx]
                };
                if method == 0 {
                    let neg_hazard: f64 = -hazard;
                    let neg_lambda_et2: f64 = -lambda * et2;
                    temp += neg_hazard.exp() * (1.0 - neg_lambda_et2.exp()) / lambda;
                }
                hazard += lambda * et2;
                for j in 0..edim {
                    if efac[j] != 1 {
                        data2_current[j] += et2;
                    }
                }
                etime -= et2;
            }
            if method == 1 {
                pexpect[idx] += hazard * weight[i];
            } else {
                let neg_cumhaz: f64 = -cumhaz;
                pexpect[idx] += neg_cumhaz.exp() * temp * weight[i];
            }
            cumhaz += hazard;
            timeleft -= thiscell;
        }
    }
}

fn pystep(
    edim: usize,
    data: &mut [f64],
    efac: &[i32],
    edims: &[usize],
    ecut: &[&[f64]],
    tmax: f64,
) -> (f64, usize, usize, f64) {
    let mut et2 = tmax;
    let mut wt = 1.0;
    let mut limiting_dim = None;
    for j in 0..edim {
        if efac[j] != 0 {
            continue;
        }
        let cuts = ecut[j];
        let current = data[j];
        let pos = cuts.partition_point(|&x| x <= current);
        if pos < cuts.len() {
            let next_cut = cuts[pos];
            let delta = (next_cut - current).max(0.0);
            if delta < et2 {
                et2 = delta;
                limiting_dim = Some(j);
            }
        }
    }
    et2 = et2.min(tmax);
    let mut indices_current = vec![0; edim];
    let mut indices_next = vec![0; edim];
    for j in 0..edim {
        if efac[j] == 0 {
            data[j] += et2;
            let cuts = ecut[j];
            let pos = cuts.partition_point(|&x| x <= data[j]) - 1;
            indices_current[j] = pos.min(edims[j] - 1);
            indices_next[j] = (pos + 1).min(edims[j] - 1);
        } else {
            indices_current[j] = data[j] as usize - 1;
            indices_next[j] = indices_current[j];
        }
    }
    let indx = column_major_index(&indices_current, edims);
    let indx2 = column_major_index(&indices_next, edims);
    if let Some(dim) = limiting_dim {
        let current = data[dim] - et2;
        let cuts = ecut[dim];
        let pos = cuts.partition_point(|&x| x <= current) - 1;
        let next_cut = if pos + 1 < cuts.len() {
            cuts[pos + 1]
        } else {
            cuts[pos]
        };
        let prev_cut = cuts[pos];
        let width = next_cut - prev_cut;
        wt = if width > 0.0 {
            (current + et2 - prev_cut) / width
        } else {
            1.0
        };
        wt = wt.clamp(0.0, 1.0);
    }
    (et2, indx, indx2, wt)
}
#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn perform_pyears_calculation(
    time_data: Vec<f64>,
    weights: Vec<f64>,
    expected_dim: usize,
    expected_factors: Vec<i32>,
    expected_dims: Vec<usize>,
    expected_cuts: Vec<f64>,
    expected_rates: Vec<f64>,
    expected_data: Vec<f64>,
    observed_dim: usize,
    observed_factors: Vec<i32>,
    observed_dims: Vec<usize>,
    observed_cuts: Vec<f64>,
    method: i32,
    observed_data: Vec<f64>,
    do_event: Option<i32>,
    ny: Option<usize>,
) -> PyResult<Py<PyAny>> {
    let n = weights.len();
    if n == 0 {
        return Err(PyRuntimeError::new_err("No observations provided"));
    }
    let ny = ny.unwrap_or(2);
    let doevent = do_event.unwrap_or(1);
    let mut total_observed = 1;
    for &dim in &observed_dims {
        total_observed *= dim;
    }
    let _total_expected = {
        let mut result = 1;
        for &dim in &expected_dims {
            result *= dim;
        }
        result
    };
    let mut pyears = vec![0.0; total_observed];
    let mut pn = vec![0.0; total_observed];
    let mut pcount = vec![0.0; total_observed];
    let mut pexpect = vec![0.0; total_observed];
    let mut offtable = 0.0;
    let expected = PyearsExpectedParams {
        dim: expected_dim,
        fac: &expected_factors,
        dims: &expected_dims,
        cut: &expected_cuts,
        rates: &expected_rates,
        data: &expected_data,
    };
    let observed = PyearsObservedParams {
        dim: observed_dim,
        fac: &observed_factors,
        dims: &observed_dims,
        cut: &observed_cuts,
        data: &observed_data,
    };
    let mut output = PyearsOutput {
        pyears: &mut pyears,
        pn: &mut pn,
        pcount: &mut pcount,
        pexpect: &mut pexpect,
        offtable: &mut offtable,
    };
    pyears3b(
        n,
        ny,
        doevent,
        &time_data,
        &weights,
        expected,
        observed,
        method,
        &mut output,
    );
    Python::attach(|py| {
        let dict = PyDict::new(py);
        dict.set_item("pyears", pyears)?;
        dict.set_item("pn", pn)?;
        dict.set_item("pcount", pcount)?;
        dict.set_item("pexpect", pexpect)?;
        dict.set_item("offtable", offtable)?;
        Ok(dict.into())
    })
}