use crate::utilities::statistical::normal_cdf;
use pyo3::prelude::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
#[pyclass]
pub struct YatesResult {
#[pyo3(get)]
pub levels: Vec<String>,
#[pyo3(get)]
pub means: Vec<f64>,
#[pyo3(get)]
pub se: Vec<f64>,
#[pyo3(get)]
pub lower: Vec<f64>,
#[pyo3(get)]
pub upper: Vec<f64>,
#[pyo3(get)]
pub n: Vec<usize>,
#[pyo3(get)]
pub predict_type: String,
}
#[pyfunction]
#[pyo3(signature = (predictions, factor, weights=None, conf_level=None))]
pub fn yates(
predictions: Vec<f64>,
factor: Vec<String>,
weights: Option<Vec<f64>>,
conf_level: Option<f64>,
) -> PyResult<YatesResult> {
let n = predictions.len();
if factor.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"predictions and factor must have same length",
));
}
let wts = weights.unwrap_or_else(|| vec![1.0; n]);
if wts.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"weights must have same length as predictions",
));
}
let conf = conf_level.unwrap_or(0.95);
let z = z_score(conf);
let mut groups: HashMap<String, Vec<usize>> = HashMap::new();
for (i, level) in factor.iter().enumerate() {
groups.entry(level.clone()).or_default().push(i);
}
let mut levels: Vec<String> = groups.keys().cloned().collect();
levels.sort();
let mut means = Vec::with_capacity(levels.len());
let mut ses = Vec::with_capacity(levels.len());
let mut lowers = Vec::with_capacity(levels.len());
let mut uppers = Vec::with_capacity(levels.len());
let mut ns = Vec::with_capacity(levels.len());
for level in &levels {
let indices = groups.get(level).unwrap();
let group_n = indices.len();
if group_n == 0 {
means.push(f64::NAN);
ses.push(f64::NAN);
lowers.push(f64::NAN);
uppers.push(f64::NAN);
ns.push(0);
continue;
}
let mut sum_w = 0.0;
let mut sum_wx = 0.0;
for &i in indices {
let w = wts[i];
let x = predictions[i];
sum_w += w;
sum_wx += w * x;
}
let mean = if sum_w > 0.0 { sum_wx / sum_w } else { 0.0 };
let mut sum_w2_dev2 = 0.0;
for &i in indices {
let w = wts[i];
let dev = predictions[i] - mean;
sum_w2_dev2 += w * w * dev * dev;
}
let variance = if sum_w > 0.0 {
sum_w2_dev2 / (sum_w * sum_w)
} else {
0.0
};
let se = variance.sqrt();
means.push(mean);
ses.push(se);
lowers.push(mean - z * se);
uppers.push(mean + z * se);
ns.push(group_n);
}
Ok(YatesResult {
levels,
means,
se: ses,
lower: lowers,
upper: uppers,
n: ns,
predict_type: "linear".to_string(),
})
}
#[pyfunction]
#[pyo3(signature = (x, coef, n_obs, n_vars, factor_col, factor_levels, predict_type=None))]
pub fn yates_contrast(
x: Vec<f64>,
coef: Vec<f64>,
n_obs: usize,
n_vars: usize,
factor_col: usize,
factor_levels: Vec<f64>,
predict_type: Option<&str>,
) -> PyResult<YatesResult> {
if x.len() != n_obs * n_vars {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"x length must equal n_obs * n_vars",
));
}
if coef.len() != n_vars {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"coef length must equal n_vars",
));
}
if factor_col >= n_vars {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"factor_col must be < n_vars",
));
}
let pred_type = predict_type.unwrap_or("linear");
let mut levels = Vec::with_capacity(factor_levels.len());
let mut means = Vec::with_capacity(factor_levels.len());
let mut ses = Vec::with_capacity(factor_levels.len());
let ns = vec![n_obs; factor_levels.len()];
for &level in &factor_levels {
levels.push(format!("{}", level));
let mut sum_pred = 0.0;
let mut sum_pred2 = 0.0;
for i in 0..n_obs {
let mut eta = 0.0;
for j in 0..n_vars {
let x_val = if j == factor_col {
level
} else {
x[i * n_vars + j]
};
eta += x_val * coef[j];
}
let pred = match pred_type {
"risk" => eta.exp(),
"survival" => (-eta.exp()).exp(),
_ => eta,
};
sum_pred += pred;
sum_pred2 += pred * pred;
}
let mean = sum_pred / n_obs as f64;
let variance = sum_pred2 / n_obs as f64 - mean * mean;
let se = (variance / n_obs as f64).sqrt();
means.push(mean);
ses.push(se);
}
let z = 1.96;
let lower: Vec<f64> = means
.iter()
.zip(ses.iter())
.map(|(&m, &s)| m - z * s)
.collect();
let upper: Vec<f64> = means
.iter()
.zip(ses.iter())
.map(|(&m, &s)| m + z * s)
.collect();
Ok(YatesResult {
levels,
means,
se: ses,
lower,
upper,
n: ns,
predict_type: pred_type.to_string(),
})
}
#[pyfunction]
pub fn yates_pairwise(yates_result: &YatesResult) -> PyResult<YatesPairwiseResult> {
let k = yates_result.levels.len();
if k < 2 {
return Ok(YatesPairwiseResult {
level1: vec![],
level2: vec![],
difference: vec![],
se: vec![],
z: vec![],
p_value: vec![],
});
}
let mut level1 = Vec::new();
let mut level2 = Vec::new();
let mut difference = Vec::new();
let mut se = Vec::new();
let mut z_scores = Vec::new();
let mut p_values = Vec::new();
for i in 0..k {
for j in (i + 1)..k {
level1.push(yates_result.levels[i].clone());
level2.push(yates_result.levels[j].clone());
let diff = yates_result.means[i] - yates_result.means[j];
difference.push(diff);
let se_diff = (yates_result.se[i].powi(2) + yates_result.se[j].powi(2)).sqrt();
se.push(se_diff);
let z = if se_diff > 0.0 { diff / se_diff } else { 0.0 };
z_scores.push(z);
let p = 2.0 * (1.0 - normal_cdf(z.abs()));
p_values.push(p);
}
}
Ok(YatesPairwiseResult {
level1,
level2,
difference,
se,
z: z_scores,
p_value: p_values,
})
}
#[derive(Debug, Clone)]
#[pyclass]
pub struct YatesPairwiseResult {
#[pyo3(get)]
pub level1: Vec<String>,
#[pyo3(get)]
pub level2: Vec<String>,
#[pyo3(get)]
pub difference: Vec<f64>,
#[pyo3(get)]
pub se: Vec<f64>,
#[pyo3(get)]
pub z: Vec<f64>,
#[pyo3(get)]
pub p_value: Vec<f64>,
}
fn z_score(conf_level: f64) -> f64 {
match conf_level {
c if (c - 0.90).abs() < 0.001 => 1.645,
c if (c - 0.95).abs() < 0.001 => 1.96,
c if (c - 0.99).abs() < 0.001 => 2.576,
_ => 1.96,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_yates_basic() {
let predictions = vec![1.0, 1.5, 2.0, 2.5, 3.0, 3.5];
let factor = vec![
"A".to_string(),
"A".to_string(),
"A".to_string(),
"B".to_string(),
"B".to_string(),
"B".to_string(),
];
let result = yates(predictions, factor, None, None).unwrap();
assert_eq!(result.levels.len(), 2);
assert_eq!(result.means.len(), 2);
let a_idx = result.levels.iter().position(|l| l == "A").unwrap();
let b_idx = result.levels.iter().position(|l| l == "B").unwrap();
assert!((result.means[a_idx] - 1.5).abs() < 0.01);
assert!((result.means[b_idx] - 3.0).abs() < 0.01);
}
#[test]
fn test_yates_weighted() {
let predictions = vec![1.0, 2.0, 3.0];
let factor = vec!["A".to_string(), "A".to_string(), "A".to_string()];
let weights = vec![1.0, 2.0, 1.0];
let result = yates(predictions, factor, Some(weights), None).unwrap();
assert!((result.means[0] - 2.0).abs() < 0.01);
}
#[test]
fn test_yates_contrast() {
let x = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0]; let coef = vec![0.5, 1.0];
let factor_levels = vec![0.0, 1.0];
let result = yates_contrast(x, coef, 3, 2, 0, factor_levels, Some("linear")).unwrap();
assert_eq!(result.levels.len(), 2);
}
#[test]
fn test_yates_pairwise() {
let result = YatesResult {
levels: vec!["A".to_string(), "B".to_string(), "C".to_string()],
means: vec![1.0, 2.0, 3.0],
se: vec![0.1, 0.1, 0.1],
lower: vec![0.8, 1.8, 2.8],
upper: vec![1.2, 2.2, 3.2],
n: vec![10, 10, 10],
predict_type: "linear".to_string(),
};
let pairwise = yates_pairwise(&result).unwrap();
assert_eq!(pairwise.level1.len(), 3); assert_eq!(pairwise.difference[0], -1.0); }
}