use ndarray::Array2;
use pyo3::prelude::*;
use rayon::prelude::*;
#[derive(Debug, Clone)]
#[pyclass]
pub struct CVResult {
#[pyo3(get)]
pub fold_scores: Vec<f64>,
#[pyo3(get)]
pub mean_score: f64,
#[pyo3(get)]
pub std_score: f64,
#[pyo3(get)]
pub fold_coefficients: Vec<Vec<f64>>,
}
#[pymethods]
impl CVResult {
#[new]
fn new(
fold_scores: Vec<f64>,
mean_score: f64,
std_score: f64,
fold_coefficients: Vec<Vec<f64>>,
) -> Self {
Self {
fold_scores,
mean_score,
std_score,
fold_coefficients,
}
}
}
pub struct CVConfig {
pub n_folds: usize,
pub shuffle: bool,
pub seed: Option<u64>,
}
impl Default for CVConfig {
fn default() -> Self {
Self {
n_folds: 5,
shuffle: true,
seed: None,
}
}
}
fn simple_shuffle(indices: &mut [usize], seed: u64) {
let n = indices.len();
for i in (1..n).rev() {
let mut state = seed.wrapping_add(i as u64);
state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
let j = (state as usize) % (i + 1);
indices.swap(i, j);
}
}
fn create_folds(n: usize, n_folds: usize, shuffle: bool, seed: Option<u64>) -> Vec<Vec<usize>> {
let mut indices: Vec<usize> = (0..n).collect();
if shuffle {
let seed = seed.unwrap_or(42);
simple_shuffle(&mut indices, seed);
}
let fold_size = n / n_folds;
let remainder = n % n_folds;
let mut folds = Vec::with_capacity(n_folds);
let mut start = 0;
for i in 0..n_folds {
let extra = if i < remainder { 1 } else { 0 };
let end = start + fold_size + extra;
folds.push(indices[start..end].to_vec());
start = end;
}
folds
}
pub fn cv_cox(
time: &[f64],
status: &[i32],
covariates: &Array2<f64>,
weights: Option<&[f64]>,
config: &CVConfig,
) -> Result<CVResult, Box<dyn std::error::Error + Send + Sync>> {
use crate::regression::coxfit6::{CoxFit, Method as CoxMethod};
use ndarray::Array1;
let n = time.len();
let nvar = covariates.nrows();
let default_weights: Vec<f64> = vec![1.0; n];
let weights = weights.unwrap_or(&default_weights);
let folds = create_folds(n, config.n_folds, config.shuffle, config.seed);
let results: Vec<(f64, Vec<f64>)> = (0..config.n_folds)
.into_par_iter()
.map(|fold_idx| {
let test_indices = &folds[fold_idx];
let train_indices: Vec<usize> = folds
.iter()
.enumerate()
.filter(|(i, _)| *i != fold_idx)
.flat_map(|(_, f)| f.iter().copied())
.collect();
let train_n = train_indices.len();
let test_n = test_indices.len();
let train_time: Vec<f64> = train_indices.iter().map(|&i| time[i]).collect();
let train_status: Vec<i32> = train_indices.iter().map(|&i| status[i]).collect();
let train_weights: Vec<f64> = train_indices.iter().map(|&i| weights[i]).collect();
let mut train_covariates = Array2::zeros((train_n, nvar));
for (new_idx, &orig_idx) in train_indices.iter().enumerate() {
for var in 0..nvar {
train_covariates[[new_idx, var]] = covariates[[var, orig_idx]];
}
}
let mut sorted_indices: Vec<usize> = (0..train_n).collect();
sorted_indices.sort_by(|&a, &b| {
train_time[b]
.partial_cmp(&train_time[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_time: Vec<f64> = sorted_indices.iter().map(|&i| train_time[i]).collect();
let sorted_status: Vec<i32> = sorted_indices.iter().map(|&i| train_status[i]).collect();
let sorted_weights: Vec<f64> =
sorted_indices.iter().map(|&i| train_weights[i]).collect();
let mut sorted_covariates = Array2::zeros((train_n, nvar));
for (new_idx, &orig_idx) in sorted_indices.iter().enumerate() {
for var in 0..nvar {
sorted_covariates[[new_idx, var]] = train_covariates[[orig_idx, var]];
}
}
let initial_beta: Vec<f64> = vec![0.0; nvar];
let time_arr = Array1::from_vec(sorted_time);
let status_arr = Array1::from_vec(sorted_status);
let strata_arr = Array1::from_elem(train_n, 0i32);
let offset_arr = Array1::from_elem(train_n, 0.0);
let weights_arr = Array1::from_vec(sorted_weights);
let beta = match CoxFit::new(
time_arr,
status_arr,
sorted_covariates,
strata_arr,
offset_arr,
weights_arr,
CoxMethod::Breslow,
25,
1e-9,
1e-9,
vec![true; nvar],
initial_beta,
) {
Ok(mut fit) => {
if fit.fit().is_ok() {
let (b, _, _, _, _, _, _, _) = fit.results();
b
} else {
vec![0.0; nvar]
}
}
Err(_) => vec![0.0; nvar],
};
let test_time: Vec<f64> = test_indices.iter().map(|&i| time[i]).collect();
let test_status: Vec<i32> = test_indices.iter().map(|&i| status[i]).collect();
let mut test_covariates = Array2::zeros((nvar, test_n));
for (new_idx, &orig_idx) in test_indices.iter().enumerate() {
for var in 0..nvar {
test_covariates[[var, new_idx]] = covariates[[var, orig_idx]];
}
}
let mut linear_predictor = vec![0.0; test_n];
for i in 0..test_n {
for var in 0..nvar {
linear_predictor[i] += beta[var] * test_covariates[[var, i]];
}
}
let mut concordant = 0.0;
let mut discordant = 0.0;
let mut tied = 0.0;
for i in 0..test_n {
if test_status[i] == 0 {
continue;
}
for j in 0..test_n {
if i == j {
continue;
}
if test_time[j] > test_time[i] {
if linear_predictor[i] > linear_predictor[j] {
concordant += 1.0;
} else if linear_predictor[i] < linear_predictor[j] {
discordant += 1.0;
} else {
tied += 1.0;
}
}
}
}
let total = concordant + discordant + tied;
let c_index = if total > 0.0 {
(concordant + 0.5 * tied) / total
} else {
0.5
};
(c_index, beta)
})
.collect();
let (fold_scores, fold_coefficients): (Vec<f64>, Vec<Vec<f64>>) = results.into_iter().unzip();
let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
let variance = fold_scores
.iter()
.map(|&s| (s - mean_score).powi(2))
.sum::<f64>()
/ (fold_scores.len() - 1) as f64;
let std_score = variance.sqrt();
Ok(CVResult {
fold_scores,
mean_score,
std_score,
fold_coefficients,
})
}
#[pyfunction]
#[pyo3(signature = (time, status, covariates, weights=None, n_folds=None, shuffle=None, seed=None))]
pub fn cv_cox_concordance(
time: Vec<f64>,
status: Vec<i32>,
covariates: Vec<Vec<f64>>,
weights: Option<Vec<f64>>,
n_folds: Option<usize>,
shuffle: Option<bool>,
seed: Option<u64>,
) -> PyResult<CVResult> {
let n = time.len();
let nvar = if !covariates.is_empty() {
covariates[0].len()
} else {
0
};
let cov_array = if nvar > 0 {
let flat: Vec<f64> = covariates.into_iter().flatten().collect();
let temp = Array2::from_shape_vec((n, nvar), flat)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
temp.t().to_owned()
} else {
Array2::zeros((0, n))
};
let config = CVConfig {
n_folds: n_folds.unwrap_or(5),
shuffle: shuffle.unwrap_or(true),
seed,
};
let weights_ref = weights.as_deref();
cv_cox(&time, &status, &cov_array, weights_ref, &config)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{}", e)))
}
pub fn cv_survreg(
time: &[f64],
status: &[f64],
covariates: &Array2<f64>,
distribution: &str,
config: &CVConfig,
) -> Result<CVResult, Box<dyn std::error::Error + Send + Sync>> {
use crate::regression::survreg6::survreg;
let n = time.len();
let nvar = covariates.ncols();
let cov_vecs: Vec<Vec<f64>> = (0..n)
.map(|i| (0..nvar).map(|j| covariates[[j, i]]).collect())
.collect();
let folds = create_folds(n, config.n_folds, config.shuffle, config.seed);
let dist = distribution.to_string();
let results: Vec<(f64, Vec<f64>)> = (0..config.n_folds)
.into_par_iter()
.filter_map(|fold_idx| {
let test_indices = &folds[fold_idx];
let train_indices: Vec<usize> = folds
.iter()
.enumerate()
.filter(|(i, _)| *i != fold_idx)
.flat_map(|(_, f)| f.iter().copied())
.collect();
let train_time: Vec<f64> = train_indices.iter().map(|&i| time[i]).collect();
let train_status: Vec<f64> = train_indices.iter().map(|&i| status[i]).collect();
let train_covariates: Vec<Vec<f64>> =
train_indices.iter().map(|&i| cov_vecs[i].clone()).collect();
let fit_result = survreg(
train_time,
train_status,
train_covariates,
None,
None,
None,
None,
Some(&dist),
Some(25),
Some(1e-5),
Some(1e-9),
)
.ok()?;
let test_time: Vec<f64> = test_indices.iter().map(|&i| time[i]).collect();
let test_status: Vec<f64> = test_indices.iter().map(|&i| status[i]).collect();
let test_covariates: Vec<Vec<f64>> =
test_indices.iter().map(|&i| cov_vecs[i].clone()).collect();
let test_fit = survreg(
test_time,
test_status,
test_covariates,
None,
None,
Some(fit_result.coefficients.clone()),
None,
Some(&dist),
Some(1),
Some(1e-5),
Some(1e-9),
)
.ok()?;
Some((test_fit.log_likelihood, fit_result.coefficients))
})
.collect();
if results.is_empty() {
return Err("All CV folds failed".into());
}
let (fold_scores, fold_coefficients): (Vec<f64>, Vec<Vec<f64>>) = results.into_iter().unzip();
let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
let variance = fold_scores
.iter()
.map(|&s| (s - mean_score).powi(2))
.sum::<f64>()
/ (fold_scores.len() - 1) as f64;
let std_score = variance.sqrt();
Ok(CVResult {
fold_scores,
mean_score,
std_score,
fold_coefficients,
})
}
#[pyfunction]
#[pyo3(signature = (time, status, covariates, distribution=None, n_folds=None, shuffle=None, seed=None))]
pub fn cv_survreg_loglik(
time: Vec<f64>,
status: Vec<f64>,
covariates: Vec<Vec<f64>>,
distribution: Option<&str>,
n_folds: Option<usize>,
shuffle: Option<bool>,
seed: Option<u64>,
) -> PyResult<CVResult> {
let n = time.len();
let nvar = if !covariates.is_empty() {
covariates[0].len()
} else {
0
};
let cov_array = if nvar > 0 {
let flat: Vec<f64> = covariates.into_iter().flatten().collect();
let temp = Array2::from_shape_vec((n, nvar), flat)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
temp.t().to_owned()
} else {
Array2::zeros((0, n))
};
let config = CVConfig {
n_folds: n_folds.unwrap_or(5),
shuffle: shuffle.unwrap_or(true),
seed,
};
let dist = distribution.unwrap_or("weibull");
cv_survreg(&time, &status, &cov_array, dist, &config)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{}", e)))
}