1use crate::error::{StatsError, StatsResult};
14use crate::survival::{CoxPH, KaplanMeier, NelsonAalen};
15use scirs2_core::ndarray::Array2;
16
17fn norm_ppf(p: f64) -> f64 {
22 let p = p.clamp(1e-15, 1.0 - 1e-15);
23 let q = p - 0.5;
24 if q.abs() <= 0.42 {
25 let r = q * q;
26 q * ((((-25.445_87 * r + 41.391_663) * r - 18.615_43) * r + 2.506_628)
27 / ((((3.130_347 * r - 21.060_244) * r + 23.083_928) * r - 8.476_377) * r + 1.0))
28 } else {
29 let r = if q < 0.0 { p } else { 1.0 - p };
30 let r = (-r.ln()).sqrt();
31 let x = (((2.321_213_5 * r + 4.850_091_7) * r - 2.297_460_0) * r - 2.787_688_0)
32 / ((1.637_547_9 * r + 3.543_889_2) * r + 1.0);
33 if q < 0.0 {
34 -x
35 } else {
36 x
37 }
38 }
39}
40
41pub struct KMCurve {
49 km: KaplanMeier,
50}
51
52impl KMCurve {
53 pub fn fit(times: &[f64], events: &[bool]) -> StatsResult<Self> {
59 let km = KaplanMeier::fit(times, events)?;
60 Ok(Self { km })
61 }
62
63 pub fn survival_function(&self, t: f64) -> f64 {
65 self.km.survival_at(t)
66 }
67
68 pub fn confidence_interval(&self, t: f64, alpha: f64) -> StatsResult<(f64, f64)> {
79 if alpha <= 0.0 || alpha >= 1.0 {
80 return Err(StatsError::InvalidArgument(format!(
81 "alpha must be in (0, 1), got {alpha}"
82 )));
83 }
84 let s = self.survival_function(t);
85 if s <= 0.0 || s >= 1.0 {
86 return Ok((s.clamp(0.0, 1.0), s.clamp(0.0, 1.0)));
87 }
88
89 let greenwood: f64 = self
91 .km
92 .times
93 .iter()
94 .enumerate()
95 .take_while(|(_, &tk)| tk <= t)
96 .map(|(k, _)| {
97 let n_k = self.km.n_at_risk[k] as f64;
98 let d_k = self.km.n_events[k] as f64;
99 if n_k > d_k {
100 d_k / (n_k * (n_k - d_k))
101 } else {
102 0.0
103 }
104 })
105 .sum();
106
107 if greenwood == 0.0 {
108 return Ok((s, s));
109 }
110
111 let z = norm_ppf(1.0 - alpha / 2.0);
112 let ln_s = s.ln();
113 let se_ll = (greenwood / (ln_s * ln_s)).sqrt();
114 let log_log_s = (-ln_s).ln();
115
116 let ll_lo = log_log_s - z * se_ll;
117 let ll_hi = log_log_s + z * se_ll;
118
119 let lower = (-ll_hi.exp()).exp().clamp(0.0, 1.0);
121 let upper = (-ll_lo.exp()).exp().clamp(0.0, 1.0);
122
123 Ok((lower.min(upper), lower.max(upper)))
124 }
125
126 pub fn median_survival(&self) -> Option<f64> {
128 self.km.median_survival()
129 }
130
131 pub fn mean_survival(&self) -> f64 {
133 self.km.mean_survival()
134 }
135}
136
137pub struct NACurve {
145 na: NelsonAalen,
146 hazard_increments: Vec<f64>,
148 at_risk: Vec<usize>,
150}
151
152impl NACurve {
153 pub fn fit(times: &[f64], events: &[bool]) -> StatsResult<Self> {
159 if times.is_empty() {
160 return Err(StatsError::InvalidArgument(
161 "times array cannot be empty".to_string(),
162 ));
163 }
164 if times.len() != events.len() {
165 return Err(StatsError::InvalidArgument(format!(
166 "times ({}) and events ({}) must have equal length",
167 times.len(),
168 events.len()
169 )));
170 }
171 for (i, &t) in times.iter().enumerate() {
172 if !t.is_finite() {
173 return Err(StatsError::InvalidArgument(format!(
174 "times[{i}] is not finite: {t}"
175 )));
176 }
177 if t < 0.0 {
178 return Err(StatsError::InvalidArgument(format!(
179 "times[{i}] is negative: {t}"
180 )));
181 }
182 }
183
184 let mut pairs: Vec<(f64, bool)> =
186 times.iter().copied().zip(events.iter().copied()).collect();
187 pairs.sort_by(|a, b| {
188 a.0.partial_cmp(&b.0)
189 .unwrap_or(std::cmp::Ordering::Equal)
190 .then(b.1.cmp(&a.1))
191 });
192
193 let n = pairs.len();
194 let mut at_risk_count = n;
195 let mut hazard_increments = Vec::new();
196 let mut at_risk_vec = Vec::new();
197 let mut idx = 0;
198
199 while idx < pairs.len() {
200 let t = pairs[idx].0;
201 let mut d = 0_usize;
202 let mut c = 0_usize;
203 while idx < pairs.len() && (pairs[idx].0 - t).abs() < 1e-12 {
204 if pairs[idx].1 {
205 d += 1;
206 } else {
207 c += 1;
208 }
209 idx += 1;
210 }
211 if d > 0 && at_risk_count > 0 {
212 hazard_increments.push((d as f64) / (at_risk_count as f64));
213 at_risk_vec.push(at_risk_count);
214 }
215 at_risk_count -= d + c;
216 }
217
218 let na = NelsonAalen::fit(times, events)?;
219 Ok(Self {
220 na,
221 hazard_increments,
222 at_risk: at_risk_vec,
223 })
224 }
225
226 pub fn survival_function(&self, t: f64) -> f64 {
228 self.na.survival_at(t)
229 }
230
231 pub fn cumulative_hazard(&self, t: f64) -> f64 {
233 self.na.hazard_at(t)
234 }
235
236 pub fn confidence_interval(&self, t: f64, alpha: f64) -> StatsResult<(f64, f64)> {
247 if alpha <= 0.0 || alpha >= 1.0 {
248 return Err(StatsError::InvalidArgument(format!(
249 "alpha must be in (0, 1), got {alpha}"
250 )));
251 }
252 let s = self.survival_function(t);
253 if s <= 0.0 || s >= 1.0 {
254 return Ok((s.clamp(0.0, 1.0), s.clamp(0.0, 1.0)));
255 }
256
257 let var_h: f64 = self
260 .na
261 .times
262 .iter()
263 .enumerate()
264 .take_while(|(_, &tk)| tk <= t)
265 .map(|(k, _)| {
266 if k < self.at_risk.len() && self.at_risk[k] > 0 {
267 self.hazard_increments[k] / self.at_risk[k] as f64
268 } else {
269 0.0
270 }
271 })
272 .sum();
273
274 if var_h == 0.0 {
275 return Ok((s, s));
276 }
277
278 let h = -s.ln();
279 let z = norm_ppf(1.0 - alpha / 2.0);
280 let se = var_h.sqrt();
281
282 let c = (z * se / h).exp();
284 let h_lo = h / c;
285 let h_hi = h * c;
286
287 let upper = (-h_lo).exp().clamp(0.0, 1.0);
288 let lower = (-h_hi).exp().clamp(0.0, 1.0);
289
290 Ok((lower.min(upper), lower.max(upper)))
291 }
292}
293
294pub fn log_rank_test(
309 group1_times: &[f64],
310 group1_events: &[bool],
311 group2_times: &[f64],
312 group2_events: &[bool],
313) -> StatsResult<(f64, f64)> {
314 let result =
315 KaplanMeier::log_rank_test(group1_times, group1_events, group2_times, group2_events)?;
316 Ok(result)
317}
318
319pub struct CoxPHModel {
327 inner: CoxPH,
328}
329
330impl CoxPHModel {
331 pub fn fit(times: &[f64], events: &[bool], covariates: &Array2<f64>) -> StatsResult<Self> {
338 let inner = CoxPH::fit(times, events, covariates)?;
339 Ok(Self { inner })
340 }
341
342 pub fn coefficients(&self) -> Vec<f64> {
344 self.inner.coefficients.iter().copied().collect()
345 }
346
347 pub fn standard_errors(&self) -> Vec<f64> {
349 self.inner.std_errors.iter().copied().collect()
350 }
351
352 pub fn p_values(&self) -> Vec<f64> {
354 self.inner.p_values.iter().copied().collect()
355 }
356
357 pub fn hazard_ratios(&self) -> Vec<f64> {
359 self.inner.hazard_ratio().iter().copied().collect()
360 }
361
362 pub fn predict_risk(&self, x: &[f64]) -> f64 {
364 use scirs2_core::ndarray::Array1;
365 let arr = Array1::from_vec(x.to_vec());
366 self.inner.predict_risk(&arr)
367 }
368
369 pub fn concordance_index(
373 &self,
374 times: &[f64],
375 events: &[bool],
376 covariates: &Array2<f64>,
377 ) -> f64 {
378 self.inner.concordance_index(times, events, covariates)
379 }
380
381 pub fn log_likelihood(&self) -> f64 {
383 self.inner.log_likelihood
384 }
385
386 pub fn n_iterations(&self) -> usize {
388 self.inner.n_iter
389 }
390}
391
392#[cfg(test)]
397mod tests {
398 use super::*;
399 use scirs2_core::ndarray::Array2;
400
401 fn sample_data() -> (Vec<f64>, Vec<bool>) {
402 (
403 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
404 vec![
405 true, true, false, true, false, true, true, false, true, true,
406 ],
407 )
408 }
409
410 #[test]
413 fn test_kmcurve_survival_starts_at_one() {
414 let (times, events) = sample_data();
415 let km = KMCurve::fit(×, &events).expect("fit failed");
416 assert_eq!(km.survival_function(0.0), 1.0);
417 }
418
419 #[test]
420 fn test_kmcurve_survival_bounded() {
421 let (times, events) = sample_data();
422 let km = KMCurve::fit(×, &events).expect("fit failed");
423 for t in [0.0, 1.5, 5.0, 10.0, 20.0] {
424 let s = km.survival_function(t);
425 assert!(s >= 0.0 && s <= 1.0, "S({t}) = {s} out of [0,1]");
426 }
427 }
428
429 #[test]
430 fn test_kmcurve_survival_non_increasing() {
431 let (times, events) = sample_data();
432 let km = KMCurve::fit(×, &events).expect("fit failed");
433 let ts = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 20.0];
434 let mut prev = 1.0_f64;
435 for &t in &ts {
436 let s = km.survival_function(t);
437 assert!(s <= prev + 1e-12, "S({t}) = {s} > S_prev = {prev}");
438 prev = s;
439 }
440 }
441
442 #[test]
443 fn test_kmcurve_confidence_interval_ordering() {
444 let (times, events) = sample_data();
445 let km = KMCurve::fit(×, &events).expect("fit failed");
446 for t in [2.0, 5.0, 8.0] {
447 let (lo, hi) = km.confidence_interval(t, 0.05).expect("CI failed");
448 assert!(lo <= hi + 1e-10, "lo > hi at t={t}: {lo} {hi}");
449 assert!(lo >= 0.0 && hi <= 1.0);
450 }
451 }
452
453 #[test]
454 fn test_kmcurve_ci_invalid_alpha() {
455 let (times, events) = sample_data();
456 let km = KMCurve::fit(×, &events).expect("fit failed");
457 assert!(km.confidence_interval(3.0, 0.0).is_err());
458 assert!(km.confidence_interval(3.0, 1.0).is_err());
459 }
460
461 #[test]
464 fn test_nacurve_survival_starts_at_one() {
465 let (times, events) = sample_data();
466 let na = NACurve::fit(×, &events).expect("fit failed");
467 assert_eq!(na.survival_function(0.0), 1.0);
468 }
469
470 #[test]
471 fn test_nacurve_survival_bounded() {
472 let (times, events) = sample_data();
473 let na = NACurve::fit(×, &events).expect("fit failed");
474 for t in [0.0, 2.5, 6.0, 12.0] {
475 let s = na.survival_function(t);
476 assert!(s >= 0.0 && s <= 1.0, "S({t}) = {s} out of [0,1]");
477 }
478 }
479
480 #[test]
481 fn test_nacurve_confidence_interval_ordering() {
482 let (times, events) = sample_data();
483 let na = NACurve::fit(×, &events).expect("fit failed");
484 let (lo, hi) = na.confidence_interval(5.0, 0.05).expect("CI failed");
485 assert!(lo <= hi + 1e-10, "lo > hi: {lo} {hi}");
486 assert!(lo >= 0.0 && hi <= 1.0);
487 }
488
489 #[test]
490 fn test_nacurve_ci_invalid_alpha() {
491 let (times, events) = sample_data();
492 let na = NACurve::fit(×, &events).expect("fit failed");
493 assert!(na.confidence_interval(3.0, 0.0).is_err());
494 assert!(na.confidence_interval(3.0, 1.5).is_err());
495 }
496
497 #[test]
500 fn test_log_rank_different_groups() {
501 let times1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
502 let events1 = vec![true, true, true, true, true];
503 let times2 = vec![6.0, 7.0, 8.0, 9.0, 10.0];
504 let events2 = vec![true, true, true, true, true];
505 let (stat, p) =
506 log_rank_test(×1, &events1, ×2, &events2).expect("log_rank_test failed");
507 assert!(stat >= 0.0, "statistic should be non-negative");
508 assert!(p >= 0.0 && p <= 1.0, "p-value out of range: {p}");
509 assert!(p < 0.05, "expected significant difference, p = {p}");
510 }
511
512 #[test]
513 fn test_log_rank_identical_groups() {
514 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
515 let events = vec![true, true, false, true, false, true];
516 let (stat, p) =
517 log_rank_test(×, &events, ×, &events).expect("log_rank_test failed");
518 assert!(stat < 1e-10, "identical groups: stat={stat}");
519 assert!(p > 0.5, "identical groups should have large p={p}");
520 }
521
522 #[test]
525 fn test_coxph_fit_and_coefficients() {
526 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
527 let events = vec![true, true, false, true, false, true, true, false];
528 let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
529 .expect("array failed");
530 let model = CoxPHModel::fit(×, &events, &x).expect("fit failed");
531 assert_eq!(model.coefficients().len(), 1);
532 assert!(model.coefficients()[0].is_finite());
533 }
534
535 #[test]
536 fn test_coxph_log_likelihood_finite() {
537 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
538 let events = vec![true, true, false, true, false, true, true, false];
539 let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
540 .expect("array failed");
541 let model = CoxPHModel::fit(×, &events, &x).expect("fit failed");
542 assert!(model.log_likelihood().is_finite());
543 }
544
545 #[test]
546 fn test_coxph_hazard_ratios_positive() {
547 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
548 let events = vec![true, true, false, true, false, true, true, false];
549 let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
550 .expect("array failed");
551 let model = CoxPHModel::fit(×, &events, &x).expect("fit failed");
552 for &hr in model.hazard_ratios().iter() {
553 assert!(hr > 0.0, "HR should be positive, got {hr}");
554 }
555 }
556
557 #[test]
558 fn test_coxph_predict_risk_positive() {
559 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
560 let events = vec![true, true, false, true, false, true, true, false];
561 let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
562 .expect("array failed");
563 let model = CoxPHModel::fit(×, &events, &x).expect("fit failed");
564 let risk = model.predict_risk(&[0.5]);
565 assert!(risk > 0.0, "risk should be positive, got {risk}");
566 }
567
568 #[test]
569 fn test_coxph_concordance_index_in_range() {
570 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
571 let events = vec![true, true, false, true, false, true, true, false];
572 let x_data = vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7];
573 let x = Array2::from_shape_vec((8, 1), x_data.clone()).expect("array failed");
574 let model = CoxPHModel::fit(×, &events, &x).expect("fit failed");
575 let ci = model.concordance_index(×, &events, &x);
576 assert!(ci >= 0.0 && ci <= 1.0, "C-index out of [0,1]: {ci}");
577 }
578}