survival 1.1.27

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use pyo3::prelude::*;

#[derive(Debug, Clone)]
#[pyclass]
pub struct QTWISTResult {
    #[pyo3(get)]
    pub qtwist: f64,
    #[pyo3(get)]
    pub tox: f64,
    #[pyo3(get)]
    pub twistt: f64,
    #[pyo3(get)]
    pub rel: f64,
    #[pyo3(get)]
    pub total_time: f64,
    #[pyo3(get)]
    pub utility_tox: f64,
    #[pyo3(get)]
    pub utility_rel: f64,
    #[pyo3(get)]
    pub qtwist_difference: Option<f64>,
    #[pyo3(get)]
    pub ci_lower: Option<f64>,
    #[pyo3(get)]
    pub ci_upper: Option<f64>,
}

#[pyfunction]
#[pyo3(signature = (
    time,
    status,
    toxicity_start,
    toxicity_end,
    relapse_time,
    utility_tox=0.5,
    utility_rel=0.5,
    tau=None
))]
#[allow(clippy::too_many_arguments)]
pub fn qtwist_analysis(
    time: Vec<f64>,
    status: Vec<i32>,
    toxicity_start: Vec<Option<f64>>,
    toxicity_end: Vec<Option<f64>>,
    relapse_time: Vec<Option<f64>>,
    utility_tox: f64,
    utility_rel: f64,
    tau: Option<f64>,
) -> PyResult<QTWISTResult> {
    let n = time.len();
    if status.len() != n
        || toxicity_start.len() != n
        || toxicity_end.len() != n
        || relapse_time.len() != n
    {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "All input arrays must have same length",
        ));
    }

    let tau_val = tau.unwrap_or_else(|| time.iter().copied().fold(0.0, f64::max));

    let mut total_tox = 0.0;
    let mut total_twistt = 0.0;
    let mut total_rel = 0.0;
    let mut total_time_sum = 0.0;

    for i in 0..n {
        let obs_time = time[i].min(tau_val);

        let tox_duration = match (toxicity_start[i], toxicity_end[i]) {
            (Some(start), Some(end)) => (end.min(obs_time) - start.max(0.0)).max(0.0),
            (Some(start), None) => (obs_time - start.max(0.0)).max(0.0),
            _ => 0.0,
        };

        let relapse_duration = match relapse_time[i] {
            Some(rel) if rel < obs_time => obs_time - rel,
            _ => 0.0,
        };

        let twistt_duration = obs_time - tox_duration - relapse_duration;

        total_tox += tox_duration;
        total_twistt += twistt_duration.max(0.0);
        total_rel += relapse_duration;
        total_time_sum += obs_time;
    }

    let mean_tox = total_tox / n as f64;
    let mean_twistt = total_twistt / n as f64;
    let mean_rel = total_rel / n as f64;
    let mean_total = total_time_sum / n as f64;

    let qtwist = utility_tox * mean_tox + mean_twistt + utility_rel * mean_rel;

    Ok(QTWISTResult {
        qtwist,
        tox: mean_tox,
        twistt: mean_twistt,
        rel: mean_rel,
        total_time: mean_total,
        utility_tox,
        utility_rel,
        qtwist_difference: None,
        ci_lower: None,
        ci_upper: None,
    })
}

#[pyfunction]
#[pyo3(signature = (
    time_treated,
    status_treated,
    tox_start_treated,
    tox_end_treated,
    relapse_treated,
    time_control,
    status_control,
    tox_start_control,
    tox_end_control,
    relapse_control,
    utility_tox=0.5,
    utility_rel=0.5,
    tau=None,
    n_bootstrap=1000
))]
#[allow(clippy::too_many_arguments)]
pub fn qtwist_comparison(
    time_treated: Vec<f64>,
    status_treated: Vec<i32>,
    tox_start_treated: Vec<Option<f64>>,
    tox_end_treated: Vec<Option<f64>>,
    relapse_treated: Vec<Option<f64>>,
    time_control: Vec<f64>,
    status_control: Vec<i32>,
    tox_start_control: Vec<Option<f64>>,
    tox_end_control: Vec<Option<f64>>,
    relapse_control: Vec<Option<f64>>,
    utility_tox: f64,
    utility_rel: f64,
    tau: Option<f64>,
    n_bootstrap: usize,
) -> PyResult<(QTWISTResult, QTWISTResult, f64, f64, f64)> {
    let result_treated = qtwist_analysis(
        time_treated.clone(),
        status_treated.clone(),
        tox_start_treated.clone(),
        tox_end_treated.clone(),
        relapse_treated.clone(),
        utility_tox,
        utility_rel,
        tau,
    )?;

    let result_control = qtwist_analysis(
        time_control.clone(),
        status_control.clone(),
        tox_start_control.clone(),
        tox_end_control.clone(),
        relapse_control.clone(),
        utility_tox,
        utility_rel,
        tau,
    )?;

    let qtwist_diff = result_treated.qtwist - result_control.qtwist;

    let n_treated = time_treated.len();
    let n_control = time_control.len();

    let mut boot_diffs: Vec<f64> = Vec::with_capacity(n_bootstrap);

    for b in 0..n_bootstrap {
        let mut rng = fastrand::Rng::with_seed(b as u64);

        let boot_treated_idx: Vec<usize> =
            (0..n_treated).map(|_| rng.usize(0..n_treated)).collect();
        let boot_control_idx: Vec<usize> =
            (0..n_control).map(|_| rng.usize(0..n_control)).collect();

        let boot_time_t: Vec<f64> = boot_treated_idx.iter().map(|&i| time_treated[i]).collect();
        let boot_status_t: Vec<i32> = boot_treated_idx
            .iter()
            .map(|&i| status_treated[i])
            .collect();
        let boot_tox_start_t: Vec<Option<f64>> = boot_treated_idx
            .iter()
            .map(|&i| tox_start_treated[i])
            .collect();
        let boot_tox_end_t: Vec<Option<f64>> = boot_treated_idx
            .iter()
            .map(|&i| tox_end_treated[i])
            .collect();
        let boot_relapse_t: Vec<Option<f64>> = boot_treated_idx
            .iter()
            .map(|&i| relapse_treated[i])
            .collect();

        let boot_time_c: Vec<f64> = boot_control_idx.iter().map(|&i| time_control[i]).collect();
        let boot_status_c: Vec<i32> = boot_control_idx
            .iter()
            .map(|&i| status_control[i])
            .collect();
        let boot_tox_start_c: Vec<Option<f64>> = boot_control_idx
            .iter()
            .map(|&i| tox_start_control[i])
            .collect();
        let boot_tox_end_c: Vec<Option<f64>> = boot_control_idx
            .iter()
            .map(|&i| tox_end_control[i])
            .collect();
        let boot_relapse_c: Vec<Option<f64>> = boot_control_idx
            .iter()
            .map(|&i| relapse_control[i])
            .collect();

        if let (Ok(res_t), Ok(res_c)) = (
            qtwist_analysis(
                boot_time_t,
                boot_status_t,
                boot_tox_start_t,
                boot_tox_end_t,
                boot_relapse_t,
                utility_tox,
                utility_rel,
                tau,
            ),
            qtwist_analysis(
                boot_time_c,
                boot_status_c,
                boot_tox_start_c,
                boot_tox_end_c,
                boot_relapse_c,
                utility_tox,
                utility_rel,
                tau,
            ),
        ) {
            boot_diffs.push(res_t.qtwist - res_c.qtwist);
        }
    }

    boot_diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));

    let ci_lower = boot_diffs[(n_bootstrap as f64 * 0.025) as usize];
    let ci_upper = boot_diffs[(n_bootstrap as f64 * 0.975) as usize];

    Ok((
        result_treated,
        result_control,
        qtwist_diff,
        ci_lower,
        ci_upper,
    ))
}

#[pyfunction]
#[allow(clippy::too_many_arguments)]
pub fn qtwist_sensitivity(
    time: Vec<f64>,
    status: Vec<i32>,
    toxicity_start: Vec<Option<f64>>,
    toxicity_end: Vec<Option<f64>>,
    relapse_time: Vec<Option<f64>>,
    utility_tox_range: Vec<f64>,
    utility_rel_range: Vec<f64>,
    tau: Option<f64>,
) -> PyResult<Vec<(f64, f64, f64)>> {
    let mut results = Vec::new();

    for &u_tox in &utility_tox_range {
        for &u_rel in &utility_rel_range {
            if let Ok(result) = qtwist_analysis(
                time.clone(),
                status.clone(),
                toxicity_start.clone(),
                toxicity_end.clone(),
                relapse_time.clone(),
                u_tox,
                u_rel,
                tau,
            ) {
                results.push((u_tox, u_rel, result.qtwist));
            }
        }
    }

    Ok(results)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_qtwist_basic() {
        let time = vec![10.0, 15.0, 20.0, 25.0];
        let status = vec![1, 0, 1, 0];
        let tox_start = vec![Some(1.0), Some(2.0), None, Some(1.0)];
        let tox_end = vec![Some(5.0), Some(7.0), None, Some(3.0)];
        let relapse = vec![None, None, Some(15.0), None];

        let result =
            qtwist_analysis(time, status, tox_start, tox_end, relapse, 0.5, 0.5, None).unwrap();

        assert!(result.qtwist >= 0.0);
        assert!(result.tox >= 0.0);
        assert!(result.twistt >= 0.0);
    }
}