survival 1.0.10

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use ndarray::{Array1, Array2};
use ndarray_linalg::Solve;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum PSplineError {
    #[error("Unsupported penalty method: {0}. Supported methods are: GCV, UBRE, REML, AIC, BIC")]
    UnsupportedMethod(String),
    #[error("Failed to create matrix with shape ({rows}, {cols}): {reason}")]
    MatrixCreationError {
        rows: usize,
        cols: usize,
        reason: String,
    },
    #[error("Failed to solve linear system: matrix may be singular or ill-conditioned")]
    LinearSolveError,
}
impl From<PSplineError> for PyErr {
    fn from(err: PSplineError) -> PyErr {
        PyValueError::new_err(err.to_string())
    }
}
#[pyclass]
pub struct PSpline {
    x: Vec<f64>,
    df: u32,
    theta: f64,
    nterm: u32,
    degree: u32,
    eps: f64,
    method: String,
    boundary_knots: (f64, f64),
    intercept: bool,
    penalty: bool,
    #[pyo3(get)]
    coefficients: Option<Vec<f64>>,
    #[pyo3(get)]
    fitted: bool,
}
#[pymethods]
impl PSpline {
    #[new]
    #[allow(clippy::too_many_arguments)]
    fn new(
        x: Vec<f64>,
        df: u32,
        theta: f64,
        eps: f64,
        method: String,
        boundary_knots: (f64, f64),
        intercept: bool,
        penalty: bool,
    ) -> Self {
        let nterm = df + 1;
        let degree = 3;
        PSpline {
            x,
            df,
            theta,
            nterm,
            degree,
            eps,
            method,
            boundary_knots,
            intercept,
            penalty,
            coefficients: None,
            fitted: false,
        }
    }
    pub fn fit(&mut self) -> PyResult<Vec<f64>> {
        let basis = self.create_basis();
        let penalized_basis = self
            .apply_penalty(basis)
            .map_err(|e| PyValueError::new_err(e.to_string()))?;
        let coefficients = self
            .optimize_fit(penalized_basis)
            .map_err(|e| PyValueError::new_err(e.to_string()))?;
        self.coefficients = Some(coefficients.clone());
        self.fitted = true;
        Ok(coefficients)
    }
    pub fn predict(&self, new_x: Vec<f64>) -> PyResult<Vec<f64>> {
        let coefficients = self
            .coefficients
            .as_ref()
            .ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
        let mut predictions = Vec::with_capacity(new_x.len());
        for x_val in &new_x {
            let mut pred = 0.0;
            for (j, coef) in coefficients.iter().enumerate() {
                pred += coef * self.basis_function(*x_val, j as u32);
            }
            predictions.push(pred);
        }
        Ok(predictions)
    }
    #[getter]
    pub fn get_df(&self) -> u32 {
        self.df
    }
    #[getter]
    pub fn get_eps(&self) -> f64 {
        self.eps
    }
    fn create_basis(&self) -> Vec<Vec<f64>> {
        let n = self.x.len();
        let mut basis = vec![vec![0.0; self.nterm as usize]; n];
        for (i, x_val) in self.x.iter().enumerate() {
            for j in 0..self.nterm {
                basis[i][j as usize] = self.basis_function(*x_val, j);
            }
        }
        basis
    }
    fn basis_function(&self, x: f64, j: u32) -> f64 {
        let mut b = 0.0;
        if self.intercept {
            b += 1.0;
        }
        if j == 0 {
            return b;
        }
        let (a, b) = self.boundary_knots;
        let t = (x - a) / (b - a);
        let knots = self.knots();
        let mut d = vec![0.0; self.degree as usize + 1];
        d[0] = 1.0;
        for k in 1..=self.degree {
            for i in 0..=(self.degree - k) {
                let i_usize = i as usize;
                let k_usize = k as usize;
                let denom = knots[i_usize + k_usize] - knots[i_usize];
                let w = if denom.abs() > 1e-10 {
                    ((t - knots[i_usize]) / denom).clamp(0.0, 1.0)
                } else {
                    0.0
                };
                d[i_usize] = (1.0 - w) * d[i_usize] + w * d[i_usize + 1];
            }
        }
        b * d[0]
    }
    fn knots(&self) -> Vec<f64> {
        let (a, b) = self.boundary_knots;
        let mut knots = vec![0.0; self.nterm as usize + self.degree as usize + 1];
        for i in 0..self.degree + 1 {
            knots[i as usize] = a;
            knots[(self.nterm + self.degree - i) as usize] = b;
        }
        knots
    }
    fn apply_penalty(&self, basis: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, PSplineError> {
        let mut penalized_basis = basis.clone();
        if self.penalty {
            let mut penalty = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
            for i in 0..self.nterm {
                for j in 0..self.nterm {
                    penalty[i as usize][j as usize] = self.penalty_function(i, j)?;
                }
            }
            for i in 0..self.nterm {
                for j in 0..self.nterm {
                    for k in 0..self.nterm {
                        penalized_basis[i as usize][j as usize] -= self.theta
                            * penalty[i as usize][k as usize]
                            * basis[k as usize][j as usize];
                    }
                }
            }
        }
        Ok(penalized_basis)
    }
    fn penalty_function(&self, i: u32, j: u32) -> Result<f64, PSplineError> {
        match self.method.as_str() {
            "GCV" => Ok(self.gcv(i, j)),
            "UBRE" => Ok(self.ubre(i, j)),
            "REML" => Ok(self.reml(i, j)),
            "AIC" => Ok(self.aic(i, j)),
            "BIC" => Ok(self.bic(i, j)),
            _ => Err(PSplineError::UnsupportedMethod(self.method.clone())),
        }
    }
    fn reml(&self, i: u32, j: u32) -> f64 {
        if i == j {
            return self.nterm as f64;
        }
        let mut df = 0.0;
        let mut trace = 0.0;
        let mut basis = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                basis[k as usize][l as usize] = self.basis_function(self.x[k as usize], l);
            }
        }
        let n = self.nterm as f64;
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                df += basis[k as usize][i as usize]
                    * basis[l as usize][j as usize]
                    * basis[k as usize][l as usize];
                trace += basis[k as usize][i as usize]
                    * basis[k as usize][l as usize]
                    * basis[l as usize][j as usize];
            }
        }
        let edf = trace / n;
        df / (1.0 - edf).powi(2) + (n - edf).ln()
    }
    fn aic(&self, i: u32, j: u32) -> f64 {
        if i == j {
            return self.nterm as f64;
        }
        let mut df = 0.0;
        let mut trace = 0.0;
        let mut basis = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                basis[k as usize][l as usize] = self.basis_function(self.x[k as usize], l);
            }
        }
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                df += basis[k as usize][i as usize]
                    * basis[l as usize][j as usize]
                    * basis[k as usize][l as usize];
                trace += basis[k as usize][i as usize]
                    * basis[k as usize][l as usize]
                    * basis[l as usize][j as usize];
            }
        }
        let n = self.nterm as f64;
        let edf = trace / n;
        df / (1.0 - edf).powi(2) + 2.0 * edf
    }
    fn bic(&self, i: u32, j: u32) -> f64 {
        if i == j {
            return self.nterm as f64;
        }
        let mut df = 0.0;
        let mut trace = 0.0;
        let mut basis = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                basis[k as usize][l as usize] = self.basis_function(self.x[k as usize], l);
            }
        }
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                df += basis[k as usize][i as usize]
                    * basis[l as usize][j as usize]
                    * basis[k as usize][l as usize];
                trace += basis[k as usize][i as usize]
                    * basis[k as usize][l as usize]
                    * basis[l as usize][j as usize];
            }
        }
        let n = self.nterm as f64;
        let edf = trace / n;
        df / (1.0 - edf).powi(2) + n.ln() * edf
    }
    fn gcv(&self, i: u32, j: u32) -> f64 {
        if i == j {
            return self.nterm as f64;
        }
        let mut df = 0.0;
        let mut trace = 0.0;
        let mut basis = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                basis[k as usize][l as usize] = self.basis_function(self.x[k as usize], l);
            }
        }
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                df += basis[k as usize][i as usize]
                    * basis[l as usize][j as usize]
                    * basis[k as usize][l as usize];
                trace += basis[k as usize][i as usize]
                    * basis[k as usize][l as usize]
                    * basis[l as usize][j as usize];
            }
        }
        df / (1.0 - trace / self.nterm as f64).powi(2)
    }
    fn ubre(&self, i: u32, j: u32) -> f64 {
        if i == j {
            return self.nterm as f64;
        }
        let mut df = 0.0;
        let mut trace = 0.0;
        let mut basis = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                basis[k as usize][l as usize] = self.basis_function(self.x[k as usize], l);
            }
        }
        for k in 0..self.nterm {
            for l in 0..self.nterm {
                df += basis[k as usize][i as usize]
                    * basis[l as usize][j as usize]
                    * basis[k as usize][l as usize];
                trace += basis[k as usize][i as usize]
                    * basis[k as usize][l as usize]
                    * basis[l as usize][j as usize];
            }
        }
        df / (1.0 - trace / self.nterm as f64).powi(2)
    }
    fn optimize_fit(&self, basis: Vec<Vec<f64>>) -> Result<Vec<f64>, PSplineError> {
        let mut a = vec![vec![0.0; self.nterm as usize]; self.nterm as usize];
        let mut b = vec![0.0; self.nterm as usize];
        for i in 0..self.nterm {
            for j in 0..self.nterm {
                for k in 0..self.nterm {
                    a[i as usize][j as usize] +=
                        basis[k as usize][i as usize] * basis[k as usize][j as usize];
                }
            }
        }
        for i in 0..self.nterm {
            for j in 0..self.nterm {
                b[i as usize] += basis[j as usize][i as usize];
            }
        }
        self.solve(a, b)
    }
    fn solve(&self, a: Vec<Vec<f64>>, b: Vec<f64>) -> Result<Vec<f64>, PSplineError> {
        let n = a.len();
        let a_array =
            Array2::from_shape_vec((n, n), a.into_iter().flatten().collect()).map_err(|e| {
                PSplineError::MatrixCreationError {
                    rows: n,
                    cols: n,
                    reason: e.to_string(),
                }
            })?;
        let b_array = Array1::from_vec(b);
        let x = a_array
            .solve_into(b_array)
            .map_err(|_| PSplineError::LinearSolveError)?;
        Ok(x.to_vec())
    }
}