1use cyanea_core::{CyaneaError, Result, Scored, Summarizable};
7
8use crate::descriptive;
9use crate::distribution::{betai, ln_gamma, ChiSquared, FDistribution, Normal, Distribution};
10use crate::rank::{rank, RankMethod};
11
12#[derive(Debug, Clone)]
14pub struct TestResult {
15 pub statistic: f64,
17 pub p_value: f64,
19 pub degrees_of_freedom: Option<f64>,
21 pub method: String,
23}
24
25impl Scored for TestResult {
26 fn score(&self) -> f64 {
27 self.p_value
28 }
29}
30
31impl Summarizable for TestResult {
32 fn summary(&self) -> String {
33 match self.degrees_of_freedom {
34 Some(df) => format!(
35 "{}: statistic={:.4}, df={:.1}, p={:.6}",
36 self.method, self.statistic, df, self.p_value,
37 ),
38 None => format!(
39 "{}: statistic={:.4}, p={:.6}",
40 self.method, self.statistic, self.p_value,
41 ),
42 }
43 }
44}
45
46fn t_two_tailed_p(t: f64, df: f64) -> f64 {
50 let x = df / (df + t * t);
51 betai(df / 2.0, 0.5, x).unwrap_or(1.0)
52}
53
54pub fn t_test_one_sample(data: &[f64], mu: f64) -> Result<TestResult> {
60 let n = data.len();
61 if n < 2 {
62 return Err(CyaneaError::InvalidInput(
63 "t_test_one_sample: need at least 2 observations".into(),
64 ));
65 }
66
67 let mean = descriptive::mean(data)?;
68 let se = descriptive::std_dev(data, 1)? / (n as f64).sqrt();
69 let t = (mean - mu) / se;
70 let df = (n - 1) as f64;
71 let p = t_two_tailed_p(t, df);
72
73 Ok(TestResult {
74 statistic: t,
75 p_value: p,
76 degrees_of_freedom: Some(df),
77 method: "One-sample t-test".into(),
78 })
79}
80
81pub fn t_test_two_sample(x: &[f64], y: &[f64], equal_var: bool) -> Result<TestResult> {
88 if x.len() < 2 || y.len() < 2 {
89 return Err(CyaneaError::InvalidInput(
90 "t_test_two_sample: each group needs at least 2 observations".into(),
91 ));
92 }
93
94 let nx = x.len() as f64;
95 let ny = y.len() as f64;
96 let mean_x = descriptive::mean(x)?;
97 let mean_y = descriptive::mean(y)?;
98 let var_x = descriptive::variance(x, 1)?;
99 let var_y = descriptive::variance(y, 1)?;
100
101 let (t, df) = if equal_var {
102 let sp2 = ((nx - 1.0) * var_x + (ny - 1.0) * var_y) / (nx + ny - 2.0);
104 let se = (sp2 * (1.0 / nx + 1.0 / ny)).sqrt();
105 let t = (mean_x - mean_y) / se;
106 let df = nx + ny - 2.0;
107 (t, df)
108 } else {
109 let se = (var_x / nx + var_y / ny).sqrt();
111 let t = (mean_x - mean_y) / se;
112 let vn_x = var_x / nx;
113 let vn_y = var_y / ny;
114 let num = (vn_x + vn_y).powi(2);
115 let denom = vn_x.powi(2) / (nx - 1.0) + vn_y.powi(2) / (ny - 1.0);
116 let df = num / denom;
117 (t, df)
118 };
119
120 let p = t_two_tailed_p(t, df);
121 let method = if equal_var {
122 "Two-sample t-test (pooled)"
123 } else {
124 "Welch's t-test"
125 };
126
127 Ok(TestResult {
128 statistic: t,
129 p_value: p,
130 degrees_of_freedom: Some(df),
131 method: method.into(),
132 })
133}
134
135pub fn mann_whitney_u(x: &[f64], y: &[f64]) -> Result<TestResult> {
144 if x.is_empty() || y.is_empty() {
145 return Err(CyaneaError::InvalidInput(
146 "mann_whitney_u: each group must be non-empty".into(),
147 ));
148 }
149 let nx = x.len();
150 let ny = y.len();
151 let n = nx + ny;
152 if n < 2 {
153 return Err(CyaneaError::InvalidInput(
154 "mann_whitney_u: need at least 2 total observations".into(),
155 ));
156 }
157
158 let mut combined: Vec<f64> = Vec::with_capacity(n);
160 combined.extend_from_slice(x);
161 combined.extend_from_slice(y);
162 let ranks = rank(&combined, RankMethod::Average);
163
164 let r1: f64 = ranks[..nx].iter().sum();
165 let u1 = r1 - (nx * (nx + 1)) as f64 / 2.0;
166 let u2 = (nx * ny) as f64 - u1;
167 let u = u1.min(u2);
168
169 let mu_u = (nx * ny) as f64 / 2.0;
171 let sigma_u = ((nx * ny * (n + 1)) as f64 / 12.0).sqrt();
172
173 let p = if sigma_u > 0.0 {
174 let z = (u - mu_u) / sigma_u;
175 let normal = Normal::standard();
177 (2.0 * normal.cdf(z)).min(1.0) } else {
179 1.0
180 };
181
182 Ok(TestResult {
183 statistic: u,
184 p_value: p,
185 degrees_of_freedom: None,
186 method: "Mann-Whitney U test".into(),
187 })
188}
189
190pub fn fisher_exact(table: &[[usize; 2]; 2]) -> Result<TestResult> {
204 let a = table[0][0];
205 let b = table[0][1];
206 let c = table[1][0];
207 let d = table[1][1];
208 let n = a + b + c + d;
209
210 if n == 0 {
211 return Err(CyaneaError::InvalidInput("fisher_exact: table is all zeros".into()));
212 }
213
214 let p_observed = hypergeometric_pmf(a, a + b, a + c, n);
215
216 let row1 = a + b;
218 let col1 = a + c;
219 let min_a = if row1 + col1 > n { row1 + col1 - n } else { 0 };
220 let max_a = row1.min(col1);
221
222 let mut p_value = 0.0;
223 for k in min_a..=max_a {
224 let p_k = hypergeometric_pmf(k, row1, col1, n);
225 if p_k <= p_observed + 1e-12 {
226 p_value += p_k;
227 }
228 }
229
230 Ok(TestResult {
231 statistic: p_observed,
232 p_value: p_value.min(1.0),
233 degrees_of_freedom: None,
234 method: "Fisher's exact test".into(),
235 })
236}
237
238pub(crate) fn hypergeometric_pmf(k: usize, sample_size: usize, success_pop: usize, total: usize) -> f64 {
243 let log_p = ln_choose(success_pop, k)
246 + ln_choose(total - success_pop, sample_size - k)
247 - ln_choose(total, sample_size);
248 log_p.exp()
249}
250
251pub(crate) fn ln_choose(n: usize, k: usize) -> f64 {
253 if k > n {
254 return f64::NEG_INFINITY;
255 }
256 ln_gamma(n as f64 + 1.0) - ln_gamma(k as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0)
257}
258
259pub fn chi_squared_test(observed: &[f64], nrows: usize, ncols: usize) -> Result<TestResult> {
267 if nrows < 2 || ncols < 2 {
268 return Err(CyaneaError::InvalidInput(
269 "chi_squared_test: need at least 2×2 table".into(),
270 ));
271 }
272 if observed.len() != nrows * ncols {
273 return Err(CyaneaError::InvalidInput(
274 "chi_squared_test: observed length must equal nrows × ncols".into(),
275 ));
276 }
277
278 let total: f64 = observed.iter().sum();
279 if total == 0.0 {
280 return Err(CyaneaError::InvalidInput("chi_squared_test: all counts are zero".into()));
281 }
282
283 let mut row_sums = vec![0.0; nrows];
285 let mut col_sums = vec![0.0; ncols];
286 for i in 0..nrows {
287 for j in 0..ncols {
288 let val = observed[i * ncols + j];
289 row_sums[i] += val;
290 col_sums[j] += val;
291 }
292 }
293
294 let mut chi2 = 0.0;
296 for i in 0..nrows {
297 for j in 0..ncols {
298 let expected = row_sums[i] * col_sums[j] / total;
299 if expected > 0.0 {
300 let diff = observed[i * ncols + j] - expected;
301 chi2 += diff * diff / expected;
302 }
303 }
304 }
305
306 let df = ((nrows - 1) * (ncols - 1)) as f64;
307 let chi2_dist = ChiSquared::new(df)?;
308 let p_value = 1.0 - chi2_dist.cdf(chi2);
309
310 Ok(TestResult {
311 statistic: chi2,
312 p_value,
313 degrees_of_freedom: Some(df),
314 method: "Chi-squared test of independence".into(),
315 })
316}
317
318pub fn anova_oneway(groups: &[&[f64]]) -> Result<TestResult> {
325 let k = groups.len();
326 if k < 2 {
327 return Err(CyaneaError::InvalidInput(
328 "anova_oneway: need at least 2 groups".into(),
329 ));
330 }
331 for (i, g) in groups.iter().enumerate() {
332 if g.is_empty() {
333 return Err(CyaneaError::InvalidInput(
334 format!("anova_oneway: group {} is empty", i),
335 ));
336 }
337 }
338
339 let n_total: usize = groups.iter().map(|g| g.len()).sum();
340 if n_total <= k {
341 return Err(CyaneaError::InvalidInput(
342 "anova_oneway: total observations must exceed number of groups".into(),
343 ));
344 }
345
346 let grand_sum: f64 = groups.iter().flat_map(|g| g.iter()).sum();
348 let grand_mean = grand_sum / n_total as f64;
349
350 let ss_between: f64 = groups
352 .iter()
353 .map(|g| {
354 let group_mean: f64 = g.iter().sum::<f64>() / g.len() as f64;
355 g.len() as f64 * (group_mean - grand_mean).powi(2)
356 })
357 .sum();
358
359 let ss_within: f64 = groups
361 .iter()
362 .map(|g| {
363 let group_mean: f64 = g.iter().sum::<f64>() / g.len() as f64;
364 g.iter().map(|&x| (x - group_mean).powi(2)).sum::<f64>()
365 })
366 .sum();
367
368 let df_between = (k - 1) as f64;
369 let df_within = (n_total - k) as f64;
370
371 let ms_between = ss_between / df_between;
372 let ms_within = ss_within / df_within;
373
374 let f_stat = if ms_within > 0.0 {
375 ms_between / ms_within
376 } else {
377 f64::INFINITY
378 };
379
380 let f_dist = FDistribution::new(df_between, df_within)?;
381 let p_value = 1.0 - f_dist.cdf(f_stat);
382
383 Ok(TestResult {
384 statistic: f_stat,
385 p_value,
386 degrees_of_freedom: Some(df_between),
387 method: "One-way ANOVA".into(),
388 })
389}
390
391#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn t_test_one_sample_mean_equals_mu() {
399 let data = [-1.0, -0.5, 0.0, 0.5, 1.0];
401 let result = t_test_one_sample(&data, 0.0).unwrap();
402 assert!(result.p_value > 0.9, "p={}", result.p_value);
403 }
404
405 #[test]
406 fn t_test_one_sample_mean_far_from_mu() {
407 let data = [10.0, 11.0, 12.0, 13.0, 14.0];
408 let result = t_test_one_sample(&data, 0.0).unwrap();
409 assert!(result.p_value < 0.001, "p={}", result.p_value);
410 }
411
412 #[test]
413 fn t_test_one_sample_too_few() {
414 assert!(t_test_one_sample(&[1.0], 0.0).is_err());
415 }
416
417 #[test]
418 fn t_test_two_sample_same_distribution() {
419 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
420 let y = [1.5, 2.5, 3.5, 4.5, 5.5];
421 let result = t_test_two_sample(&x, &y, true).unwrap();
422 assert!(result.p_value > 0.3, "p={}", result.p_value);
424 }
425
426 #[test]
427 fn t_test_two_sample_different_means() {
428 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
429 let y = [100.0, 101.0, 102.0, 103.0, 104.0];
430 let result = t_test_two_sample(&x, &y, true).unwrap();
431 assert!(result.p_value < 0.001, "p={}", result.p_value);
432 }
433
434 #[test]
435 fn t_test_welch() {
436 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
437 let y = [100.0, 101.0, 102.0, 103.0, 104.0];
438 let result = t_test_two_sample(&x, &y, false).unwrap();
439 assert!(result.p_value < 0.001, "p={}", result.p_value);
440 assert!(result.method.contains("Welch"));
441 }
442
443 #[test]
444 fn t_test_two_sample_too_few() {
445 assert!(t_test_two_sample(&[1.0], &[2.0, 3.0], true).is_err());
446 }
447
448 #[test]
449 fn mann_whitney_same() {
450 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
451 let y = [1.5, 2.5, 3.5, 4.5, 5.5];
452 let result = mann_whitney_u(&x, &y).unwrap();
453 assert!(result.p_value > 0.3, "p={}", result.p_value);
454 }
455
456 #[test]
457 fn mann_whitney_different() {
458 let x = [1.0, 2.0, 3.0, 4.0, 5.0];
459 let y = [100.0, 101.0, 102.0, 103.0, 104.0];
460 let result = mann_whitney_u(&x, &y).unwrap();
461 assert!(result.p_value < 0.05, "p={}", result.p_value);
462 }
463
464 #[test]
465 fn mann_whitney_empty() {
466 assert!(mann_whitney_u(&[], &[1.0]).is_err());
467 assert!(mann_whitney_u(&[1.0], &[]).is_err());
468 }
469
470 #[test]
471 fn test_result_scored() {
472 let result = t_test_one_sample(&[1.0, 2.0, 3.0], 2.0).unwrap();
473 assert!((result.score() - result.p_value).abs() < 1e-15);
474 }
475
476 #[test]
477 fn test_result_summary() {
478 let result = t_test_one_sample(&[1.0, 2.0, 3.0, 4.0, 5.0], 0.0).unwrap();
479 let s = result.summary();
480 assert!(s.contains("One-sample t-test"));
481 assert!(s.contains("statistic="));
482 assert!(s.contains("p="));
483 }
484
485 #[test]
488 fn fisher_exact_significant() {
489 let table = [[8, 1], [1, 8]];
491 let result = fisher_exact(&table).unwrap();
492 assert!(result.p_value < 0.05, "p={}", result.p_value);
493 }
494
495 #[test]
496 fn fisher_exact_not_significant() {
497 let table = [[5, 5], [5, 5]];
499 let result = fisher_exact(&table).unwrap();
500 assert!(result.p_value > 0.5, "p={}", result.p_value);
501 }
502
503 #[test]
504 fn fisher_exact_extreme() {
505 let table = [[10, 0], [0, 10]];
507 let result = fisher_exact(&table).unwrap();
508 assert!(result.p_value < 0.001, "p={}", result.p_value);
509 }
510
511 #[test]
512 fn fisher_exact_zero_table() {
513 let table = [[0, 0], [0, 0]];
514 assert!(fisher_exact(&table).is_err());
515 }
516
517 #[test]
520 fn chi_squared_test_independent() {
521 #[rustfmt::skip]
523 let observed = [
524 50.0, 50.0,
525 50.0, 50.0,
526 ];
527 let result = chi_squared_test(&observed, 2, 2).unwrap();
528 assert!(result.p_value > 0.9, "p={}", result.p_value);
529 }
530
531 #[test]
532 fn chi_squared_test_dependent() {
533 #[rustfmt::skip]
535 let observed = [
536 90.0, 10.0,
537 10.0, 90.0,
538 ];
539 let result = chi_squared_test(&observed, 2, 2).unwrap();
540 assert!(result.p_value < 0.001, "p={}", result.p_value);
541 assert!((result.degrees_of_freedom.unwrap() - 1.0).abs() < 1e-10);
542 }
543
544 #[test]
545 fn chi_squared_test_3x3() {
546 #[rustfmt::skip]
547 let observed = [
548 10.0, 20.0, 30.0,
549 20.0, 30.0, 10.0,
550 30.0, 10.0, 20.0,
551 ];
552 let result = chi_squared_test(&observed, 3, 3).unwrap();
553 assert!((result.degrees_of_freedom.unwrap() - 4.0).abs() < 1e-10);
554 assert!(result.p_value < 0.05, "p={}", result.p_value);
555 }
556
557 #[test]
558 fn chi_squared_test_invalid() {
559 assert!(chi_squared_test(&[1.0], 1, 1).is_err());
560 assert!(chi_squared_test(&[1.0, 2.0], 2, 2).is_err()); }
562
563 #[test]
566 fn anova_same_groups() {
567 let g1 = [1.0, 2.0, 3.0, 4.0, 5.0];
568 let g2 = [1.5, 2.5, 3.5, 4.5, 5.5];
569 let g3 = [1.0, 2.0, 3.0, 4.0, 5.0];
570 let result = anova_oneway(&[&g1, &g2, &g3]).unwrap();
571 assert!(result.p_value > 0.3, "p={}", result.p_value);
572 }
573
574 #[test]
575 fn anova_different_groups() {
576 let g1 = [1.0, 2.0, 3.0, 4.0, 5.0];
577 let g2 = [100.0, 101.0, 102.0, 103.0, 104.0];
578 let g3 = [200.0, 201.0, 202.0, 203.0, 204.0];
579 let result = anova_oneway(&[&g1, &g2, &g3]).unwrap();
580 assert!(result.p_value < 0.001, "p={}", result.p_value);
581 assert!(result.method.contains("ANOVA"));
582 }
583
584 #[test]
585 fn anova_two_groups_matches_t() {
586 let g1 = [1.0, 2.0, 3.0, 4.0, 5.0];
588 let g2 = [3.0, 4.0, 5.0, 6.0, 7.0];
589 let anova_result = anova_oneway(&[&g1, &g2]).unwrap();
590 let t_result = t_test_two_sample(&g1, &g2, true).unwrap();
591 assert!((anova_result.p_value - t_result.p_value).abs() < 0.01);
592 }
593
594 #[test]
595 fn anova_too_few_groups() {
596 assert!(anova_oneway(&[&[1.0, 2.0]]).is_err());
597 }
598
599 #[test]
600 fn anova_empty_group() {
601 let g1: [f64; 0] = [];
602 assert!(anova_oneway(&[&g1, &[1.0, 2.0]]).is_err());
603 }
604}