survival 1.1.27

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use crate::utilities::statistical::chi2_sf;
use pyo3::prelude::*;

#[derive(Debug, Clone)]
#[pyclass]
pub struct AnovaRow {
    #[pyo3(get)]
    pub model_name: String,
    #[pyo3(get)]
    pub loglik: f64,
    #[pyo3(get)]
    pub df: usize,
    #[pyo3(get)]
    pub chisq: Option<f64>,
    #[pyo3(get)]
    pub p_value: Option<f64>,
}

#[pymethods]
impl AnovaRow {
    #[new]
    pub fn new(
        model_name: String,
        loglik: f64,
        df: usize,
        chisq: Option<f64>,
        p_value: Option<f64>,
    ) -> Self {
        Self {
            model_name,
            loglik,
            df,
            chisq,
            p_value,
        }
    }

    fn __repr__(&self) -> String {
        match (self.chisq, self.p_value) {
            (Some(chi), Some(p)) => format!(
                "AnovaRow(model='{}', loglik={:.4}, df={}, chisq={:.4}, p={:.4})",
                self.model_name, self.loglik, self.df, chi, p
            ),
            _ => format!(
                "AnovaRow(model='{}', loglik={:.4}, df={})",
                self.model_name, self.loglik, self.df
            ),
        }
    }
}

#[derive(Debug, Clone)]
#[pyclass]
pub struct AnovaCoxphResult {
    #[pyo3(get)]
    pub rows: Vec<AnovaRow>,
    #[pyo3(get)]
    pub test_type: String,
}

#[pymethods]
impl AnovaCoxphResult {
    #[new]
    pub fn new(rows: Vec<AnovaRow>, test_type: String) -> Self {
        Self { rows, test_type }
    }

    fn __repr__(&self) -> String {
        let rows_str: Vec<String> = self.rows.iter().map(|r| r.__repr__()).collect();
        format!(
            "AnovaCoxphResult(test='{}', models=[\n  {}\n])",
            self.test_type,
            rows_str.join(",\n  ")
        )
    }

    pub fn to_table(&self) -> String {
        let mut table = String::new();
        table.push_str(&format!(
            "Analysis of Deviance Table ({})\n",
            self.test_type
        ));
        table.push_str(&format!(
            "{:<20} {:>12} {:>6} {:>12} {:>12}\n",
            "Model", "loglik", "df", "Chisq", "Pr(>|Chi|)"
        ));
        table.push_str(&"-".repeat(64));
        table.push('\n');

        for row in &self.rows {
            let chisq_str = row
                .chisq
                .map(|c| format!("{:.4}", c))
                .unwrap_or_else(|| "".to_string());
            let p_str = row
                .p_value
                .map(|p| format!("{:.4}", p))
                .unwrap_or_else(|| "".to_string());
            table.push_str(&format!(
                "{:<20} {:>12.4} {:>6} {:>12} {:>12}\n",
                row.model_name, row.loglik, row.df, chisq_str, p_str
            ));
        }
        table
    }
}

#[pyfunction]
#[pyo3(signature = (logliks, dfs, model_names=None, test="LRT".to_string()))]
pub fn anova_coxph(
    logliks: Vec<f64>,
    dfs: Vec<usize>,
    model_names: Option<Vec<String>>,
    test: String,
) -> PyResult<AnovaCoxphResult> {
    if logliks.len() != dfs.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "logliks and dfs must have the same length",
        ));
    }

    if logliks.len() < 2 {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "Need at least 2 models for comparison",
        ));
    }

    let names = model_names.unwrap_or_else(|| {
        (1..=logliks.len())
            .map(|i| format!("Model {}", i))
            .collect()
    });

    if names.len() != logliks.len() {
        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
            "model_names must match logliks length",
        ));
    }

    let mut rows = Vec::with_capacity(logliks.len());

    rows.push(AnovaRow {
        model_name: names[0].clone(),
        loglik: logliks[0],
        df: dfs[0],
        chisq: None,
        p_value: None,
    });

    for i in 1..logliks.len() {
        let chisq = 2.0 * (logliks[i] - logliks[i - 1]);
        let df_diff = dfs[i].abs_diff(dfs[i - 1]);

        let p_value = if df_diff > 0 && chisq >= 0.0 {
            chi2_sf(chisq, df_diff)
        } else {
            f64::NAN
        };

        rows.push(AnovaRow {
            model_name: names[i].clone(),
            loglik: logliks[i],
            df: dfs[i],
            chisq: Some(chisq),
            p_value: Some(p_value),
        });
    }

    Ok(AnovaCoxphResult {
        rows,
        test_type: test,
    })
}

#[pyfunction]
pub fn anova_coxph_single(
    loglik_null: f64,
    loglik_full: f64,
    df_null: usize,
    df_full: usize,
) -> PyResult<AnovaCoxphResult> {
    anova_coxph(
        vec![loglik_null, loglik_full],
        vec![df_null, df_full],
        Some(vec!["Null".to_string(), "Full".to_string()]),
        "LRT".to_string(),
    )
}

#[allow(dead_code)]
pub fn compute_anova_lrt(loglik1: f64, loglik2: f64, df1: usize, df2: usize) -> (f64, f64, usize) {
    let chisq = 2.0 * (loglik2 - loglik1).abs();
    let df_diff = df2.abs_diff(df1);
    let p_value = if df_diff > 0 {
        chi2_sf(chisq, df_diff)
    } else {
        1.0
    };
    (chisq, p_value, df_diff)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_anova_coxph() {
        let logliks = vec![-100.0, -95.0, -90.0];
        let dfs = vec![0, 1, 2];
        let result = anova_coxph(logliks, dfs, None, "LRT".to_string()).unwrap();

        assert_eq!(result.rows.len(), 3);
        assert!(result.rows[0].chisq.is_none());
        assert!(result.rows[1].chisq.is_some());
        assert!((result.rows[1].chisq.unwrap() - 10.0).abs() < 1e-10);
    }

    #[test]
    fn test_compute_anova_lrt() {
        let (chisq, p_value, df) = compute_anova_lrt(-100.0, -95.0, 0, 1);
        assert!((chisq - 10.0).abs() < 1e-10);
        assert_eq!(df, 1);
        assert!(p_value < 0.01);
    }
}