survival 1.0.10

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use crate::utilities::validation::{
    clamp_probability, validate_length, validate_no_nan, validate_non_empty, validate_non_negative,
};
use pyo3::prelude::*;
use rayon::prelude::*;
#[derive(Debug, Clone)]
#[pyclass]
pub struct SurvFitKMOutput {
    #[pyo3(get)]
    pub time: Vec<f64>,
    #[pyo3(get)]
    pub n_risk: Vec<f64>,
    #[pyo3(get)]
    pub n_event: Vec<f64>,
    #[pyo3(get)]
    pub n_censor: Vec<f64>,
    #[pyo3(get)]
    pub estimate: Vec<f64>,
    #[pyo3(get)]
    pub std_err: Vec<f64>,
    #[pyo3(get)]
    pub conf_lower: Vec<f64>,
    #[pyo3(get)]
    pub conf_upper: Vec<f64>,
}
#[pyfunction]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (time, status, weights=None, entry_times=None, position=None, reverse=None, computation_type=None))]
pub fn survfitkm(
    time: Vec<f64>,
    status: Vec<f64>,
    weights: Option<Vec<f64>>,
    entry_times: Option<Vec<f64>>,
    position: Option<Vec<i32>>,
    reverse: Option<bool>,
    computation_type: Option<i32>,
) -> PyResult<SurvFitKMOutput> {
    validate_non_empty(&time, "time")?;
    validate_length(time.len(), status.len(), "status")?;
    validate_non_negative(&time, "time")?;
    validate_no_nan(&time, "time")?;
    validate_no_nan(&status, "status")?;
    let weights = match weights {
        Some(w) => {
            validate_length(time.len(), w.len(), "weights")?;
            validate_non_negative(&w, "weights")?;
            w
        }
        None => vec![1.0; time.len()],
    };
    let position = match position {
        Some(p) => {
            validate_length(time.len(), p.len(), "position")?;
            p
        }
        None => vec![0; time.len()],
    };
    if let Some(ref entry) = entry_times {
        validate_length(time.len(), entry.len(), "entry_times")?;
    }
    let _reverse = reverse.unwrap_or(false);
    let _computation_type = computation_type.unwrap_or(0);
    Ok(compute_survfitkm(
        &time,
        &status,
        &weights,
        entry_times.as_deref(),
        &position,
        _reverse,
        _computation_type,
    ))
}
pub fn compute_survfitkm(
    time: &[f64],
    status: &[f64],
    weights: &[f64],
    _entry_times: Option<&[f64]>,
    position: &[i32],
    _reverse: bool,
    _computation_type: i32,
) -> SurvFitKMOutput {
    let mut dtime: Vec<f64> = time
        .iter()
        .zip(status)
        .filter_map(|(&t, &s)| if s > 0.0 { Some(t) } else { None })
        .collect();
    dtime.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
    dtime.dedup();
    let ntime = dtime.len();
    let time_stats: Vec<(f64, f64)> = dtime
        .par_iter()
        .map(|&t| {
            (0..time.len())
                .into_par_iter()
                .filter(|&j| (time[j] - t).abs() < 1e-9)
                .map(|j| {
                    if status[j] > 0.0 {
                        (weights[j], 0.0)
                    } else if position[j] & 2 != 0 {
                        (0.0, 1.0)
                    } else {
                        (0.0, 0.0)
                    }
                })
                .reduce(|| (0.0, 0.0), |a, b| (a.0 + b.0, a.1 + b.1))
        })
        .collect();
    let mut n_risk = vec![0.0; ntime];
    let mut n_event = vec![0.0; ntime];
    let mut n_censor = vec![0.0; ntime];
    let mut estimate = vec![1.0; ntime];
    let mut std_err = vec![0.0; ntime];
    let mut current_risk: f64 = weights.iter().sum();
    let mut current_estimate = 1.0;
    let mut cumulative_variance = 0.0;
    for i in (0..ntime).rev() {
        let (weighted_events, censored) = time_stats[i];
        let weighted_risk = current_risk;
        if i < ntime - 1 {
            current_risk -= weighted_events + censored;
        }
        n_risk[i] = weighted_risk;
        n_event[i] = weighted_events;
        n_censor[i] = censored;
        if weighted_risk > 0.0 && weighted_events > 0.0 {
            let hazard = weighted_events / weighted_risk;
            current_estimate *= 1.0 - hazard;
            cumulative_variance += hazard / (weighted_risk - weighted_events);
        }
        estimate[i] = current_estimate;
        std_err[i] = (current_estimate * current_estimate * cumulative_variance).sqrt();
    }
    let z = 1.96;
    let (conf_lower, conf_upper): (Vec<f64>, Vec<f64>) = estimate
        .par_iter()
        .zip(std_err.par_iter())
        .map(|(&s, &se)| {
            if s <= 0.0 || s >= 1.0 || se <= 0.0 {
                (clamp_probability(s), clamp_probability(s))
            } else {
                let log_s = s.ln();
                let log_se = se / s;
                (
                    clamp_probability((log_s - z * log_se).exp()),
                    clamp_probability((log_s + z * log_se).exp()),
                )
            }
        })
        .unzip();
    SurvFitKMOutput {
        time: dtime,
        n_risk,
        n_event,
        n_censor,
        estimate,
        std_err,
        conf_lower,
        conf_upper,
    }
}
#[pymodule]
#[pyo3(name = "survfitkm")]
fn survfitkm_module(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(survfitkm, &m)?)?;
    m.add_class::<SurvFitKMOutput>()?;
    Ok(())
}