use pyo3::prelude::*;
#[pyclass]
#[derive(Clone)]
pub struct Surv2DataResult {
#[pyo3(get)]
pub id: Vec<i32>,
#[pyo3(get)]
pub time1: Vec<f64>,
#[pyo3(get)]
pub time2: Vec<f64>,
#[pyo3(get)]
pub status: Vec<i32>,
#[pyo3(get)]
pub row_index: Vec<usize>,
}
#[pyfunction]
#[pyo3(signature = (id, time, event_time=None, event_status=None))]
pub fn surv2data(
id: Vec<i32>,
time: Vec<f64>,
event_time: Option<Vec<f64>>,
event_status: Option<Vec<i32>>,
) -> Surv2DataResult {
let n = id.len();
if n == 0 {
return Surv2DataResult {
id: Vec::new(),
time1: Vec::new(),
time2: Vec::new(),
status: Vec::new(),
row_index: Vec::new(),
};
}
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| match id[a].cmp(&id[b]) {
std::cmp::Ordering::Equal => time[a]
.partial_cmp(&time[b])
.unwrap_or(std::cmp::Ordering::Equal),
other => other,
});
let mut subject_event: std::collections::HashMap<i32, (f64, i32)> =
std::collections::HashMap::new();
if let (Some(etimes), Some(estatus)) = (&event_time, &event_status) {
for i in 0..n {
let subj_id = id[i];
subject_event
.entry(subj_id)
.or_insert((etimes[i], estatus[i]));
}
}
let mut result = Surv2DataResult {
id: Vec::with_capacity(n),
time1: Vec::with_capacity(n),
time2: Vec::with_capacity(n),
status: Vec::with_capacity(n),
row_index: Vec::with_capacity(n),
};
let mut i = 0;
while i < n {
let start_idx = indices[i];
let current_id = id[start_idx];
let mut subject_times: Vec<(f64, usize)> = Vec::new();
let mut j = i;
while j < n && id[indices[j]] == current_id {
subject_times.push((time[indices[j]], indices[j]));
j += 1;
}
let (subj_event_time, subj_event_status) = subject_event
.get(¤t_id)
.copied()
.unwrap_or((f64::INFINITY, 0));
for k in 0..subject_times.len() {
let (t1, orig_idx) = subject_times[k];
let t2 = if k + 1 < subject_times.len() {
subject_times[k + 1].0
} else if subj_event_time > t1 {
subj_event_time
} else {
t1
};
if t2 <= t1 {
continue;
}
let interval_status = if k + 1 >= subject_times.len() {
subj_event_status
} else {
0
};
result.id.push(current_id);
result.time1.push(t1);
result.time2.push(t2);
result.status.push(interval_status);
result.row_index.push(orig_idx + 1);
}
i = j;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_surv2data_basic() {
let id = vec![1, 1, 1];
let time = vec![0.0, 5.0, 10.0];
let event_time = Some(vec![15.0, 15.0, 15.0]);
let event_status = Some(vec![1, 1, 1]);
let result = surv2data(id, time, event_time, event_status);
assert_eq!(result.id.len(), 3);
assert_eq!(result.time1, vec![0.0, 5.0, 10.0]);
assert_eq!(result.time2, vec![5.0, 10.0, 15.0]);
assert_eq!(result.status, vec![0, 0, 1]);
}
#[test]
fn test_surv2data_multiple_subjects() {
let id = vec![1, 1, 2, 2];
let time = vec![0.0, 5.0, 0.0, 3.0];
let event_time = Some(vec![10.0, 10.0, 8.0, 8.0]);
let event_status = Some(vec![1, 1, 0, 0]);
let result = surv2data(id, time, event_time, event_status);
assert_eq!(result.id.len(), 4);
}
#[test]
fn test_surv2data_no_event_info() {
let id = vec![1, 1, 1];
let time = vec![0.0, 5.0, 10.0];
let result = surv2data(id, time, None, None);
assert!(result.id.len() >= 2);
assert_eq!(result.time1[0], 0.0);
assert_eq!(result.time2[0], 5.0);
}
}