use crate::constants::{COX_MAX_ITER, PARALLEL_THRESHOLD_SMALL};
use crate::utilities::numpy_utils::{extract_optional_vec_f64, extract_vec_f64, extract_vec_i32};
use ndarray::Array2;
use pyo3::prelude::*;
use rayon::prelude::*;
#[derive(Debug, Clone)]
#[pyclass]
pub struct CVResult {
#[pyo3(get)]
pub fold_scores: Vec<f64>,
#[pyo3(get)]
pub mean_score: f64,
#[pyo3(get)]
pub std_score: f64,
#[pyo3(get)]
pub fold_coefficients: Vec<Vec<f64>>,
}
#[pymethods]
impl CVResult {
#[new]
fn new(
fold_scores: Vec<f64>,
mean_score: f64,
std_score: f64,
fold_coefficients: Vec<Vec<f64>>,
) -> Self {
Self {
fold_scores,
mean_score,
std_score,
fold_coefficients,
}
}
}
pub struct CVConfig {
pub n_folds: usize,
pub shuffle: bool,
pub seed: Option<u64>,
}
impl Default for CVConfig {
fn default() -> Self {
Self {
n_folds: 5,
shuffle: true,
seed: None,
}
}
}
fn simple_shuffle(indices: &mut [usize], seed: u64) {
let n = indices.len();
for i in (1..n).rev() {
let mut state = seed.wrapping_add(i as u64);
state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
let j = (state as usize) % (i + 1);
indices.swap(i, j);
}
}
fn create_folds(n: usize, n_folds: usize, shuffle: bool, seed: Option<u64>) -> Vec<Vec<usize>> {
let mut indices: Vec<usize> = (0..n).collect();
if shuffle {
let seed = seed.unwrap_or(42);
simple_shuffle(&mut indices, seed);
}
let fold_size = n / n_folds;
let remainder = n % n_folds;
let mut folds = Vec::with_capacity(n_folds);
let mut start = 0;
for i in 0..n_folds {
let extra = if i < remainder { 1 } else { 0 };
let end = start + fold_size + extra;
folds.push(indices[start..end].to_vec());
start = end;
}
folds
}
pub fn cv_cox(
time: &[f64],
status: &[i32],
covariates: &Array2<f64>,
weights: Option<&[f64]>,
config: &CVConfig,
) -> Result<CVResult, Box<dyn std::error::Error + Send + Sync>> {
use crate::regression::coxfit6::{CoxFitBuilder, Method as CoxMethod};
use ndarray::Array1;
let n = time.len();
let nvar = covariates.nrows();
let default_weights: Vec<f64> = vec![1.0; n];
let weights = weights.unwrap_or(&default_weights);
let folds = create_folds(n, config.n_folds, config.shuffle, config.seed);
let results: Vec<(f64, Vec<f64>)> = (0..config.n_folds)
.into_par_iter()
.map(|fold_idx| {
let test_indices = &folds[fold_idx];
let train_indices: Vec<usize> = folds
.iter()
.enumerate()
.filter(|(i, _)| *i != fold_idx)
.flat_map(|(_, f)| f.iter().copied())
.collect();
let train_n = train_indices.len();
let test_n = test_indices.len();
let train_time: Vec<f64> = train_indices.iter().map(|&i| time[i]).collect();
let train_status: Vec<i32> = train_indices.iter().map(|&i| status[i]).collect();
let train_weights: Vec<f64> = train_indices.iter().map(|&i| weights[i]).collect();
let mut train_covariates = Array2::zeros((train_n, nvar));
for (new_idx, &orig_idx) in train_indices.iter().enumerate() {
for var in 0..nvar {
train_covariates[[new_idx, var]] = covariates[[var, orig_idx]];
}
}
let mut sorted_indices: Vec<usize> = (0..train_n).collect();
sorted_indices.sort_by(|&a, &b| {
train_time[b]
.partial_cmp(&train_time[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_time: Vec<f64> = sorted_indices.iter().map(|&i| train_time[i]).collect();
let sorted_status: Vec<i32> = sorted_indices.iter().map(|&i| train_status[i]).collect();
let sorted_weights: Vec<f64> =
sorted_indices.iter().map(|&i| train_weights[i]).collect();
let mut sorted_covariates = Array2::zeros((train_n, nvar));
for (new_idx, &orig_idx) in sorted_indices.iter().enumerate() {
for var in 0..nvar {
sorted_covariates[[new_idx, var]] = train_covariates[[orig_idx, var]];
}
}
let time_arr = Array1::from_vec(sorted_time);
let status_arr = Array1::from_vec(sorted_status);
let weights_arr = Array1::from_vec(sorted_weights);
let beta = match CoxFitBuilder::new(time_arr, status_arr, sorted_covariates)
.weights(weights_arr)
.method(CoxMethod::Breslow)
.max_iter(COX_MAX_ITER)
.eps(1e-9)
.toler(1e-9)
.build()
{
Ok(mut fit) => {
if fit.fit().is_ok() {
let (b, _, _, _, _, _, _, _) = fit.results();
b
} else {
vec![0.0; nvar]
}
}
Err(_) => vec![0.0; nvar],
};
let test_time: Vec<f64> = test_indices.iter().map(|&i| time[i]).collect();
let test_status: Vec<i32> = test_indices.iter().map(|&i| status[i]).collect();
let mut test_covariates = Array2::zeros((nvar, test_n));
for (new_idx, &orig_idx) in test_indices.iter().enumerate() {
for var in 0..nvar {
test_covariates[[var, new_idx]] = covariates[[var, orig_idx]];
}
}
let linear_predictor: Vec<f64> = (0..test_n)
.map(|i| {
(0..nvar)
.map(|var| beta[var] * test_covariates[[var, i]])
.sum()
})
.collect();
let (concordant, discordant, tied) = if test_n > PARALLEL_THRESHOLD_SMALL {
(0..test_n)
.into_par_iter()
.filter(|&i| test_status[i] == 1)
.map(|i| {
let mut c = 0.0;
let mut d = 0.0;
let mut t = 0.0;
for j in 0..test_n {
if i != j && test_time[j] > test_time[i] {
if linear_predictor[i] > linear_predictor[j] {
c += 1.0;
} else if linear_predictor[i] < linear_predictor[j] {
d += 1.0;
} else {
t += 1.0;
}
}
}
(c, d, t)
})
.reduce(|| (0.0, 0.0, 0.0), |a, b| (a.0 + b.0, a.1 + b.1, a.2 + b.2))
} else {
let mut concordant = 0.0;
let mut discordant = 0.0;
let mut tied = 0.0;
for i in 0..test_n {
if test_status[i] != 1 {
continue;
}
for j in 0..test_n {
if i != j && test_time[j] > test_time[i] {
if linear_predictor[i] > linear_predictor[j] {
concordant += 1.0;
} else if linear_predictor[i] < linear_predictor[j] {
discordant += 1.0;
} else {
tied += 1.0;
}
}
}
}
(concordant, discordant, tied)
};
let total = concordant + discordant + tied;
let c_index = if total > 0.0 {
(concordant + 0.5 * tied) / total
} else {
0.5
};
(c_index, beta)
})
.collect();
let (fold_scores, fold_coefficients): (Vec<f64>, Vec<Vec<f64>>) = results.into_iter().unzip();
let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
let variance = fold_scores
.iter()
.map(|&s| (s - mean_score).powi(2))
.sum::<f64>()
/ (fold_scores.len() - 1) as f64;
let std_score = variance.sqrt();
Ok(CVResult {
fold_scores,
mean_score,
std_score,
fold_coefficients,
})
}
#[pyfunction]
#[pyo3(signature = (time, status, covariates, weights=None, n_folds=None, shuffle=None, seed=None))]
pub fn cv_cox_concordance(
time: &Bound<'_, PyAny>,
status: &Bound<'_, PyAny>,
covariates: Vec<Vec<f64>>,
weights: Option<&Bound<'_, PyAny>>,
n_folds: Option<usize>,
shuffle: Option<bool>,
seed: Option<u64>,
) -> PyResult<CVResult> {
let time = extract_vec_f64(time)?;
let status = extract_vec_i32(status)?;
let weights = extract_optional_vec_f64(weights)?;
let n = time.len();
let nvar = if !covariates.is_empty() {
covariates[0].len()
} else {
0
};
let cov_array = if nvar > 0 {
let flat: Vec<f64> = covariates.into_iter().flatten().collect();
let temp = Array2::from_shape_vec((n, nvar), flat)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
temp.t().to_owned()
} else {
Array2::zeros((0, n))
};
let config = CVConfig {
n_folds: n_folds.unwrap_or(5),
shuffle: shuffle.unwrap_or(true),
seed,
};
let weights_ref = weights.as_deref();
cv_cox(&time, &status, &cov_array, weights_ref, &config)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{}", e)))
}
pub fn cv_survreg(
time: &[f64],
status: &[f64],
covariates: &Array2<f64>,
distribution: &str,
config: &CVConfig,
) -> Result<CVResult, Box<dyn std::error::Error + Send + Sync>> {
use crate::regression::survreg6::survreg;
let n = time.len();
let nvar = covariates.ncols();
let cov_vecs: Vec<Vec<f64>> = (0..n)
.map(|i| (0..nvar).map(|j| covariates[[j, i]]).collect())
.collect();
let folds = create_folds(n, config.n_folds, config.shuffle, config.seed);
let results: Vec<(f64, Vec<f64>)> = (0..config.n_folds)
.into_par_iter()
.filter_map(|fold_idx| {
let test_indices = &folds[fold_idx];
let train_indices: Vec<usize> = folds
.iter()
.enumerate()
.filter(|(i, _)| *i != fold_idx)
.flat_map(|(_, f)| f.iter().copied())
.collect();
let train_time: Vec<f64> = train_indices.iter().map(|&i| time[i]).collect();
let train_status: Vec<f64> = train_indices.iter().map(|&i| status[i]).collect();
let train_covariates: Vec<Vec<f64>> =
train_indices.iter().map(|&i| cov_vecs[i].clone()).collect();
let fit_result = survreg(
train_time,
train_status,
train_covariates,
None,
None,
None,
None,
Some(distribution),
Some(COX_MAX_ITER),
Some(1e-5),
Some(1e-9),
)
.ok()?;
let test_time: Vec<f64> = test_indices.iter().map(|&i| time[i]).collect();
let test_status: Vec<f64> = test_indices.iter().map(|&i| status[i]).collect();
let test_covariates: Vec<Vec<f64>> =
test_indices.iter().map(|&i| cov_vecs[i].clone()).collect();
let test_fit = survreg(
test_time,
test_status,
test_covariates,
None,
None,
Some(fit_result.coefficients.clone()),
None,
Some(distribution),
Some(1),
Some(1e-5),
Some(1e-9),
)
.ok()?;
Some((test_fit.log_likelihood, fit_result.coefficients))
})
.collect();
if results.is_empty() {
return Err("All CV folds failed".into());
}
let (fold_scores, fold_coefficients): (Vec<f64>, Vec<Vec<f64>>) = results.into_iter().unzip();
let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
let variance = fold_scores
.iter()
.map(|&s| (s - mean_score).powi(2))
.sum::<f64>()
/ (fold_scores.len() - 1) as f64;
let std_score = variance.sqrt();
Ok(CVResult {
fold_scores,
mean_score,
std_score,
fold_coefficients,
})
}
#[pyfunction]
#[pyo3(signature = (time, status, covariates, distribution=None, n_folds=None, shuffle=None, seed=None))]
pub fn cv_survreg_loglik(
time: &Bound<'_, PyAny>,
status: &Bound<'_, PyAny>,
covariates: Vec<Vec<f64>>,
distribution: Option<&str>,
n_folds: Option<usize>,
shuffle: Option<bool>,
seed: Option<u64>,
) -> PyResult<CVResult> {
let time = extract_vec_f64(time)?;
let status = extract_vec_f64(status)?;
let n = time.len();
let nvar = if !covariates.is_empty() {
covariates[0].len()
} else {
0
};
let cov_array = if nvar > 0 {
let flat: Vec<f64> = covariates.into_iter().flatten().collect();
let temp = Array2::from_shape_vec((n, nvar), flat)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
temp.t().to_owned()
} else {
Array2::zeros((0, n))
};
let config = CVConfig {
n_folds: n_folds.unwrap_or(5),
shuffle: shuffle.unwrap_or(true),
seed,
};
let dist = distribution.unwrap_or("weibull");
cv_survreg(&time, &status, &cov_array, dist, &config)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{}", e)))
}