use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
pub fn validate_concordance_inputs(
time_data_len: usize,
n: usize,
indices_len: usize,
weights_len: usize,
) -> PyResult<()> {
if n == 0 {
return Err(PyRuntimeError::new_err("No observations provided"));
}
if time_data_len != 2 * n {
return Err(PyRuntimeError::new_err(
"Time data should have 2*n elements (time, status)",
));
}
if indices_len != n {
return Err(PyRuntimeError::new_err(
"Indices length does not match observations",
));
}
if weights_len != n {
return Err(PyRuntimeError::new_err(
"Weights length does not match observations",
));
}
Ok(())
}
pub fn validate_extended_concordance_inputs(
time_data_len: usize,
n: usize,
indices_len: usize,
weights_len: usize,
time_weights_len: usize,
sort_stop_len: usize,
) -> PyResult<()> {
validate_concordance_inputs(time_data_len, n, indices_len, weights_len)?;
if time_weights_len != n {
return Err(PyRuntimeError::new_err(
"Time weights length does not match observations",
));
}
if sort_stop_len != n {
return Err(PyRuntimeError::new_err(
"Sort stop length does not match observations",
));
}
Ok(())
}
pub fn build_concordance_result(
py: Python<'_>,
count: &[f64],
imat: Option<&[f64]>,
resid: Option<&[f64]>,
n: Option<usize>,
) -> PyResult<Py<PyDict>> {
let concordant = count[0];
let discordant = count[1];
let tied_x = count[2];
let tied_y = count[3];
let tied_xy = count.get(4).copied().unwrap_or(0.0);
let variance = count.get(5).copied();
let total_pairs = concordant + discordant + tied_x + tied_y + tied_xy;
let concordance_index = if total_pairs > 0.0 {
(concordant + 0.5 * (tied_x + tied_y + tied_xy)) / total_pairs
} else {
0.0
};
let dict = PyDict::new(py);
dict.set_item("concordant", concordant)?;
dict.set_item("discordant", discordant)?;
dict.set_item("tied_x", tied_x)?;
dict.set_item("tied_y", tied_y)?;
dict.set_item("tied_xy", tied_xy)?;
dict.set_item("concordance_index", concordance_index)?;
dict.set_item("total_pairs", total_pairs)?;
if let Some(v) = variance {
dict.set_item("variance", v)?;
}
if let Some(imat_data) = imat {
dict.set_item("information_matrix", imat_data.to_vec())?;
}
if let Some(resid_data) = resid {
dict.set_item("residuals", resid_data.to_vec())?;
}
if let Some(n_obs) = n {
dict.set_item("n_observations", n_obs)?;
}
Ok(dict.into())
}
pub fn walkup_binary_tree(nwt: &[f64], twt: &[f64], index: usize, ntree: usize) -> [f64; 3] {
let mut sums = [0.0; 3];
if index >= ntree {
return sums;
}
sums[2] = nwt[index];
let right_child = 2 * index + 2;
if right_child < ntree {
sums[0] += twt[right_child];
}
let left_child = 2 * index + 1;
if left_child < ntree {
sums[1] += twt[left_child];
}
let mut current = index;
while current > 0 {
let parent = (current - 1) / 2;
let parent_twt = twt[parent];
let current_twt = twt[current];
if current % 2 == 1 {
sums[0] += parent_twt - current_twt;
} else {
sums[1] += parent_twt - current_twt;
}
current = parent;
}
sums
}
pub fn add_to_binary_tree(nwt: &mut [f64], twt: &mut [f64], index: usize, wt: f64) {
nwt[index] += wt;
let mut current = index;
while current > 0 {
let parent = (current - 1) / 2;
twt[parent] += wt;
current = parent;
}
twt[0] += wt;
}