survival 1.0.17

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use ndarray::{Array1, Array2};
use ndarray_linalg::{Cholesky, Solve, UPLO};

pub fn cholesky_solve(
    matrix: &Array2<f64>,
    vector: &Array1<f64>,
    _tol: f64,
) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
    if matrix.nrows() == 0 || matrix.ncols() == 0 {
        return Ok(Array1::zeros(vector.len()));
    }

    let max_val = matrix.iter().map(|&x| x.abs()).fold(0.0f64, f64::max);
    if max_val < 1e-10 {
        return Ok(Array1::zeros(vector.len()));
    }

    match matrix.cholesky(UPLO::Lower) {
        Ok(chol) => chol
            .solve(vector)
            .map_err(|e| format!("Cholesky solve failed: {}", e).into()),
        Err(_) => {
            let n = matrix.nrows();
            let mut reg_matrix = matrix.clone();
            let ridge = max_val * 1e-6;
            for i in 0..n {
                reg_matrix[[i, i]] += ridge;
            }
            match reg_matrix.cholesky(UPLO::Lower) {
                Ok(chol) => chol.solve(vector).map_err(|e| {
                    format!("Cholesky solve failed after regularization: {}", e).into()
                }),
                Err(_) => Ok(Array1::zeros(vector.len())),
            }
        }
    }
}

pub fn cholesky_solve_with_fallback(
    primary: &Array2<f64>,
    fallback: &Array2<f64>,
    vector: &Array1<f64>,
    tol: f64,
) -> Result<Array1<f64>, Box<dyn std::error::Error>> {
    match cholesky_solve(primary, vector, tol) {
        Ok(result) => {
            if result.iter().all(|&x| x == 0.0) {
                cholesky_solve(fallback, vector, tol)
            } else {
                Ok(result)
            }
        }
        Err(_) => cholesky_solve(fallback, vector, tol),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::arr2;

    #[test]
    fn test_cholesky_solve_identity() {
        let matrix = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
        let vector = Array1::from_vec(vec![1.0, 2.0]);
        let result = cholesky_solve(&matrix, &vector, 1e-9).unwrap();
        assert!((result[0] - 1.0).abs() < 1e-10);
        assert!((result[1] - 2.0).abs() < 1e-10);
    }

    #[test]
    fn test_cholesky_solve_empty() {
        let matrix: Array2<f64> = Array2::zeros((0, 0));
        let vector: Array1<f64> = Array1::zeros(0);
        let result = cholesky_solve(&matrix, &vector, 1e-9).unwrap();
        assert_eq!(result.len(), 0);
    }
}