survival 1.1.27

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use pyo3::prelude::*;
use rayon::prelude::*;

#[derive(Debug, Clone)]
#[pyclass]
pub struct QALYResult {
    #[pyo3(get)]
    pub qaly: f64,
    #[pyo3(get)]
    pub life_years: f64,
    #[pyo3(get)]
    pub mean_utility: f64,
    #[pyo3(get)]
    pub qaly_by_period: Vec<f64>,
    #[pyo3(get)]
    pub discounted_qaly: f64,
    #[pyo3(get)]
    pub qaly_se: f64,
    #[pyo3(get)]
    pub qaly_ci_lower: f64,
    #[pyo3(get)]
    pub qaly_ci_upper: f64,
}

fn qaly_calculation_internal(
    time: &[f64],
    status: &[i32],
    utility_values: &[f64],
    utility_times: &[f64],
    discount_rate: f64,
    tau: f64,
) -> Option<(f64, f64, f64, Vec<f64>)> {
    let unique_times: Vec<f64> = {
        let mut times = utility_times.to_vec();
        times.push(0.0);
        times.push(tau);
        times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
        times.dedup();
        times.into_iter().filter(|&t| t <= tau).collect()
    };

    let get_utility = |t: f64| -> f64 {
        for (i, &ut) in utility_times.iter().enumerate().rev() {
            if t >= ut && i < utility_values.len() {
                return utility_values[i];
            }
        }
        utility_values.first().copied().unwrap_or(1.0)
    };

    let survival_km = kaplan_meier_estimate(time, status, &unique_times);

    let mut total_qaly = 0.0;
    let mut total_ly = 0.0;
    let mut total_discounted_qaly = 0.0;
    let mut qaly_by_period = Vec::new();

    for i in 1..unique_times.len() {
        let t0 = unique_times[i - 1];
        let t1 = unique_times[i];
        let dt = t1 - t0;

        let s0 = survival_km[i - 1];
        let s1 = survival_km[i];
        let avg_surv = (s0 + s1) / 2.0;

        let utility = get_utility(t0);
        let period_qaly = avg_surv * utility * dt;
        let period_ly = avg_surv * dt;

        let mid_time = (t0 + t1) / 2.0;
        let discount_factor = (1.0 + discount_rate).powf(-mid_time);
        let discounted_period_qaly = period_qaly * discount_factor;

        total_qaly += period_qaly;
        total_ly += period_ly;
        total_discounted_qaly += discounted_period_qaly;
        qaly_by_period.push(period_qaly);
    }

    Some((total_qaly, total_ly, total_discounted_qaly, qaly_by_period))
}

#[pyfunction]
#[pyo3(signature = (
    time,
    status,
    utility_values,
    utility_times,
    discount_rate=0.03,
    horizon=None
))]
pub fn qaly_calculation(
    time: Vec<f64>,
    status: Vec<i32>,
    utility_values: Vec<f64>,
    utility_times: Vec<f64>,
    discount_rate: f64,
    horizon: Option<f64>,
) -> PyResult<QALYResult> {
    let n = time.len();
    if status.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "time and status must have same length",
        ));
    }

    let tau = horizon.unwrap_or_else(|| time.iter().copied().fold(0.0, f64::max));

    let (total_qaly, total_ly, total_discounted_qaly, qaly_by_period) = qaly_calculation_internal(
        &time,
        &status,
        &utility_values,
        &utility_times,
        discount_rate,
        tau,
    )
    .ok_or_else(|| PyErr::new::<pyo3::exceptions::PyValueError, _>("Failed to compute QALY"))?;

    let mean_utility = if total_ly > 0.0 {
        total_qaly / total_ly
    } else {
        0.0
    };

    let (qaly_se, qaly_ci_lower, qaly_ci_upper) = bootstrap_qaly_ci(
        &time,
        &status,
        &utility_values,
        &utility_times,
        tau,
        discount_rate,
        500,
    );

    Ok(QALYResult {
        qaly: total_qaly,
        life_years: total_ly,
        mean_utility,
        qaly_by_period,
        discounted_qaly: total_discounted_qaly,
        qaly_se,
        qaly_ci_lower,
        qaly_ci_upper,
    })
}

fn kaplan_meier_estimate(time: &[f64], status: &[i32], eval_times: &[f64]) -> Vec<f64> {
    let n = time.len();

    let mut indices: Vec<usize> = (0..n).collect();
    indices.sort_by(|&a, &b| {
        time[a]
            .partial_cmp(&time[b])
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    let mut survival = Vec::with_capacity(eval_times.len());
    let mut cum_surv = 1.0;
    let mut at_risk = n;
    let mut time_idx = 0;

    for &t in eval_times {
        while time_idx < n && time[indices[time_idx]] <= t {
            let idx = indices[time_idx];
            if status[idx] == 1 && at_risk > 0 {
                cum_surv *= 1.0 - 1.0 / at_risk as f64;
            }
            at_risk = at_risk.saturating_sub(1);
            time_idx += 1;
        }
        survival.push(cum_surv);
    }

    survival
}

fn bootstrap_qaly_ci(
    time: &[f64],
    status: &[i32],
    utility_values: &[f64],
    utility_times: &[f64],
    tau: f64,
    discount_rate: f64,
    n_bootstrap: usize,
) -> (f64, f64, f64) {
    let n = time.len();

    let qalys: Vec<f64> = (0..n_bootstrap)
        .into_par_iter()
        .filter_map(|b| {
            let mut rng = fastrand::Rng::with_seed(b as u64 + 98765);
            let indices: Vec<usize> = (0..n).map(|_| rng.usize(0..n)).collect();

            let boot_time: Vec<f64> = indices.iter().map(|&i| time[i]).collect();
            let boot_status: Vec<i32> = indices.iter().map(|&i| status[i]).collect();

            qaly_calculation_internal(
                &boot_time,
                &boot_status,
                utility_values,
                utility_times,
                discount_rate,
                tau,
            )
            .map(|(qaly, _, _, _)| qaly)
        })
        .collect();

    if qalys.len() < 2 {
        return (0.0, 0.0, 0.0);
    }

    let mean = qalys.iter().sum::<f64>() / qalys.len() as f64;
    let var = qalys.iter().map(|&q| (q - mean).powi(2)).sum::<f64>() / (qalys.len() - 1) as f64;
    let se = var.sqrt();

    let mut sorted = qalys.clone();
    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));

    let ci_lower = sorted[(qalys.len() as f64 * 0.025) as usize];
    let ci_upper = sorted[(qalys.len() as f64 * 0.975) as usize];

    (se, ci_lower, ci_upper)
}

#[pyfunction]
#[pyo3(signature = (
    time_treated,
    status_treated,
    utility_treated,
    time_control,
    status_control,
    utility_control,
    utility_times,
    discount_rate=0.03,
    horizon=None,
    n_bootstrap=1000
))]
#[allow(clippy::too_many_arguments)]
pub fn qaly_comparison(
    time_treated: Vec<f64>,
    status_treated: Vec<i32>,
    utility_treated: Vec<f64>,
    time_control: Vec<f64>,
    status_control: Vec<i32>,
    utility_control: Vec<f64>,
    utility_times: Vec<f64>,
    discount_rate: f64,
    horizon: Option<f64>,
    n_bootstrap: usize,
) -> PyResult<(QALYResult, QALYResult, f64, f64, f64)> {
    let result_treated = qaly_calculation(
        time_treated.clone(),
        status_treated.clone(),
        utility_treated.clone(),
        utility_times.clone(),
        discount_rate,
        horizon,
    )?;

    let result_control = qaly_calculation(
        time_control.clone(),
        status_control.clone(),
        utility_control.clone(),
        utility_times.clone(),
        discount_rate,
        horizon,
    )?;

    let qaly_diff = result_treated.qaly - result_control.qaly;

    let n_treated = time_treated.len();
    let n_control = time_control.len();
    let tau = horizon.unwrap_or_else(|| {
        time_treated
            .iter()
            .chain(time_control.iter())
            .copied()
            .fold(0.0, f64::max)
    });

    let boot_diffs: Vec<f64> = (0..n_bootstrap)
        .into_par_iter()
        .filter_map(|b| {
            let mut rng = fastrand::Rng::with_seed(b as u64 + 11111);

            let idx_t: Vec<usize> = (0..n_treated).map(|_| rng.usize(0..n_treated)).collect();
            let idx_c: Vec<usize> = (0..n_control).map(|_| rng.usize(0..n_control)).collect();

            let boot_time_t: Vec<f64> = idx_t.iter().map(|&i| time_treated[i]).collect();
            let boot_status_t: Vec<i32> = idx_t.iter().map(|&i| status_treated[i]).collect();
            let boot_time_c: Vec<f64> = idx_c.iter().map(|&i| time_control[i]).collect();
            let boot_status_c: Vec<i32> = idx_c.iter().map(|&i| status_control[i]).collect();

            match (
                qaly_calculation_internal(
                    &boot_time_t,
                    &boot_status_t,
                    &utility_treated,
                    &utility_times,
                    discount_rate,
                    tau,
                ),
                qaly_calculation_internal(
                    &boot_time_c,
                    &boot_status_c,
                    &utility_control,
                    &utility_times,
                    discount_rate,
                    tau,
                ),
            ) {
                (Some((qaly_t, _, _, _)), Some((qaly_c, _, _, _))) => Some(qaly_t - qaly_c),
                _ => None,
            }
        })
        .collect();

    let mut sorted_diffs = boot_diffs.clone();
    sorted_diffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));

    let ci_lower = sorted_diffs[(boot_diffs.len() as f64 * 0.025) as usize];
    let ci_upper = sorted_diffs[(boot_diffs.len() as f64 * 0.975) as usize];

    Ok((
        result_treated,
        result_control,
        qaly_diff,
        ci_lower,
        ci_upper,
    ))
}

#[pyfunction]
pub fn incremental_cost_effectiveness(
    qaly_treated: f64,
    qaly_control: f64,
    cost_treated: f64,
    cost_control: f64,
    wtp_threshold: Option<f64>,
) -> PyResult<(f64, f64, bool)> {
    let delta_qaly = qaly_treated - qaly_control;
    let delta_cost = cost_treated - cost_control;

    let icer = if delta_qaly.abs() > 1e-10 {
        delta_cost / delta_qaly
    } else {
        f64::INFINITY
    };

    let threshold = wtp_threshold.unwrap_or(50000.0);
    let cost_effective = icer <= threshold;

    let nmb = delta_qaly * threshold - delta_cost;

    Ok((icer, nmb, cost_effective))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_qaly_basic() {
        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        let status = vec![1, 0, 1, 0, 1];
        let utility_values = vec![0.8, 0.6];
        let utility_times = vec![0.0, 2.0];

        let result =
            qaly_calculation(time, status, utility_values, utility_times, 0.03, None).unwrap();

        assert!(result.qaly >= 0.0);
        assert!(result.life_years >= 0.0);
        assert!(result.mean_utility >= 0.0 && result.mean_utility <= 1.0);
    }

    #[test]
    fn test_icer() {
        let (icer, _nmb, cost_effective) =
            incremental_cost_effectiveness(10.0, 8.0, 50000.0, 30000.0, Some(50000.0)).unwrap();

        assert_eq!(icer, 10000.0);
        assert!(cost_effective);
    }
}