use crate::utilities::numpy_utils::{extract_vec_f64, extract_vec_i32};
use crate::utilities::statistical::chi2_sf;
use pyo3::prelude::*;
#[derive(Debug, Clone)]
#[pyclass]
pub struct SurvObrienResult {
#[pyo3(get)]
pub statistic: f64,
#[pyo3(get)]
pub p_value: f64,
#[pyo3(get)]
pub df: usize,
#[pyo3(get)]
pub scores: Vec<f64>,
#[pyo3(get)]
pub score_sum: f64,
#[pyo3(get)]
pub expected: f64,
#[pyo3(get)]
pub variance: f64,
}
#[pymethods]
impl SurvObrienResult {
#[new]
fn new(
statistic: f64,
p_value: f64,
df: usize,
scores: Vec<f64>,
score_sum: f64,
expected: f64,
variance: f64,
) -> Self {
Self {
statistic,
p_value,
df,
scores,
score_sum,
expected,
variance,
}
}
}
#[pyfunction]
#[pyo3(signature = (time, status, covariate, strata=None))]
pub fn survobrien(
time: &Bound<'_, PyAny>,
status: &Bound<'_, PyAny>,
covariate: &Bound<'_, PyAny>,
strata: Option<&Bound<'_, PyAny>>,
) -> PyResult<SurvObrienResult> {
let time_vec = extract_vec_f64(time)?;
let status_vec = extract_vec_i32(status)?;
let covariate_vec = extract_vec_f64(covariate)?;
let strata_vec = match strata {
Some(s) => extract_vec_i32(s)?,
None => vec![1; time_vec.len()],
};
let result = compute_survobrien(&time_vec, &status_vec, &covariate_vec, &strata_vec);
Ok(result)
}
fn compute_survobrien(
time: &[f64],
status: &[i32],
covariate: &[f64],
strata: &[i32],
) -> SurvObrienResult {
let n = time.len();
if n == 0 {
return SurvObrienResult {
statistic: 0.0,
p_value: 1.0,
df: 1,
scores: Vec::new(),
score_sum: 0.0,
expected: 0.0,
variance: 0.0,
};
}
let mut unique_strata: Vec<i32> = strata.to_vec();
unique_strata.sort();
unique_strata.dedup();
let mut scores = vec![0.0; n];
let mut total_score_sum = 0.0;
let total_expected = 0.0;
let mut total_variance = 0.0;
for &stratum in &unique_strata {
let stratum_indices: Vec<usize> = (0..n).filter(|&i| strata[i] == stratum).collect();
if stratum_indices.is_empty() {
continue;
}
let mut sorted_indices = stratum_indices.clone();
sorted_indices.sort_by(|&a, &b| {
time[a]
.partial_cmp(&time[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let n_stratum = sorted_indices.len();
let mut at_risk: Vec<bool> = vec![true; n_stratum];
let mut i = 0;
while i < n_stratum {
let current_time = time[sorted_indices[i]];
let mut event_indices: Vec<usize> = Vec::new();
let mut j = i;
while j < n_stratum && time[sorted_indices[j]] == current_time {
if status[sorted_indices[j]] == 1 {
event_indices.push(j);
}
j += 1;
}
if !event_indices.is_empty() {
let mut at_risk_values: Vec<(usize, f64)> = Vec::new();
for (k, &idx) in sorted_indices.iter().enumerate() {
if at_risk[k] {
at_risk_values.push((k, covariate[idx]));
}
}
let n_at_risk = at_risk_values.len();
if n_at_risk > 0 {
at_risk_values
.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranks: std::collections::HashMap<usize, f64> =
std::collections::HashMap::new();
let mut k = 0;
while k < n_at_risk {
let current_value = at_risk_values[k].1;
let mut tie_count = 1;
let mut rank_sum = (k + 1) as f64;
while k + tie_count < n_at_risk
&& (at_risk_values[k + tie_count].1 - current_value).abs() < 1e-10
{
rank_sum += (k + tie_count + 1) as f64;
tie_count += 1;
}
let avg_rank = rank_sum / tie_count as f64;
for t in 0..tie_count {
ranks.insert(at_risk_values[k + t].0, avg_rank);
}
k += tie_count;
}
let mean_rank = (n_at_risk as f64 + 1.0) / 2.0;
let var_rank = (n_at_risk as f64 * n_at_risk as f64 - 1.0) / 12.0;
for &event_local_idx in &event_indices {
if let Some(&rank) = ranks.get(&event_local_idx) {
let orig_idx = sorted_indices[event_local_idx];
if var_rank > 0.0 {
scores[orig_idx] = (rank - mean_rank) / var_rank.sqrt();
} else {
scores[orig_idx] = 0.0;
}
total_score_sum += scores[orig_idx];
}
}
let n_events = event_indices.len() as f64;
if var_rank > 0.0 {
total_variance += n_events / var_rank * var_rank;
}
}
}
for item in at_risk.iter_mut().take(j).skip(i) {
*item = false;
}
i = j;
}
}
let statistic = if total_variance > 0.0 {
total_score_sum * total_score_sum / total_variance
} else {
0.0
};
let p_value = chi2_sf(statistic, 1);
SurvObrienResult {
statistic,
p_value,
df: 1,
scores,
score_sum: total_score_sum,
expected: total_expected,
variance: total_variance,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_survobrien_basic() {
let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let status = vec![1, 1, 0, 1, 0];
let covariate = vec![10.0, 20.0, 15.0, 30.0, 25.0];
let strata = vec![1, 1, 1, 1, 1];
let result = compute_survobrien(&time, &status, &covariate, &strata);
assert!(result.statistic >= 0.0);
assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
assert_eq!(result.df, 1);
}
#[test]
fn test_survobrien_empty() {
let result = compute_survobrien(&[], &[], &[], &[]);
assert_eq!(result.statistic, 0.0);
assert_eq!(result.p_value, 1.0);
}
#[test]
fn test_survobrien_stratified() {
let time = vec![1.0, 2.0, 1.0, 2.0];
let status = vec![1, 0, 1, 0];
let covariate = vec![10.0, 20.0, 30.0, 40.0];
let strata = vec![1, 1, 2, 2];
let result = compute_survobrien(&time, &status, &covariate, &strata);
assert!(result.statistic >= 0.0);
assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
}
}