use pyo3::prelude::*;
#[derive(Debug, Clone)]
#[pyclass]
pub struct RttrightResult {
#[pyo3(get)]
pub weights: Vec<f64>,
#[pyo3(get)]
pub time: Vec<f64>,
#[pyo3(get)]
pub status: Vec<i32>,
#[pyo3(get)]
pub order: Vec<usize>,
}
#[pyfunction]
#[pyo3(signature = (time, status, weights=None))]
pub fn rttright(
time: Vec<f64>,
status: Vec<i32>,
weights: Option<Vec<f64>>,
) -> PyResult<RttrightResult> {
let n = time.len();
if status.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"time and status must have same length",
));
}
let init_weights = weights.unwrap_or_else(|| vec![1.0; n]);
if init_weights.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"weights must have same length as time",
));
}
if n == 0 {
return Ok(RttrightResult {
weights: vec![],
time: vec![],
status: vec![],
order: vec![],
});
}
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| {
time[a]
.partial_cmp(&time[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let sorted_time: Vec<f64> = indices.iter().map(|&i| time[i]).collect();
let sorted_status: Vec<i32> = indices.iter().map(|&i| status[i]).collect();
let sorted_weights: Vec<f64> = indices.iter().map(|&i| init_weights[i]).collect();
let mut rtt_weights = vec![0.0; n];
let mut cumulative_censored_weight = 0.0;
for i in 0..n {
if i > 0 && sorted_status[i - 1] == 0 {
cumulative_censored_weight += sorted_weights[i - 1];
}
if sorted_status[i] == 1 {
let base_weight = sorted_weights[i];
let n_at_risk = (n - i) as f64;
let redistributed = if n_at_risk > 0.0 {
cumulative_censored_weight / n_at_risk
} else {
0.0
};
rtt_weights[i] = base_weight + redistributed;
} else {
rtt_weights[i] = 0.0;
}
}
let km_weights = compute_km_weights(&sorted_time, &sorted_status, &sorted_weights);
Ok(RttrightResult {
weights: km_weights,
time: sorted_time,
status: sorted_status,
order: indices,
})
}
fn compute_km_weights(time: &[f64], status: &[i32], init_weights: &[f64]) -> Vec<f64> {
let n = time.len();
if n == 0 {
return vec![];
}
let mut g_t = vec![1.0; n];
let mut n_at_risk = init_weights.iter().sum::<f64>();
let mut prev_time = f64::NEG_INFINITY;
let mut current_g = 1.0;
for i in 0..n {
if time[i] > prev_time && i > 0 {
let mut n_censored = 0.0;
let mut j = i - 1;
while j < i && time[j] == prev_time {
if status[j] == 0 {
n_censored += init_weights[j];
}
if j == 0 {
break;
}
j -= 1;
}
if n_at_risk > 0.0 && n_censored > 0.0 {
current_g *= 1.0 - n_censored / n_at_risk;
}
}
g_t[i] = current_g;
n_at_risk -= init_weights[i];
prev_time = time[i];
}
let mut weights = Vec::with_capacity(n);
for i in 0..n {
if status[i] == 1 {
let g = if i > 0 { g_t[i - 1] } else { 1.0 };
let w = if g > 1e-10 {
init_weights[i] / g
} else {
init_weights[i]
};
weights.push(w);
} else {
weights.push(0.0);
}
}
weights
}
#[pyfunction]
#[pyo3(signature = (time, status, strata, weights=None))]
pub fn rttright_stratified(
time: Vec<f64>,
status: Vec<i32>,
strata: Vec<i32>,
weights: Option<Vec<f64>>,
) -> PyResult<RttrightResult> {
let n = time.len();
if status.len() != n || strata.len() != n {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"time, status, and strata must have same length",
));
}
let init_weights = weights.unwrap_or_else(|| vec![1.0; n]);
let mut strata_indices: std::collections::HashMap<i32, Vec<usize>> =
std::collections::HashMap::new();
for (i, &s) in strata.iter().enumerate() {
strata_indices.entry(s).or_default().push(i);
}
let mut final_weights = vec![0.0; n];
let mut final_order = vec![0; n];
let mut offset = 0;
for indices in strata_indices.values() {
let strata_time: Vec<f64> = indices.iter().map(|&i| time[i]).collect();
let strata_status: Vec<i32> = indices.iter().map(|&i| status[i]).collect();
let strata_weights: Vec<f64> = indices.iter().map(|&i| init_weights[i]).collect();
let result = rttright(strata_time, strata_status, Some(strata_weights))?;
for (j, &orig_idx) in indices.iter().enumerate() {
final_weights[orig_idx] = result.weights[j];
final_order[offset + j] = orig_idx;
}
offset += indices.len();
}
Ok(RttrightResult {
weights: final_weights,
time,
status,
order: final_order,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rttright_basic() {
let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let status = vec![1, 0, 1, 0, 1];
let result = rttright(time, status, None).unwrap();
assert!(result.weights[0] > 0.0);
assert!(result.weights[2] > 0.0);
assert!(result.weights[4] > 0.0);
assert_eq!(result.weights[1], 0.0);
assert_eq!(result.weights[3], 0.0);
}
#[test]
fn test_rttright_all_events() {
let time = vec![1.0, 2.0, 3.0];
let status = vec![1, 1, 1];
let result = rttright(time, status, None).unwrap();
for w in &result.weights {
assert!((*w - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_rttright_empty() {
let time: Vec<f64> = vec![];
let status: Vec<i32> = vec![];
let result = rttright(time, status, None).unwrap();
assert!(result.weights.is_empty());
}
}