mini_mcmc/
ks_test.rs

1/*!
2Two-sample Kolmogorov–Smirnov test.
3
4This module provides functionality to perform a two-sample KS test,
5adapted from the [`kolmogorov_smirnov`](https://crates.io/crates/kolmogorov_smirnov)
6crate under the Apache 2.0 License. The KS test compares two samples to determine
7whether they come from the same distribution.
8
9# Overview
10
11The public API consists of:
12
13- The [`TotalF64`] struct, a wrapper around `f64` that provides a total order (even when NaN values occur).
14- The [`two_sample_ks_test`] function, which returns a [`TestResult`] containing the test statistic,
15  p-value, and a boolean flag indicating if the null hypothesis is rejected at a given significance level.
16
17The internal functions such as `compute_ks_statistic`, `ks_p_value`, `pks`, and `qks` perform the
18necessary computations based on algorithms found in *Numerical Recipes (Third Edition)*.
19*/
20use std::cmp::Ordering;
21
22use rand_distr::num_traits::ToPrimitive;
23// use std::cmp::Ordering;
24
25/**
26A wrapper around `f64` that implements a total ordering.
27
28This type is used for sorting and comparing floating-point values in a way that
29treats NaN values as equal (and places them after all finite numbers).
30
31# Examples
32
33```rust
34use mini_mcmc::ks_test::TotalF64;
35
36let mut values = [TotalF64(3.0), TotalF64(f64::NAN), TotalF64(1.0)];
37values.sort();
38assert_eq!(values[0].0, 1.0);
39assert_eq!(values[1].0, 3.0);
40assert!(values[2].0.is_nan());
41```
42*/
43#[derive(Debug, Copy, Clone, PartialEq)]
44pub struct TotalF64(pub f64);
45
46impl Eq for TotalF64 {}
47
48impl PartialOrd for TotalF64 {
49    fn partial_cmp(&self, other: &TotalF64) -> Option<Ordering> {
50        Some(self.cmp(other))
51    }
52}
53
54impl Ord for TotalF64 {
55    fn cmp(&self, other: &Self) -> Ordering {
56        self.0.total_cmp(&other.0)
57    }
58}
59
60/**
61Performs a two-sample Kolmogorov–Smirnov test on two samples at the given significance level.
62
63The function computes the maximum difference between the empirical distribution functions
64of the two samples and then estimates a p-value. If the p-value is less than the provided
65significance level, the null hypothesis (that the two samples are drawn from the same distribution)
66is rejected.
67
68# Type Parameters
69
70* `T`: A type that implements `Ord`, `Clone`, and `Copy`. In practice, you may wrap your
71  floating-point numbers with [`TotalF64`] to ensure a total ordering.
72
73# Arguments
74
75* `sample_1` - A slice containing the first sample.
76* `sample_2` - A slice containing the second sample.
77* `level` - The significance level at which to test the null hypothesis (e.g. 0.05).
78
79# Returns
80
81Returns a [`TestResult`] with the KS test statistic, p-value, and a flag indicating if the null
82hypothesis is rejected.
83
84# Errors
85
86Returns an error `String` if either sample is empty or if the p-value cannot be computed.
87
88# Examples
89
90```rust
91use mini_mcmc::ks_test::{two_sample_ks_test, TotalF64};
92
93// Two identical samples should yield a KS statistic of 0 and a p-value of 1.
94let sample_1: Vec<TotalF64> = (0..10).map(|x| TotalF64(x as f64)).collect();
95let result = two_sample_ks_test(&sample_1, &sample_1, 0.05).unwrap();
96assert_eq!(result.statistic, 0.0);
97assert!((result.p_value - 1.0).abs() < 1e-10);
98
99// For different samples, the statistic will be > 0 and the p-value will be less than 1.
100let sample_2: Vec<TotalF64> = sample_1.iter().map(|x| TotalF64(x.0 * x.0)).collect();
101let result_diff = two_sample_ks_test(&sample_1, &sample_2, 0.05).unwrap();
102assert!(result_diff.statistic > 0.0);
103assert!(result_diff.p_value < 1.0);
104```
105*/
106pub fn two_sample_ks_test<T: Ord + Clone + Copy>(
107    sample_1: &[T],
108    sample_2: &[T],
109    level: f64,
110) -> Result<TestResult, String> {
111    let statistic = compute_ks_statistic(sample_1, sample_2)?;
112    let p_value = ks_p_value(statistic, sample_1.len(), sample_2.len())?;
113    Ok(TestResult {
114        is_rejected: p_value < level,
115        statistic,
116        p_value,
117        level,
118    })
119}
120
121/**
122The result of a two-sample Kolmogorov–Smirnov test.
123
124Contains the test statistic, the computed p-value, the significance level used for testing,
125and a boolean flag `is_rejected` indicating whether the null hypothesis (that the two samples
126come from the same distribution) is rejected.
127*/
128#[derive(Debug)]
129pub struct TestResult {
130    pub is_rejected: bool,
131    pub statistic: f64,
132    pub p_value: f64,
133    pub level: f64,
134}
135
136/**
137Computes the Kolmogorov–Smirnov p-value for the two-sample case.
138
139This function uses an approximation based on the effective sample size and
140the KS test statistic. It asserts that both samples have sizes greater than 7 for accuracy.
141
142# Arguments
143
144* `statistic` - The KS test statistic.
145* `n1` - The size of the first sample.
146* `n2` - The size of the second sample.
147
148# Returns
149
150Returns the p-value as an `f64` if successful.
151
152*/
153pub fn ks_p_value(statistic: f64, n1: usize, n2: usize) -> Result<f64, String> {
154    if n1 <= 7 || n2 <= 7 {
155        return Err(("Requires sample sizes > 7 for accuracy.").to_string());
156    }
157
158    let factor = ((n1 as f64 * n2 as f64) / (n1 as f64 + n2 as f64)).sqrt();
159    let term = factor * statistic;
160
161    // We call `qks` to get the complementary CDF of the KS distribution.
162    let p_value = qks(term)?;
163    assert!((0.0..=1.0).contains(&p_value));
164
165    Ok(p_value)
166}
167
168/**
169Computes the two-sample KS statistic as the maximum absolute difference between the
170empirical distribution functions of the two samples.
171
172The input samples are first sorted (in ascending order) before computing the statistic.
173
174# Arguments
175
176* `sample_1` - The first sample.
177* `sample_2` - The second sample.
178
179# Returns
180
181Returns the KS statistic as an `f64` if both samples are non-empty.
182
183# Errors
184
185Returns an error if either sample is empty.
186*/
187pub fn compute_ks_statistic<T: Ord + Clone + Copy>(
188    sample_1: &[T],
189    sample_2: &[T],
190) -> Result<f64, String> {
191    if sample_1.is_empty() {
192        return Err("Expected sample_1 to be non-empty.".into());
193    }
194    if sample_2.is_empty() {
195        return Err("Expected sample_2 to be non-empty.".into());
196    }
197
198    // let (mut _sample_1, mut _sample_2) = (sample_1.clone(), sample_2.clone());
199    let mut _sample_1 = sample_1.to_vec();
200    let mut _sample_2 = sample_2.to_vec();
201
202    _sample_1.sort_unstable();
203    _sample_2.sort_unstable();
204
205    let (n, m) = (_sample_1.len(), _sample_2.len());
206    let (n_i32, m_i32) = (n as i32, m as i32);
207    let (n_f64, m_f64) = (n as f64, m as f64);
208
209    let (mut i, mut j) = (-1_i32, -1_i32);
210    let mut max_diff: f64 = 0.0;
211    let mut cur_x: T = _sample_1[0].min(_sample_2[0]);
212
213    while i + 1 < n_i32 || j + 1 < m_i32 {
214        advance(&mut i, n_i32, &_sample_1, &cur_x);
215        advance(&mut j, m_i32, &_sample_2, &cur_x);
216
217        let fi = if i < 0 { 0.0 } else { (i + 1) as f64 / n_f64 };
218        let fj = if j < 0 { 0.0 } else { (j + 1) as f64 / m_f64 };
219
220        max_diff = max_diff.max((fj - fi).abs());
221
222        let ip = (i + 1).to_usize().unwrap();
223        let jp = (j + 1).to_usize().unwrap();
224        if ip < n && jp < m {
225            cur_x = _sample_1[ip].min(_sample_2[jp]);
226        } else {
227            break;
228        }
229    }
230    Ok(max_diff)
231}
232
233/**
234Advances the index `i` while the next value in `sample` is less than or equal to `cur_x`.
235
236This helper function is used in the computation of the KS statistic.
237
238# Arguments
239
240* `i` - A mutable reference to the current index.
241* `n` - The total number of elements in the sample (as `i32`).
242* `sample` - The sorted sample slice.
243* `cur_x` - The current threshold value.
244
245# Example
246
247(This function is internal; see [`compute_ks_statistic`] for its usage.)
248*/
249fn advance<T: Ord + Clone>(i: &mut i32, n: i32, sample: &[T], cur_x: &T) {
250    while *i + 1 < n {
251        let next_val = &sample[(*i + 1) as usize];
252        if *next_val <= *cur_x {
253            *i += 1;
254        } else {
255            break;
256        }
257    }
258}
259
260/**
261Computes the one-sided cumulative distribution function (CDF) of the KS distribution.
262
263This function uses an algorithm adapted from *Numerical Recipes (Third Edition)*.
264
265# Arguments
266
267* `z` - The argument of the CDF (must be non-negative).
268
269# Returns
270
271Returns the CDF value for the KS distribution.
272
273# Errors
274
275Returns an error if `z` is negative.
276
277# Examples
278
279```rust
280// For z = 0, the CDF should be 0.
281let cdf = mini_mcmc::ks_test::pks(0.0).unwrap();
282assert_eq!(cdf, 0.0);
283```
284*/
285pub fn pks(z: f64) -> Result<f64, String> {
286    if z < 0. {
287        return Err("Bad z for KS distribution function.".into());
288    }
289    if z == 0. {
290        return Ok(0.);
291    }
292    if z < 1.18 {
293        let y = (-1.233_700_550_136_169_7 / z.powi(2)).exp();
294        return Ok(2.256_758_334_191_025
295            * (-y.ln()).sqrt()
296            * (y + y.powf(9.) + y.powf(25.) + y.powf(49.)));
297    }
298    let x = (-2. * z.powi(2)).exp();
299    Ok(1. - 2. * (x - x.powf(4.) + x.powf(9.)))
300}
301
302/**
303Computes the complementary CDF (Q-function) of the KS distribution.
304
305This function is also adapted from *Numerical Recipes (Third Edition)*.
306
307# Arguments
308
309* `z` - The argument of the Q-function (must be non-negative).
310
311# Returns
312
313Returns the complementary probability for the KS distribution.
314
315# Errors
316
317Returns an error if `z` is negative.
318
319# Examples
320
321```rust
322// For z = 0, the Q-function should return 1.
323let q = mini_mcmc::ks_test::qks(0.0).unwrap();
324assert_eq!(q, 1.0);
325```
326*/
327pub fn qks(z: f64) -> Result<f64, String> {
328    if z < 0. {
329        return Err("Bad z for KS distribution function.".into());
330    }
331    if z == 0. {
332        return Ok(1.);
333    }
334    if z < 1.18 {
335        return Ok(1. - pks(z)?);
336    }
337    let x = (-2. * z.powi(2)).exp();
338    Ok(2. * (x - x.powf(4.) + x.powf(9.)))
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    use rand::{rngs::SmallRng, Rng, SeedableRng};
346
347    #[test]
348    fn test_ks_p_value_too_few() {
349        let res = ks_p_value(1., 1, 1);
350        assert!(res.is_err(), "Expected to get an Err object");
351    }
352
353    #[test]
354    fn test_ks_p_value_ok() {
355        let res = ks_p_value(1., 8, 8);
356        assert!(res.is_ok(), "Expected to get a Ok object");
357    }
358
359    #[test]
360    fn test_ks_simple_case() {
361        // Three-element samples with partial overlap; we expect D ~ 1/3.
362        let s1 = [1.0, 2.0, 3.0].map(TotalF64);
363        let s2 = [2.0, 3.0, 4.0].map(TotalF64);
364        let d = compute_ks_statistic(&s1, &s2).unwrap();
365        assert!((d - 1.0 / 3.0).abs() < 1e-9, "Expected D ~ 1/3, got {}", d);
366    }
367
368    #[test]
369    fn test_ks_identical_samples() {
370        // Identical => D=0.
371        let s1 = [1.0, 2.0, 3.0].map(TotalF64);
372        let s2 = [1.0, 2.0, 3.0].map(TotalF64);
373        let d = compute_ks_statistic(&s1, &s2).unwrap();
374        assert_eq!(d, 0.0, "KS should be 0 for identical samples.");
375    }
376
377    #[test]
378    fn test_ks_non_overlapping() {
379        // Disjoint => D=1.
380        let s1 = [1.0, 2.0, 3.0].map(TotalF64);
381        let s2 = [10.0, 11.0, 12.0].map(TotalF64);
382        let d = compute_ks_statistic(&s1, &s2).unwrap();
383        assert_eq!(d, 1.0, "Non-overlapping samples => D=1.");
384    }
385
386    #[test]
387    fn test_ks_single_element() {
388        // s1=[2], s2=[5] => D=1.
389        let s1 = [TotalF64(2.0)];
390        let s2 = [TotalF64(5.0)];
391        let d = compute_ks_statistic(&s1, &s2).unwrap();
392        assert_eq!(d, 1.0);
393    }
394
395    #[test]
396    fn test_ks_repeated_values() {
397        // Tie-handling with repeated values; expect around 0.2 from R-like logic.
398        let s1 = [1.0, 1.0, 1.0, 2.0, 2.0].map(TotalF64);
399        let s2 = [1.0, 1.0, 2.0, 2.0, 2.0].map(TotalF64);
400        let d = compute_ks_statistic(&s1, &s2).unwrap();
401        assert!((d - 0.2).abs() < 1e-6, "Expected ~0.2, got {}", d);
402    }
403
404    #[test]
405    fn test_ks_partial_overlap() {
406        // Overlapping but not identical => D=0.25.
407        let s1 = [0.0, 1.0, 2.0, 3.0].map(TotalF64);
408        let s2 = [1.0, 2.0, 3.0, 4.0].map(TotalF64);
409        let d = compute_ks_statistic(&s1, &s2).unwrap();
410        assert!((d - 0.25).abs() < 1e-9, "Expected 0.25, got {}", d);
411    }
412
413    #[test]
414    fn test_ks_rep_similar() {
415        // Repeated pattern, slight difference => check statistic & p-value.
416        let s1: Vec<TotalF64> = [0.12, 0.25, 0.25, 0.78, 0.99, 0.33, 0.15, 0.5]
417            .iter()
418            .cycle()
419            .take(8 * 20)
420            .copied()
421            .map(TotalF64)
422            .collect();
423        let s2: Vec<TotalF64> = [0.12, 0.25, 0.25, 0.78, 0.99, 0.33, 0.15, 0.51]
424            .iter()
425            .cycle()
426            .take(8 * 20)
427            .copied()
428            .map(TotalF64)
429            .collect();
430
431        let result = two_sample_ks_test(&s1, &s2, 0.05).unwrap();
432        assert!((result.statistic - 0.125).abs() < 1e-9, "D mismatch");
433        assert!((result.p_value - 0.1641).abs() < 1e-4, "p-value mismatch");
434    }
435
436    #[test]
437    fn test_ks_empty_1() {
438        let s1 = [];
439        let s2 = [1.0, 2.0, 3.0, 4.0].map(TotalF64);
440        let res = compute_ks_statistic(&s1, &s2);
441        assert!(res.is_err(), "Expected compute_ks_statistic(...) to return an error since the first list is empty, got {:?}.", res);
442    }
443
444    #[test]
445    fn test_ks_empty_2() {
446        let s1 = [1.0, 2.0, 3.0, 4.0].map(TotalF64);
447        let s2 = [];
448        let res = compute_ks_statistic(&s1, &s2);
449        assert!(res.is_err(), "Expected compute_ks_statistic(...) to return an error since the second list is empty, got {:?}.", res);
450    }
451
452    #[test]
453    fn test_bad_z_for_pks() {
454        let res = pks(-1.0);
455        assert!(
456            res.is_err(),
457            "Expected pks(-1.0) to return an error, got {:?}.",
458            res
459        );
460    }
461
462    #[test]
463    fn test_pks_zero() {
464        match pks(0.0) {
465            Err(msg) => panic!("Expected pks(0.0) == 0, got error message {:?}.", msg),
466            Ok(val) => assert!(val == 0.0, "Expected pks(0.0) == 0, got {:?}.", val),
467        }
468    }
469
470    #[test]
471    fn test_pks_large_1() {
472        match pks(1.23) {
473            Err(msg) => panic!(
474                "Expected pks(1.23), to not error out, got error message {:?}.",
475                msg
476            ),
477            Ok(val) => assert!(
478                (val - 0.9029731024047791).abs() < 1e-8,
479                "Expected pks(1.23) ~= 0.9029731024047791, got {:?}.",
480                val
481            ),
482        }
483    }
484
485    #[test]
486    fn test_pks_large_2() {
487        match pks(2.34) {
488            Err(msg) => panic!(
489                "Expected pks(2.34), to not error out, got error message {:?}.",
490                msg
491            ),
492            Ok(val) => assert!(
493                (val - 0.9999649260833611).abs() < 1e-8,
494                "Expected pks(2.34) ~= 0.9999649260833611, got {:?}.",
495                val
496            ),
497        }
498    }
499
500    #[test]
501    fn test_pks_large_3() {
502        match pks(3.45) {
503            Err(msg) => panic!(
504                "Expected pks(3.45), to not error out, got error message {:?}.",
505                msg
506            ),
507            Ok(val) => assert!(
508                (val - 1.0).abs() < 1e-8,
509                "Expected pks(3.45) ~= 1.0, got {:?}.",
510                val
511            ),
512        }
513    }
514
515    #[test]
516    fn test_qks_zero() {
517        match qks(0.0) {
518            Err(msg) => panic!(
519                "Expected qks(0.0), to not error out, got error message {:?}.",
520                msg
521            ),
522            Ok(val) => assert!(val == 1.0, "Expected qks(0.0) = 0.0, got {:?}.", val),
523        }
524    }
525
526    #[test]
527    fn test_qks_large() {
528        match qks(1.2) {
529            Err(msg) => panic!(
530                "Expected qks(1.2), to not error out, got error message {:?}.",
531                msg
532            ),
533            Ok(val) => assert!(
534                (val - 0.11224966667072497).abs() < 1e-8,
535                "Expected qks(1.2) ~= 00.11224966667072497, got {:?}.",
536                val
537            ),
538        }
539    }
540
541    #[test]
542    fn test_bad_z_for_qks() {
543        let res = qks(-1.0);
544        assert!(
545            res.is_err(),
546            "Expected qks(-1.0) to return an error, got {:?}.",
547            res
548        );
549    }
550
551    #[test]
552    fn test_cmp_f64_middle_nan() {
553        let mut s = [1.0, f64::NAN, 3.0];
554        s.sort_by(|a, b| a.total_cmp(b));
555        assert!(
556            s[0] == 1.0 && s[1] == 3.0 && s[2].is_nan(),
557            "Expected sorting [1.0, NAN, 3.0] to give [1.0, 3.0, NAN], got {s:?}."
558        );
559    }
560    #[test]
561    fn test_cmp_f64_beginning_nan() {
562        let mut s = [f64::NAN, 2.0, 3.0].map(TotalF64);
563        s.sort();
564        assert!(
565            s[0].0 == 2.0 && s[1].0 == 3.0 && s[2].0.is_nan(),
566            "Expected sorting [NAN, 2.0, 3.0] to give [2.0, 3.0, NAN], got {s:?}."
567        );
568    }
569
570    #[test]
571    fn test_cmp_f64_end_nan() {
572        let mut s = [1.0, 2.0, f64::NAN].map(TotalF64);
573        s.sort();
574        assert!(
575            s[0].0 == 1.0 && s[1].0 == 2.0 && s[2].0.is_nan(),
576            "Expected sorting [NAN, 2.0, 3.0] to give [2.0, 3.0, NAN], got {s:?}."
577        );
578    }
579
580    #[test]
581    fn test_cmp_f64_double_nana() {
582        let mut s = [f64::NAN, 2.0, f64::NAN].map(TotalF64);
583        s.sort();
584        assert!(
585            s[0].0 == 2.0 && s[1].0.is_nan() && s[2].0.is_nan(),
586            "Expected sorting [NAN, 2.0, NAN] to give [2.0, NAN, NAN], got {s:?}."
587        );
588    }
589
590    #[test]
591    fn test_cmp_f64_all_nana() {
592        let mut s = [f64::NAN, f64::NAN, f64::NAN].map(TotalF64);
593        s.sort();
594        assert!(
595            s[0].0.is_nan() && s[1].0.is_nan() && s[2].0.is_nan(),
596            "Expected sorting [NAN, NAN, NAN] to give [NAN, NAN, NAN], got {s:?}."
597        );
598    }
599
600    #[test]
601    fn test_same_as_external() {
602        let mut rng = SmallRng::seed_from_u64(42);
603
604        let s1: Vec<TotalF64> = (0..100000).map(|_| rng.gen()).map(TotalF64).collect();
605        let s2: Vec<TotalF64> = (0..100000).map(|_| rng.gen()).map(TotalF64).collect();
606        let res_external = kolmogorov_smirnov::test(&s1, &s2, 0.95);
607        let res_internal = two_sample_ks_test(&s1, &s2, 0.05).expect("Expected KS test to succeed");
608        println!(
609            "EXTERNAL:\n  statistic={:?}\n  is_rejected={:?}\n  reject_probability={:?}",
610            res_external.statistic, res_external.is_rejected, res_external.reject_probability
611        );
612        println!(
613            "INTERNAL:\n  statistic={:?}\n  is_rejected={:?}\n  reject_probability={:?}",
614            res_internal.statistic,
615            res_internal.is_rejected,
616            1.0 - res_internal.p_value
617        );
618        println!("{res_internal:?}");
619    }
620}