1use std::cmp::Ordering;
21
22use rand_distr::num_traits::ToPrimitive;
23#[derive(Debug, Copy, Clone, PartialEq)]
44pub struct TotalF64(pub f64);
45
46impl Eq for TotalF64 {}
47
48impl PartialOrd for TotalF64 {
49 fn partial_cmp(&self, other: &TotalF64) -> Option<Ordering> {
50 Some(self.cmp(other))
51 }
52}
53
54impl Ord for TotalF64 {
55 fn cmp(&self, other: &Self) -> Ordering {
56 self.0.total_cmp(&other.0)
57 }
58}
59
60pub fn two_sample_ks_test<T: Ord + Clone + Copy>(
107 sample_1: &[T],
108 sample_2: &[T],
109 level: f64,
110) -> Result<TestResult, String> {
111 let statistic = compute_ks_statistic(sample_1, sample_2)?;
112 let p_value = ks_p_value(statistic, sample_1.len(), sample_2.len())?;
113 Ok(TestResult {
114 is_rejected: p_value < level,
115 statistic,
116 p_value,
117 level,
118 })
119}
120
121#[derive(Debug)]
129pub struct TestResult {
130 pub is_rejected: bool,
131 pub statistic: f64,
132 pub p_value: f64,
133 pub level: f64,
134}
135
136pub fn ks_p_value(statistic: f64, n1: usize, n2: usize) -> Result<f64, String> {
154 if n1 <= 7 || n2 <= 7 {
155 return Err(("Requires sample sizes > 7 for accuracy.").to_string());
156 }
157
158 let factor = ((n1 as f64 * n2 as f64) / (n1 as f64 + n2 as f64)).sqrt();
159 let term = factor * statistic;
160
161 let p_value = qks(term)?;
163 assert!((0.0..=1.0).contains(&p_value));
164
165 Ok(p_value)
166}
167
168pub fn compute_ks_statistic<T: Ord + Clone + Copy>(
188 sample_1: &[T],
189 sample_2: &[T],
190) -> Result<f64, String> {
191 if sample_1.is_empty() {
192 return Err("Expected sample_1 to be non-empty.".into());
193 }
194 if sample_2.is_empty() {
195 return Err("Expected sample_2 to be non-empty.".into());
196 }
197
198 let mut _sample_1 = sample_1.to_vec();
200 let mut _sample_2 = sample_2.to_vec();
201
202 _sample_1.sort_unstable();
203 _sample_2.sort_unstable();
204
205 let (n, m) = (_sample_1.len(), _sample_2.len());
206 let (n_i32, m_i32) = (n as i32, m as i32);
207 let (n_f64, m_f64) = (n as f64, m as f64);
208
209 let (mut i, mut j) = (-1_i32, -1_i32);
210 let mut max_diff: f64 = 0.0;
211 let mut cur_x: T = _sample_1[0].min(_sample_2[0]);
212
213 while i + 1 < n_i32 || j + 1 < m_i32 {
214 advance(&mut i, n_i32, &_sample_1, &cur_x);
215 advance(&mut j, m_i32, &_sample_2, &cur_x);
216
217 let fi = if i < 0 { 0.0 } else { (i + 1) as f64 / n_f64 };
218 let fj = if j < 0 { 0.0 } else { (j + 1) as f64 / m_f64 };
219
220 max_diff = max_diff.max((fj - fi).abs());
221
222 let ip = (i + 1).to_usize().unwrap();
223 let jp = (j + 1).to_usize().unwrap();
224 if ip < n && jp < m {
225 cur_x = _sample_1[ip].min(_sample_2[jp]);
226 } else {
227 break;
228 }
229 }
230 Ok(max_diff)
231}
232
233fn advance<T: Ord + Clone>(i: &mut i32, n: i32, sample: &[T], cur_x: &T) {
250 while *i + 1 < n {
251 let next_val = &sample[(*i + 1) as usize];
252 if *next_val <= *cur_x {
253 *i += 1;
254 } else {
255 break;
256 }
257 }
258}
259
260pub fn pks(z: f64) -> Result<f64, String> {
286 if z < 0. {
287 return Err("Bad z for KS distribution function.".into());
288 }
289 if z == 0. {
290 return Ok(0.);
291 }
292 if z < 1.18 {
293 let y = (-1.233_700_550_136_169_7 / z.powi(2)).exp();
294 return Ok(2.256_758_334_191_025
295 * (-y.ln()).sqrt()
296 * (y + y.powf(9.) + y.powf(25.) + y.powf(49.)));
297 }
298 let x = (-2. * z.powi(2)).exp();
299 Ok(1. - 2. * (x - x.powf(4.) + x.powf(9.)))
300}
301
302pub fn qks(z: f64) -> Result<f64, String> {
328 if z < 0. {
329 return Err("Bad z for KS distribution function.".into());
330 }
331 if z == 0. {
332 return Ok(1.);
333 }
334 if z < 1.18 {
335 return Ok(1. - pks(z)?);
336 }
337 let x = (-2. * z.powi(2)).exp();
338 Ok(2. * (x - x.powf(4.) + x.powf(9.)))
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 use rand::{rngs::SmallRng, Rng, SeedableRng};
346
347 #[test]
348 fn test_ks_p_value_too_few() {
349 let res = ks_p_value(1., 1, 1);
350 assert!(res.is_err(), "Expected to get an Err object");
351 }
352
353 #[test]
354 fn test_ks_p_value_ok() {
355 let res = ks_p_value(1., 8, 8);
356 assert!(res.is_ok(), "Expected to get a Ok object");
357 }
358
359 #[test]
360 fn test_ks_simple_case() {
361 let s1 = [1.0, 2.0, 3.0].map(TotalF64);
363 let s2 = [2.0, 3.0, 4.0].map(TotalF64);
364 let d = compute_ks_statistic(&s1, &s2).unwrap();
365 assert!((d - 1.0 / 3.0).abs() < 1e-9, "Expected D ~ 1/3, got {}", d);
366 }
367
368 #[test]
369 fn test_ks_identical_samples() {
370 let s1 = [1.0, 2.0, 3.0].map(TotalF64);
372 let s2 = [1.0, 2.0, 3.0].map(TotalF64);
373 let d = compute_ks_statistic(&s1, &s2).unwrap();
374 assert_eq!(d, 0.0, "KS should be 0 for identical samples.");
375 }
376
377 #[test]
378 fn test_ks_non_overlapping() {
379 let s1 = [1.0, 2.0, 3.0].map(TotalF64);
381 let s2 = [10.0, 11.0, 12.0].map(TotalF64);
382 let d = compute_ks_statistic(&s1, &s2).unwrap();
383 assert_eq!(d, 1.0, "Non-overlapping samples => D=1.");
384 }
385
386 #[test]
387 fn test_ks_single_element() {
388 let s1 = [TotalF64(2.0)];
390 let s2 = [TotalF64(5.0)];
391 let d = compute_ks_statistic(&s1, &s2).unwrap();
392 assert_eq!(d, 1.0);
393 }
394
395 #[test]
396 fn test_ks_repeated_values() {
397 let s1 = [1.0, 1.0, 1.0, 2.0, 2.0].map(TotalF64);
399 let s2 = [1.0, 1.0, 2.0, 2.0, 2.0].map(TotalF64);
400 let d = compute_ks_statistic(&s1, &s2).unwrap();
401 assert!((d - 0.2).abs() < 1e-6, "Expected ~0.2, got {}", d);
402 }
403
404 #[test]
405 fn test_ks_partial_overlap() {
406 let s1 = [0.0, 1.0, 2.0, 3.0].map(TotalF64);
408 let s2 = [1.0, 2.0, 3.0, 4.0].map(TotalF64);
409 let d = compute_ks_statistic(&s1, &s2).unwrap();
410 assert!((d - 0.25).abs() < 1e-9, "Expected 0.25, got {}", d);
411 }
412
413 #[test]
414 fn test_ks_rep_similar() {
415 let s1: Vec<TotalF64> = [0.12, 0.25, 0.25, 0.78, 0.99, 0.33, 0.15, 0.5]
417 .iter()
418 .cycle()
419 .take(8 * 20)
420 .copied()
421 .map(TotalF64)
422 .collect();
423 let s2: Vec<TotalF64> = [0.12, 0.25, 0.25, 0.78, 0.99, 0.33, 0.15, 0.51]
424 .iter()
425 .cycle()
426 .take(8 * 20)
427 .copied()
428 .map(TotalF64)
429 .collect();
430
431 let result = two_sample_ks_test(&s1, &s2, 0.05).unwrap();
432 assert!((result.statistic - 0.125).abs() < 1e-9, "D mismatch");
433 assert!((result.p_value - 0.1641).abs() < 1e-4, "p-value mismatch");
434 }
435
436 #[test]
437 fn test_ks_empty_1() {
438 let s1 = [];
439 let s2 = [1.0, 2.0, 3.0, 4.0].map(TotalF64);
440 let res = compute_ks_statistic(&s1, &s2);
441 assert!(res.is_err(), "Expected compute_ks_statistic(...) to return an error since the first list is empty, got {:?}.", res);
442 }
443
444 #[test]
445 fn test_ks_empty_2() {
446 let s1 = [1.0, 2.0, 3.0, 4.0].map(TotalF64);
447 let s2 = [];
448 let res = compute_ks_statistic(&s1, &s2);
449 assert!(res.is_err(), "Expected compute_ks_statistic(...) to return an error since the second list is empty, got {:?}.", res);
450 }
451
452 #[test]
453 fn test_bad_z_for_pks() {
454 let res = pks(-1.0);
455 assert!(
456 res.is_err(),
457 "Expected pks(-1.0) to return an error, got {:?}.",
458 res
459 );
460 }
461
462 #[test]
463 fn test_pks_zero() {
464 match pks(0.0) {
465 Err(msg) => panic!("Expected pks(0.0) == 0, got error message {:?}.", msg),
466 Ok(val) => assert!(val == 0.0, "Expected pks(0.0) == 0, got {:?}.", val),
467 }
468 }
469
470 #[test]
471 fn test_pks_large_1() {
472 match pks(1.23) {
473 Err(msg) => panic!(
474 "Expected pks(1.23), to not error out, got error message {:?}.",
475 msg
476 ),
477 Ok(val) => assert!(
478 (val - 0.9029731024047791).abs() < 1e-8,
479 "Expected pks(1.23) ~= 0.9029731024047791, got {:?}.",
480 val
481 ),
482 }
483 }
484
485 #[test]
486 fn test_pks_large_2() {
487 match pks(2.34) {
488 Err(msg) => panic!(
489 "Expected pks(2.34), to not error out, got error message {:?}.",
490 msg
491 ),
492 Ok(val) => assert!(
493 (val - 0.9999649260833611).abs() < 1e-8,
494 "Expected pks(2.34) ~= 0.9999649260833611, got {:?}.",
495 val
496 ),
497 }
498 }
499
500 #[test]
501 fn test_pks_large_3() {
502 match pks(3.45) {
503 Err(msg) => panic!(
504 "Expected pks(3.45), to not error out, got error message {:?}.",
505 msg
506 ),
507 Ok(val) => assert!(
508 (val - 1.0).abs() < 1e-8,
509 "Expected pks(3.45) ~= 1.0, got {:?}.",
510 val
511 ),
512 }
513 }
514
515 #[test]
516 fn test_qks_zero() {
517 match qks(0.0) {
518 Err(msg) => panic!(
519 "Expected qks(0.0), to not error out, got error message {:?}.",
520 msg
521 ),
522 Ok(val) => assert!(val == 1.0, "Expected qks(0.0) = 0.0, got {:?}.", val),
523 }
524 }
525
526 #[test]
527 fn test_qks_large() {
528 match qks(1.2) {
529 Err(msg) => panic!(
530 "Expected qks(1.2), to not error out, got error message {:?}.",
531 msg
532 ),
533 Ok(val) => assert!(
534 (val - 0.11224966667072497).abs() < 1e-8,
535 "Expected qks(1.2) ~= 00.11224966667072497, got {:?}.",
536 val
537 ),
538 }
539 }
540
541 #[test]
542 fn test_bad_z_for_qks() {
543 let res = qks(-1.0);
544 assert!(
545 res.is_err(),
546 "Expected qks(-1.0) to return an error, got {:?}.",
547 res
548 );
549 }
550
551 #[test]
552 fn test_cmp_f64_middle_nan() {
553 let mut s = [1.0, f64::NAN, 3.0];
554 s.sort_by(|a, b| a.total_cmp(b));
555 assert!(
556 s[0] == 1.0 && s[1] == 3.0 && s[2].is_nan(),
557 "Expected sorting [1.0, NAN, 3.0] to give [1.0, 3.0, NAN], got {s:?}."
558 );
559 }
560 #[test]
561 fn test_cmp_f64_beginning_nan() {
562 let mut s = [f64::NAN, 2.0, 3.0].map(TotalF64);
563 s.sort();
564 assert!(
565 s[0].0 == 2.0 && s[1].0 == 3.0 && s[2].0.is_nan(),
566 "Expected sorting [NAN, 2.0, 3.0] to give [2.0, 3.0, NAN], got {s:?}."
567 );
568 }
569
570 #[test]
571 fn test_cmp_f64_end_nan() {
572 let mut s = [1.0, 2.0, f64::NAN].map(TotalF64);
573 s.sort();
574 assert!(
575 s[0].0 == 1.0 && s[1].0 == 2.0 && s[2].0.is_nan(),
576 "Expected sorting [NAN, 2.0, 3.0] to give [2.0, 3.0, NAN], got {s:?}."
577 );
578 }
579
580 #[test]
581 fn test_cmp_f64_double_nana() {
582 let mut s = [f64::NAN, 2.0, f64::NAN].map(TotalF64);
583 s.sort();
584 assert!(
585 s[0].0 == 2.0 && s[1].0.is_nan() && s[2].0.is_nan(),
586 "Expected sorting [NAN, 2.0, NAN] to give [2.0, NAN, NAN], got {s:?}."
587 );
588 }
589
590 #[test]
591 fn test_cmp_f64_all_nana() {
592 let mut s = [f64::NAN, f64::NAN, f64::NAN].map(TotalF64);
593 s.sort();
594 assert!(
595 s[0].0.is_nan() && s[1].0.is_nan() && s[2].0.is_nan(),
596 "Expected sorting [NAN, NAN, NAN] to give [NAN, NAN, NAN], got {s:?}."
597 );
598 }
599
600 #[test]
601 fn test_same_as_external() {
602 let mut rng = SmallRng::seed_from_u64(42);
603
604 let s1: Vec<TotalF64> = (0..100000).map(|_| rng.gen()).map(TotalF64).collect();
605 let s2: Vec<TotalF64> = (0..100000).map(|_| rng.gen()).map(TotalF64).collect();
606 let res_external = kolmogorov_smirnov::test(&s1, &s2, 0.95);
607 let res_internal = two_sample_ks_test(&s1, &s2, 0.05).expect("Expected KS test to succeed");
608 println!(
609 "EXTERNAL:\n statistic={:?}\n is_rejected={:?}\n reject_probability={:?}",
610 res_external.statistic, res_external.is_rejected, res_external.reject_probability
611 );
612 println!(
613 "INTERNAL:\n statistic={:?}\n is_rejected={:?}\n reject_probability={:?}",
614 res_internal.statistic,
615 res_internal.is_rejected,
616 1.0 - res_internal.p_value
617 );
618 println!("{res_internal:?}");
619 }
620}