Skip to main content

cyanea_stats/
testing.rs

1//! Hypothesis testing.
2//!
3//! Provides parametric ([`t_test_one_sample`], [`t_test_two_sample`]) and
4//! non-parametric ([`mann_whitney_u`]) statistical tests.
5
6use cyanea_core::{CyaneaError, Result, Scored, Summarizable};
7
8use crate::descriptive;
9use crate::distribution::{betai, ln_gamma, ChiSquared, FDistribution, Normal, Distribution};
10use crate::rank::{rank, RankMethod};
11
12/// Result of a hypothesis test.
13#[derive(Debug, Clone)]
14pub struct TestResult {
15    /// The test statistic (t, U, z, etc.).
16    pub statistic: f64,
17    /// Two-tailed p-value.
18    pub p_value: f64,
19    /// Degrees of freedom, if applicable.
20    pub degrees_of_freedom: Option<f64>,
21    /// Name of the test method.
22    pub method: String,
23}
24
25impl Scored for TestResult {
26    fn score(&self) -> f64 {
27        self.p_value
28    }
29}
30
31impl Summarizable for TestResult {
32    fn summary(&self) -> String {
33        match self.degrees_of_freedom {
34            Some(df) => format!(
35                "{}: statistic={:.4}, df={:.1}, p={:.6}",
36                self.method, self.statistic, df, self.p_value,
37            ),
38            None => format!(
39                "{}: statistic={:.4}, p={:.6}",
40                self.method, self.statistic, self.p_value,
41            ),
42        }
43    }
44}
45
46// ── t-distribution helpers ──────────────────────────────────────────────────
47
48/// Two-tailed p-value for the t-distribution.
49fn t_two_tailed_p(t: f64, df: f64) -> f64 {
50    let x = df / (df + t * t);
51    betai(df / 2.0, 0.5, x).unwrap_or(1.0)
52}
53
54// ── One-sample t-test ──────────────────────────────────────────────────────
55
56/// One-sample t-test: test whether the population mean equals `mu`.
57///
58/// Requires at least 2 observations.
59pub fn t_test_one_sample(data: &[f64], mu: f64) -> Result<TestResult> {
60    let n = data.len();
61    if n < 2 {
62        return Err(CyaneaError::InvalidInput(
63            "t_test_one_sample: need at least 2 observations".into(),
64        ));
65    }
66
67    let mean = descriptive::mean(data)?;
68    let se = descriptive::std_dev(data, 1)? / (n as f64).sqrt();
69    let t = (mean - mu) / se;
70    let df = (n - 1) as f64;
71    let p = t_two_tailed_p(t, df);
72
73    Ok(TestResult {
74        statistic: t,
75        p_value: p,
76        degrees_of_freedom: Some(df),
77        method: "One-sample t-test".into(),
78    })
79}
80
81// ── Two-sample t-test ──────────────────────────────────────────────────────
82
83/// Two-sample t-test: test whether two populations have the same mean.
84///
85/// When `equal_var` is `true`, uses pooled variance (Student's t-test).
86/// When `false`, uses Welch's t-test (unequal variances).
87pub fn t_test_two_sample(x: &[f64], y: &[f64], equal_var: bool) -> Result<TestResult> {
88    if x.len() < 2 || y.len() < 2 {
89        return Err(CyaneaError::InvalidInput(
90            "t_test_two_sample: each group needs at least 2 observations".into(),
91        ));
92    }
93
94    let nx = x.len() as f64;
95    let ny = y.len() as f64;
96    let mean_x = descriptive::mean(x)?;
97    let mean_y = descriptive::mean(y)?;
98    let var_x = descriptive::variance(x, 1)?;
99    let var_y = descriptive::variance(y, 1)?;
100
101    let (t, df) = if equal_var {
102        // Pooled variance
103        let sp2 = ((nx - 1.0) * var_x + (ny - 1.0) * var_y) / (nx + ny - 2.0);
104        let se = (sp2 * (1.0 / nx + 1.0 / ny)).sqrt();
105        let t = (mean_x - mean_y) / se;
106        let df = nx + ny - 2.0;
107        (t, df)
108    } else {
109        // Welch's approximation
110        let se = (var_x / nx + var_y / ny).sqrt();
111        let t = (mean_x - mean_y) / se;
112        let vn_x = var_x / nx;
113        let vn_y = var_y / ny;
114        let num = (vn_x + vn_y).powi(2);
115        let denom = vn_x.powi(2) / (nx - 1.0) + vn_y.powi(2) / (ny - 1.0);
116        let df = num / denom;
117        (t, df)
118    };
119
120    let p = t_two_tailed_p(t, df);
121    let method = if equal_var {
122        "Two-sample t-test (pooled)"
123    } else {
124        "Welch's t-test"
125    };
126
127    Ok(TestResult {
128        statistic: t,
129        p_value: p,
130        degrees_of_freedom: Some(df),
131        method: method.into(),
132    })
133}
134
135// ── Mann-Whitney U test ────────────────────────────────────────────────────
136
137/// Mann-Whitney U test (Wilcoxon rank-sum test).
138///
139/// Non-parametric test for whether two independent samples come from the
140/// same distribution. Uses normal approximation for the p-value.
141///
142/// Each group needs at least 1 observation, and total n must be >= 2.
143pub fn mann_whitney_u(x: &[f64], y: &[f64]) -> Result<TestResult> {
144    if x.is_empty() || y.is_empty() {
145        return Err(CyaneaError::InvalidInput(
146            "mann_whitney_u: each group must be non-empty".into(),
147        ));
148    }
149    let nx = x.len();
150    let ny = y.len();
151    let n = nx + ny;
152    if n < 2 {
153        return Err(CyaneaError::InvalidInput(
154            "mann_whitney_u: need at least 2 total observations".into(),
155        ));
156    }
157
158    // Combine, rank, and sum ranks for x.
159    let mut combined: Vec<f64> = Vec::with_capacity(n);
160    combined.extend_from_slice(x);
161    combined.extend_from_slice(y);
162    let ranks = rank(&combined, RankMethod::Average);
163
164    let r1: f64 = ranks[..nx].iter().sum();
165    let u1 = r1 - (nx * (nx + 1)) as f64 / 2.0;
166    let u2 = (nx * ny) as f64 - u1;
167    let u = u1.min(u2);
168
169    // Normal approximation
170    let mu_u = (nx * ny) as f64 / 2.0;
171    let sigma_u = ((nx * ny * (n + 1)) as f64 / 12.0).sqrt();
172
173    let p = if sigma_u > 0.0 {
174        let z = (u - mu_u) / sigma_u;
175        // Two-tailed p-value via standard normal
176        let normal = Normal::standard();
177        (2.0 * normal.cdf(z)).min(1.0) // z <= 0 since u = min(u1,u2) <= mu_u
178    } else {
179        1.0
180    };
181
182    Ok(TestResult {
183        statistic: u,
184        p_value: p,
185        degrees_of_freedom: None,
186        method: "Mann-Whitney U test".into(),
187    })
188}
189
190// ── Fisher's exact test (2×2) ─────────────────────────────────────────────
191
192/// Fisher's exact test for a 2×2 contingency table.
193///
194/// The table is specified as `[[a, b], [c, d]]`:
195///
196/// ```text
197///           Group 1   Group 2
198/// Outcome A    a         b
199/// Outcome B    c         d
200/// ```
201///
202/// Returns a two-tailed p-value based on the hypergeometric distribution.
203pub fn fisher_exact(table: &[[usize; 2]; 2]) -> Result<TestResult> {
204    let a = table[0][0];
205    let b = table[0][1];
206    let c = table[1][0];
207    let d = table[1][1];
208    let n = a + b + c + d;
209
210    if n == 0 {
211        return Err(CyaneaError::InvalidInput("fisher_exact: table is all zeros".into()));
212    }
213
214    let p_observed = hypergeometric_pmf(a, a + b, a + c, n);
215
216    // Two-tailed: sum probabilities of tables as or more extreme than observed
217    let row1 = a + b;
218    let col1 = a + c;
219    let min_a = if row1 + col1 > n { row1 + col1 - n } else { 0 };
220    let max_a = row1.min(col1);
221
222    let mut p_value = 0.0;
223    for k in min_a..=max_a {
224        let p_k = hypergeometric_pmf(k, row1, col1, n);
225        if p_k <= p_observed + 1e-12 {
226            p_value += p_k;
227        }
228    }
229
230    Ok(TestResult {
231        statistic: p_observed,
232        p_value: p_value.min(1.0),
233        degrees_of_freedom: None,
234        method: "Fisher's exact test".into(),
235    })
236}
237
238/// Hypergeometric PMF: P(X = k) where X ~ Hypergeometric(N, K, n).
239///
240/// Probability of drawing exactly `k` successes from a population of `total`
241/// containing `success_pop` successes, in a sample of size `sample_size`.
242pub(crate) fn hypergeometric_pmf(k: usize, sample_size: usize, success_pop: usize, total: usize) -> f64 {
243    // P = C(K,k) * C(N-K, n-k) / C(N, n)
244    // Compute in log-space to avoid overflow.
245    let log_p = ln_choose(success_pop, k)
246        + ln_choose(total - success_pop, sample_size - k)
247        - ln_choose(total, sample_size);
248    log_p.exp()
249}
250
251/// Log of binomial coefficient C(n, k) = ln(n!) - ln(k!) - ln((n-k)!).
252pub(crate) fn ln_choose(n: usize, k: usize) -> f64 {
253    if k > n {
254        return f64::NEG_INFINITY;
255    }
256    ln_gamma(n as f64 + 1.0) - ln_gamma(k as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0)
257}
258
259// ── Chi-squared test of independence ──────────────────────────────────────
260
261/// Chi-squared test of independence for an r×c contingency table.
262///
263/// `observed` is a row-major slice of `nrows × ncols` observed counts.
264///
265/// Uses Pearson's chi-squared statistic: χ² = Σ (O - E)² / E
266pub fn chi_squared_test(observed: &[f64], nrows: usize, ncols: usize) -> Result<TestResult> {
267    if nrows < 2 || ncols < 2 {
268        return Err(CyaneaError::InvalidInput(
269            "chi_squared_test: need at least 2×2 table".into(),
270        ));
271    }
272    if observed.len() != nrows * ncols {
273        return Err(CyaneaError::InvalidInput(
274            "chi_squared_test: observed length must equal nrows × ncols".into(),
275        ));
276    }
277
278    let total: f64 = observed.iter().sum();
279    if total == 0.0 {
280        return Err(CyaneaError::InvalidInput("chi_squared_test: all counts are zero".into()));
281    }
282
283    // Row and column sums
284    let mut row_sums = vec![0.0; nrows];
285    let mut col_sums = vec![0.0; ncols];
286    for i in 0..nrows {
287        for j in 0..ncols {
288            let val = observed[i * ncols + j];
289            row_sums[i] += val;
290            col_sums[j] += val;
291        }
292    }
293
294    // Compute chi-squared statistic
295    let mut chi2 = 0.0;
296    for i in 0..nrows {
297        for j in 0..ncols {
298            let expected = row_sums[i] * col_sums[j] / total;
299            if expected > 0.0 {
300                let diff = observed[i * ncols + j] - expected;
301                chi2 += diff * diff / expected;
302            }
303        }
304    }
305
306    let df = ((nrows - 1) * (ncols - 1)) as f64;
307    let chi2_dist = ChiSquared::new(df)?;
308    let p_value = 1.0 - chi2_dist.cdf(chi2);
309
310    Ok(TestResult {
311        statistic: chi2,
312        p_value,
313        degrees_of_freedom: Some(df),
314        method: "Chi-squared test of independence".into(),
315    })
316}
317
318// ── One-way ANOVA ─────────────────────────────────────────────────────────
319
320/// One-way analysis of variance (ANOVA).
321///
322/// Tests whether the means of k groups are equal. Each group must have at
323/// least 1 observation, and there must be at least 2 groups.
324pub fn anova_oneway(groups: &[&[f64]]) -> Result<TestResult> {
325    let k = groups.len();
326    if k < 2 {
327        return Err(CyaneaError::InvalidInput(
328            "anova_oneway: need at least 2 groups".into(),
329        ));
330    }
331    for (i, g) in groups.iter().enumerate() {
332        if g.is_empty() {
333            return Err(CyaneaError::InvalidInput(
334                format!("anova_oneway: group {} is empty", i),
335            ));
336        }
337    }
338
339    let n_total: usize = groups.iter().map(|g| g.len()).sum();
340    if n_total <= k {
341        return Err(CyaneaError::InvalidInput(
342            "anova_oneway: total observations must exceed number of groups".into(),
343        ));
344    }
345
346    // Grand mean
347    let grand_sum: f64 = groups.iter().flat_map(|g| g.iter()).sum();
348    let grand_mean = grand_sum / n_total as f64;
349
350    // Between-group sum of squares
351    let ss_between: f64 = groups
352        .iter()
353        .map(|g| {
354            let group_mean: f64 = g.iter().sum::<f64>() / g.len() as f64;
355            g.len() as f64 * (group_mean - grand_mean).powi(2)
356        })
357        .sum();
358
359    // Within-group sum of squares
360    let ss_within: f64 = groups
361        .iter()
362        .map(|g| {
363            let group_mean: f64 = g.iter().sum::<f64>() / g.len() as f64;
364            g.iter().map(|&x| (x - group_mean).powi(2)).sum::<f64>()
365        })
366        .sum();
367
368    let df_between = (k - 1) as f64;
369    let df_within = (n_total - k) as f64;
370
371    let ms_between = ss_between / df_between;
372    let ms_within = ss_within / df_within;
373
374    let f_stat = if ms_within > 0.0 {
375        ms_between / ms_within
376    } else {
377        f64::INFINITY
378    };
379
380    let f_dist = FDistribution::new(df_between, df_within)?;
381    let p_value = 1.0 - f_dist.cdf(f_stat);
382
383    Ok(TestResult {
384        statistic: f_stat,
385        p_value,
386        degrees_of_freedom: Some(df_between),
387        method: "One-way ANOVA".into(),
388    })
389}
390
391// ── Tests ──────────────────────────────────────────────────────────────────
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn t_test_one_sample_mean_equals_mu() {
399        // Data centered on mu=0 → large p-value
400        let data = [-1.0, -0.5, 0.0, 0.5, 1.0];
401        let result = t_test_one_sample(&data, 0.0).unwrap();
402        assert!(result.p_value > 0.9, "p={}", result.p_value);
403    }
404
405    #[test]
406    fn t_test_one_sample_mean_far_from_mu() {
407        let data = [10.0, 11.0, 12.0, 13.0, 14.0];
408        let result = t_test_one_sample(&data, 0.0).unwrap();
409        assert!(result.p_value < 0.001, "p={}", result.p_value);
410    }
411
412    #[test]
413    fn t_test_one_sample_too_few() {
414        assert!(t_test_one_sample(&[1.0], 0.0).is_err());
415    }
416
417    #[test]
418    fn t_test_two_sample_same_distribution() {
419        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
420        let y = [1.5, 2.5, 3.5, 4.5, 5.5];
421        let result = t_test_two_sample(&x, &y, true).unwrap();
422        // Means are close, p should be moderate to large
423        assert!(result.p_value > 0.3, "p={}", result.p_value);
424    }
425
426    #[test]
427    fn t_test_two_sample_different_means() {
428        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
429        let y = [100.0, 101.0, 102.0, 103.0, 104.0];
430        let result = t_test_two_sample(&x, &y, true).unwrap();
431        assert!(result.p_value < 0.001, "p={}", result.p_value);
432    }
433
434    #[test]
435    fn t_test_welch() {
436        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
437        let y = [100.0, 101.0, 102.0, 103.0, 104.0];
438        let result = t_test_two_sample(&x, &y, false).unwrap();
439        assert!(result.p_value < 0.001, "p={}", result.p_value);
440        assert!(result.method.contains("Welch"));
441    }
442
443    #[test]
444    fn t_test_two_sample_too_few() {
445        assert!(t_test_two_sample(&[1.0], &[2.0, 3.0], true).is_err());
446    }
447
448    #[test]
449    fn mann_whitney_same() {
450        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
451        let y = [1.5, 2.5, 3.5, 4.5, 5.5];
452        let result = mann_whitney_u(&x, &y).unwrap();
453        assert!(result.p_value > 0.3, "p={}", result.p_value);
454    }
455
456    #[test]
457    fn mann_whitney_different() {
458        let x = [1.0, 2.0, 3.0, 4.0, 5.0];
459        let y = [100.0, 101.0, 102.0, 103.0, 104.0];
460        let result = mann_whitney_u(&x, &y).unwrap();
461        assert!(result.p_value < 0.05, "p={}", result.p_value);
462    }
463
464    #[test]
465    fn mann_whitney_empty() {
466        assert!(mann_whitney_u(&[], &[1.0]).is_err());
467        assert!(mann_whitney_u(&[1.0], &[]).is_err());
468    }
469
470    #[test]
471    fn test_result_scored() {
472        let result = t_test_one_sample(&[1.0, 2.0, 3.0], 2.0).unwrap();
473        assert!((result.score() - result.p_value).abs() < 1e-15);
474    }
475
476    #[test]
477    fn test_result_summary() {
478        let result = t_test_one_sample(&[1.0, 2.0, 3.0, 4.0, 5.0], 0.0).unwrap();
479        let s = result.summary();
480        assert!(s.contains("One-sample t-test"));
481        assert!(s.contains("statistic="));
482        assert!(s.contains("p="));
483    }
484
485    // ── Fisher's exact test ────────────────────────────────────────────
486
487    #[test]
488    fn fisher_exact_significant() {
489        // Classic lady tasting tea: strong association
490        let table = [[8, 1], [1, 8]];
491        let result = fisher_exact(&table).unwrap();
492        assert!(result.p_value < 0.05, "p={}", result.p_value);
493    }
494
495    #[test]
496    fn fisher_exact_not_significant() {
497        // No association
498        let table = [[5, 5], [5, 5]];
499        let result = fisher_exact(&table).unwrap();
500        assert!(result.p_value > 0.5, "p={}", result.p_value);
501    }
502
503    #[test]
504    fn fisher_exact_extreme() {
505        // Perfect association
506        let table = [[10, 0], [0, 10]];
507        let result = fisher_exact(&table).unwrap();
508        assert!(result.p_value < 0.001, "p={}", result.p_value);
509    }
510
511    #[test]
512    fn fisher_exact_zero_table() {
513        let table = [[0, 0], [0, 0]];
514        assert!(fisher_exact(&table).is_err());
515    }
516
517    // ── Chi-squared test ───────────────────────────────────────────────
518
519    #[test]
520    fn chi_squared_test_independent() {
521        // Observed ≈ expected → not significant
522        #[rustfmt::skip]
523        let observed = [
524            50.0, 50.0,
525            50.0, 50.0,
526        ];
527        let result = chi_squared_test(&observed, 2, 2).unwrap();
528        assert!(result.p_value > 0.9, "p={}", result.p_value);
529    }
530
531    #[test]
532    fn chi_squared_test_dependent() {
533        // Strong deviation from expected
534        #[rustfmt::skip]
535        let observed = [
536            90.0, 10.0,
537            10.0, 90.0,
538        ];
539        let result = chi_squared_test(&observed, 2, 2).unwrap();
540        assert!(result.p_value < 0.001, "p={}", result.p_value);
541        assert!((result.degrees_of_freedom.unwrap() - 1.0).abs() < 1e-10);
542    }
543
544    #[test]
545    fn chi_squared_test_3x3() {
546        #[rustfmt::skip]
547        let observed = [
548            10.0, 20.0, 30.0,
549            20.0, 30.0, 10.0,
550            30.0, 10.0, 20.0,
551        ];
552        let result = chi_squared_test(&observed, 3, 3).unwrap();
553        assert!((result.degrees_of_freedom.unwrap() - 4.0).abs() < 1e-10);
554        assert!(result.p_value < 0.05, "p={}", result.p_value);
555    }
556
557    #[test]
558    fn chi_squared_test_invalid() {
559        assert!(chi_squared_test(&[1.0], 1, 1).is_err());
560        assert!(chi_squared_test(&[1.0, 2.0], 2, 2).is_err()); // length mismatch
561    }
562
563    // ── ANOVA ──────────────────────────────────────────────────────────
564
565    #[test]
566    fn anova_same_groups() {
567        let g1 = [1.0, 2.0, 3.0, 4.0, 5.0];
568        let g2 = [1.5, 2.5, 3.5, 4.5, 5.5];
569        let g3 = [1.0, 2.0, 3.0, 4.0, 5.0];
570        let result = anova_oneway(&[&g1, &g2, &g3]).unwrap();
571        assert!(result.p_value > 0.3, "p={}", result.p_value);
572    }
573
574    #[test]
575    fn anova_different_groups() {
576        let g1 = [1.0, 2.0, 3.0, 4.0, 5.0];
577        let g2 = [100.0, 101.0, 102.0, 103.0, 104.0];
578        let g3 = [200.0, 201.0, 202.0, 203.0, 204.0];
579        let result = anova_oneway(&[&g1, &g2, &g3]).unwrap();
580        assert!(result.p_value < 0.001, "p={}", result.p_value);
581        assert!(result.method.contains("ANOVA"));
582    }
583
584    #[test]
585    fn anova_two_groups_matches_t() {
586        // With 2 groups, ANOVA F = t² and p-values should agree
587        let g1 = [1.0, 2.0, 3.0, 4.0, 5.0];
588        let g2 = [3.0, 4.0, 5.0, 6.0, 7.0];
589        let anova_result = anova_oneway(&[&g1, &g2]).unwrap();
590        let t_result = t_test_two_sample(&g1, &g2, true).unwrap();
591        assert!((anova_result.p_value - t_result.p_value).abs() < 0.01);
592    }
593
594    #[test]
595    fn anova_too_few_groups() {
596        assert!(anova_oneway(&[&[1.0, 2.0]]).is_err());
597    }
598
599    #[test]
600    fn anova_empty_group() {
601        let g1: [f64; 0] = [];
602        assert!(anova_oneway(&[&g1, &[1.0, 2.0]]).is_err());
603    }
604}