survival 1.1.29

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use pyo3::prelude::*;

#[derive(Debug, Clone)]
#[pyclass]
pub struct CoxphDetailRow {
    #[pyo3(get)]
    pub time: f64,
    #[pyo3(get)]
    pub n_risk: usize,
    #[pyo3(get)]
    pub n_event: usize,
    #[pyo3(get)]
    pub n_censor: usize,
    #[pyo3(get)]
    pub hazard: f64,
    #[pyo3(get)]
    pub cumhaz: f64,
    #[pyo3(get)]
    pub score: Vec<f64>,
    #[pyo3(get)]
    pub schoenfeld: Option<Vec<f64>>,
    #[pyo3(get)]
    pub means: Vec<f64>,
    #[pyo3(get)]
    pub imat: Vec<Vec<f64>>,
}

#[pymethods]
impl CoxphDetailRow {
    fn __repr__(&self) -> String {
        format!(
            "CoxphDetailRow(time={:.4}, n_risk={}, n_event={}, hazard={:.6})",
            self.time, self.n_risk, self.n_event, self.hazard
        )
    }
}

#[derive(Debug, Clone)]
#[pyclass]
pub struct CoxphDetail {
    #[pyo3(get)]
    pub rows: Vec<CoxphDetailRow>,
    #[pyo3(get)]
    pub n_events: usize,
    #[pyo3(get)]
    pub n_observations: usize,
    #[pyo3(get)]
    pub n_covariates: usize,
}

#[pymethods]
impl CoxphDetail {
    fn __repr__(&self) -> String {
        format!(
            "CoxphDetail(n_events={}, n_obs={}, n_times={})",
            self.n_events,
            self.n_observations,
            self.rows.len()
        )
    }

    pub fn times(&self) -> Vec<f64> {
        self.rows.iter().map(|r| r.time).collect()
    }

    pub fn hazards(&self) -> Vec<f64> {
        self.rows.iter().map(|r| r.hazard).collect()
    }

    pub fn cumulative_hazards(&self) -> Vec<f64> {
        self.rows.iter().map(|r| r.cumhaz).collect()
    }

    pub fn n_risk_at_times(&self) -> Vec<usize> {
        self.rows.iter().map(|r| r.n_risk).collect()
    }

    pub fn schoenfeld_residuals(&self) -> Vec<Vec<f64>> {
        self.rows
            .iter()
            .filter_map(|r| r.schoenfeld.clone())
            .collect()
    }
}

pub fn compute_coxph_detail(
    time: &[f64],
    status: &[i32],
    covariates: &[Vec<f64>],
    coefficients: &[f64],
    weights: Option<&[f64]>,
) -> CoxphDetail {
    let n = time.len();
    let nvar = coefficients.len();

    if n == 0 || nvar == 0 {
        return CoxphDetail {
            rows: vec![],
            n_events: 0,
            n_observations: 0,
            n_covariates: nvar,
        };
    }

    let risk_scores: Vec<f64> = covariates
        .iter()
        .map(|cov| {
            let mut lp = 0.0;
            for (j, &c) in cov.iter().enumerate() {
                if j < nvar {
                    lp += c * coefficients[j];
                }
            }
            lp.exp()
        })
        .collect();

    let wts: Vec<f64> = weights.map(|w| w.to_vec()).unwrap_or_else(|| vec![1.0; n]);

    let mut indices: Vec<usize> = (0..n).collect();
    indices.sort_by(|&a, &b| {
        time[a]
            .partial_cmp(&time[b])
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    let mut unique_times: Vec<f64> = Vec::new();
    let mut event_counts: Vec<usize> = Vec::new();
    let mut censor_counts: Vec<usize> = Vec::new();
    let mut risk_set_sizes: Vec<usize> = Vec::new();
    let mut risk_sums: Vec<f64> = Vec::new();
    let mut weighted_means: Vec<Vec<f64>> = Vec::new();
    let mut event_covariates: Vec<Vec<f64>> = Vec::new();

    let mut i = 0;
    while i < n {
        let current_time = time[indices[i]];
        let mut n_event = 0;
        let mut n_censor = 0;
        let mut event_cov_sum = vec![0.0; nvar];
        let start_i = i;

        while i < n && (time[indices[i]] - current_time).abs() < 1e-10 {
            if status[indices[i]] == 1 {
                n_event += 1;
                for (k, &c) in covariates[indices[i]].iter().enumerate() {
                    if k < nvar {
                        event_cov_sum[k] += c;
                    }
                }
            } else {
                n_censor += 1;
            }
            i += 1;
        }

        if n_event > 0 {
            let risk_set_size = n - start_i;
            let mut risk_sum = 0.0;
            let mut weighted_mean = vec![0.0; nvar];

            for &idx in indices.iter().take(n).skip(start_i) {
                let w = wts[idx] * risk_scores[idx];
                risk_sum += w;
                for (k, &c) in covariates[idx].iter().enumerate() {
                    if k < nvar {
                        weighted_mean[k] += w * c;
                    }
                }
            }

            if risk_sum > 0.0 {
                for wm in weighted_mean.iter_mut().take(nvar) {
                    *wm /= risk_sum;
                }
            }

            unique_times.push(current_time);
            event_counts.push(n_event);
            censor_counts.push(n_censor);
            risk_set_sizes.push(risk_set_size);
            risk_sums.push(risk_sum);
            weighted_means.push(weighted_mean);
            event_covariates.push(event_cov_sum);
        }
    }

    let n_unique = unique_times.len();
    let mut rows = Vec::with_capacity(n_unique);
    let mut cumhaz = 0.0;

    for t in 0..n_unique {
        let hazard = if risk_sums[t] > 0.0 {
            event_counts[t] as f64 / risk_sums[t]
        } else {
            0.0
        };
        cumhaz += hazard;

        let schoenfeld = if event_counts[t] > 0 {
            let mut scho = vec![0.0; nvar];
            for k in 0..nvar {
                scho[k] = event_covariates[t][k] - event_counts[t] as f64 * weighted_means[t][k];
            }
            Some(scho)
        } else {
            None
        };

        let score = vec![0.0; nvar];
        let imat = vec![vec![0.0; nvar]; nvar];

        rows.push(CoxphDetailRow {
            time: unique_times[t],
            n_risk: risk_set_sizes[t],
            n_event: event_counts[t],
            n_censor: censor_counts[t],
            hazard,
            cumhaz,
            score,
            schoenfeld,
            means: weighted_means[t].clone(),
            imat,
        });
    }

    let total_events = event_counts.iter().sum();

    CoxphDetail {
        rows,
        n_events: total_events,
        n_observations: n,
        n_covariates: nvar,
    }
}

#[pyfunction]
#[pyo3(signature = (time, status, covariates, coefficients, weights=None))]
pub fn coxph_detail(
    time: Vec<f64>,
    status: Vec<i32>,
    covariates: Vec<Vec<f64>>,
    coefficients: Vec<f64>,
    weights: Option<Vec<f64>>,
) -> PyResult<CoxphDetail> {
    let n = time.len();
    if status.len() != n || covariates.len() != n {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "time, status, and covariates must have the same length",
        ));
    }

    Ok(compute_coxph_detail(
        &time,
        &status,
        &covariates,
        &coefficients,
        weights.as_deref(),
    ))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_coxph_detail() {
        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        let status = vec![1, 0, 1, 0, 1];
        let covariates = vec![vec![1.0], vec![2.0], vec![1.5], vec![2.5], vec![3.0]];
        let coefficients = vec![0.5];

        let detail = compute_coxph_detail(&time, &status, &covariates, &coefficients, None);

        assert_eq!(detail.n_events, 3);
        assert_eq!(detail.n_observations, 5);
        assert_eq!(detail.rows.len(), 3);
    }
}