survival 1.0.10

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use ndarray::{Array, Array2, Axis, concatenate};
use polars::prelude::*;
use approx::assert_abs_diff_eq;

use survival::prelude::*;

fn aeq(x: &Array<f64, ndarray::Ix1>, y: &Array<f64, ndarray::Ix1>, tol: f64) -> bool {
    x.iter().zip(y.iter()).all(|(a, b)| (a - b).abs() < tol)
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let mut tdata = DataFrame::from_csv("myeloid.csv")?;
    
    process_tmerge(&mut tdata)?;
    create_event_factor(&mut tdata)?;

    let formula = "Surv(tstart, tstop, event) ~ trt + sex";
    let fit = CoxPH::fit(formula, &tdata)
        .id("id")
        .iter(4)
        .robust(false)
        .build()?;

    let fit12 = CoxPH::fit("Surv(tstart, tstop, sct) ~ trt + sex", &tdata)
        .subset("priortx == 0")
        .method(Method::Breslow)
        .build()?;

    let fit13 = CoxPH::fit("Surv(tstart, tstop, death) ~ trt + sex", &tdata)
        .subset("priortx == 0")
        .method(Method::Breslow)
        .build()?;

    let fit23 = CoxPH::fit("Surv(tstart, tstop, death) ~ trt + sex", &tdata)
        .subset("priortx == 1")
        .method(Method::Breslow)
        .build()?;

    let combined_coef = concatenate![
        Axis(0),
        fit12.coefficients().view(),
        fit13.coefficients().view(),
        fit23.coefficients().view()
    ];
    assert!(aeq(&fit.coefficients(), &combined_coef, 1e-6));

    let total_ll = fit12.loglik() + fit13.loglik() + fit23.loglik();
    assert_abs_diff_eq!(fit.loglik(), total_ll, epsilon = 1e-6);

    let mut combined_var = Array2::zeros((6, 6));
    combined_var.slice_mut(s![0..2, 0..2]).assign(&fit12.var());
    combined_var.slice_mut(s![2..4, 2..4]).assign(&fit13.var());
    combined_var.slice_mut(s![4..6, 4..6]).assign(&fit23.var());
    assert!(aeq(&fit.var().into_owned(), &combined_var.into_raw_vec(), 1e-6));

    let res = fit.residuals(ResidualType::Martingale);
    let res12 = fit12.residuals(ResidualType::Martingale);
    let res13 = fit13.residuals(ResidualType::Martingale);
    let res23 = fit23.residuals(ResidualType::Martingale);
    let combined_res = concatenate![Axis(0), res12, res13, res23];
    assert!(aeq(&res, &combined_res, 1e-6));

    let score_res = fit.score_residuals();
    let x = fit.model_matrix();
    let x12 = fit12.model_matrix();
    let x13 = fit13.model_matrix();
    let x23 = fit23.model_matrix();

    Ok(())
}

fn process_tmerge(df: &mut DataFrame) -> Result<(), Box<dyn std::error::Error>> {
    Ok(())
}

fn create_event_factor(df: &mut DataFrame) -> Result<(), Box<dyn std::error::Error>> {
    Ok(())
}