1use crate::dist::KsTwoAsymptotic;
2use crate::traits::Cdf;
3use num::Integer;
4use num::integer::binomial;
5
6pub fn ks_test<X, F>(xs: &[X], cdf: F) -> (f64, f64)
38where
39 X: Copy + PartialOrd,
40 F: Fn(X) -> f64,
41{
42 let mut xs_r: Vec<X> = xs.to_vec();
43 xs_r.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
44
45 let n: f64 = xs_r.len() as f64;
46 let d = xs_r.iter().enumerate().fold(0.0, |acc, (i, &x)| {
47 let diff = ((i as f64) / n - cdf(x)).abs();
48 if diff > acc { diff } else { acc }
49 });
50
51 let p = 1.0 - ks_cdf(xs.len(), d);
52 (d, p)
53}
54
55const KS_AUTO_CUTOVER: usize = 10_000;
56
57#[derive(Debug, Clone, Copy, Default)]
59pub enum KsMode {
60 Exact,
62 Asymptotic,
64 #[default]
66 Auto,
67}
68
69#[derive(Debug, Clone, Copy, Default)]
71pub enum KsAlternative {
72 #[default]
74 TwoSided,
75 Less,
78 Greater,
81}
82
83#[derive(Debug)]
85pub enum KsError {
86 EmptySlice,
88 TooLongForExact,
90}
91
92#[allow(clippy::many_single_char_names)]
119pub fn ks_two_sample<X>(
120 xs: &[X],
121 ys: &[X],
122 mode: KsMode,
123 alternative: KsAlternative,
124) -> Result<(f64, f64), KsError>
125where
126 X: Copy + PartialOrd,
127{
128 if xs.is_empty() || ys.is_empty() {
129 return Err(KsError::EmptySlice);
130 }
131
132 let n_x = xs.len();
133 let n_y = ys.len();
134 let n_x_f = xs.len() as f64;
135 let n_y_f = ys.len() as f64;
136
137 let mut xs = xs.to_vec();
138 xs.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
139
140 let mut ys = ys.to_vec();
141 ys.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
142
143 let mut cdf_x = Vec::new();
144 let mut cdf_y = Vec::new();
145
146 for x in [&xs[..], &ys[..]].concat() {
147 match xs.binary_search_by(|b| b.partial_cmp(&x).unwrap()) {
148 Ok(z) => cdf_x.push((z as f64) / n_x_f),
149 Err(z) => cdf_x.push((z as f64) / n_x_f),
150 }
151 match ys.binary_search_by(|b| b.partial_cmp(&x).unwrap()) {
152 Ok(z) => cdf_y.push((z as f64) / n_y_f),
153 Err(z) => cdf_y.push((z as f64) / n_y_f),
154 }
155 }
156
157 let (min_s, max_s) = cdf_x
158 .iter()
159 .zip(cdf_y.iter())
160 .map(|(cx, cy)| cx - cy)
161 .fold((f64::MAX, f64::MIN), |(min, max), z| {
162 let new_min = min.min(z);
163 let new_max = max.max(z);
164 (new_min, new_max)
165 });
166
167 let min_s = -min_s;
168
169 let stat = match alternative {
170 KsAlternative::Less => min_s,
171 KsAlternative::Greater => max_s,
172 KsAlternative::TwoSided => max_s.max(min_s),
173 };
174
175 let g = n_x.gcd(&n_y);
176 let g_f = g as f64;
177 let n_x_g = n_x_f / (g as f64);
178 let n_y_g = n_y_f / (g as f64);
179
180 let use_method = match mode {
181 KsMode::Asymptotic => KsMode::Asymptotic,
182 KsMode::Auto => {
183 if n_x.max(n_y) <= KS_AUTO_CUTOVER {
184 KsMode::Exact
185 } else {
186 KsMode::Asymptotic
187 }
188 }
189 KsMode::Exact => {
190 if n_x_g > f64::MAX / n_y_g {
191 return Err(KsError::TooLongForExact);
192 }
193 KsMode::Exact
194 }
195 };
196
197 match use_method {
198 KsMode::Exact => {
199 let lcm = (n_x_f / g_f) * n_y_f;
200 let h = (stat * lcm).round();
201 let stat = h / lcm;
202 if h == 0.0 {
203 Ok((stat, 1.0))
204 } else {
205 match alternative {
206 KsAlternative::TwoSided => {
207 if n_x == n_y {
208 Ok((stat, paths_outside_proportion(n_x, h)))
209 } else {
210 Ok((
211 stat,
212 1.0 - paths_inside_proportion(n_x, n_y, g, h),
213 ))
214 }
215 }
216 _ => {
217 if n_x == n_y {
218 let p = (0..(h as usize)).fold(1.0, |p, j| {
219 ((n_x - j) as f64) * p
220 / (n_x_f + (j as f64) + 1.0)
221 });
222 Ok((stat, p))
223 } else {
224 let paths = paths_outside(n_x, n_y, g, h);
225 let bin = binomial(n_x + n_y, n_x);
226 Ok((stat, (paths as f64) / (bin as f64)))
227 }
228 }
229 }
230 }
231 }
232 KsMode::Asymptotic => {
233 let ks_dist = KsTwoAsymptotic::new();
234
235 if let KsAlternative::TwoSided = alternative {
236 let en = (n_x_f * n_y_f / (n_y_f + n_x_f)).sqrt();
237 Ok((stat, 1.0 - ks_dist.cdf(&(en * stat))))
238 } else {
239 let m = n_x.max(n_y) as f64;
240 let n = n_x.min(n_y) as f64;
241
242 let z = (m * n / (m + n)).sqrt() * stat;
243 let expt = (-2.0 * z).mul_add(
244 z,
245 -2.0 * z * 2.0_f64.mul_add(n, m)
246 / (m * n * (m + n)).sqrt()
247 / 3.0,
248 );
249 let p = expt.exp();
250 Ok((stat, p))
251 }
252 }
253 KsMode::Auto => unreachable!(),
254 }
255}
256
257#[allow(clippy::many_single_char_names)]
258fn paths_outside(m: usize, n: usize, g: usize, h: f64) -> usize {
259 let (m, n) = (m.max(n), m.min(n));
260 let mg = m / g;
261 let ng = n / g;
262 let ng_f = ng as f64;
263 let mg_f = mg as f64;
264
265 let xj: Vec<usize> = (0..=n)
266 .map(|j| (mg_f.mul_add(j as f64, h) / ng_f).ceil() as usize)
267 .filter(|&x| x <= m)
268 .collect();
269
270 let lxj = xj.len();
271
272 if lxj == 0 {
273 binomial(m + n, n)
274 } else {
275 let mut b: Vec<usize> = (0..lxj).map(|_| 0).collect();
276 b[0] = 1;
277 for j in 1..lxj {
278 let mut bj = binomial(xj[j] + j, j);
279 for i in 0..j {
280 let bin = binomial(xj[j] - xj[i] + j - i, j - i);
281 let dec = bin * b[i];
282 bj -= dec;
283 }
284 b[j] = bj;
285 }
286 let mut num_paths = 0;
287 for j in 0..lxj {
288 let bin = binomial((m - xj[j]) + (n - j), n - j);
289 let term = b[j] * bin;
290 num_paths += term;
291 }
292 num_paths
293 }
294}
295
296fn paths_outside_proportion(n: usize, h: f64) -> f64 {
298 let mut p = 0.0;
299 let n_f = n as f64;
300 let k_max = (n_f / h) as usize;
301
302 for k in (0..=k_max).rev() {
303 let mut p1 = 1.0;
304 for j in 0..(h as usize) {
305 let j_f = j as f64;
306 let k_f = k as f64;
307 p1 = (k_f.mul_add(-h, n_f) - j_f) * p1
308 / (k_f.mul_add(h, n_f) + j_f + 1.0);
309 }
310 p = p1 * (1.0 - p);
311 }
312 2.0 * p
313}
314
315#[allow(clippy::many_single_char_names)]
317fn paths_inside_proportion(m: usize, n: usize, g: usize, h: f64) -> f64 {
318 let (m, n) = (m.max(n), m.min(n));
319 let n_f = n as f64;
320 let mg = m / g;
321 let ng = n / g;
322
323 let mg_f = mg as f64;
324 let ng_f = ng as f64;
325
326 let mut min_j = 0;
327 let mut max_j = (n + 1).min((h / (mg as f64)).ceil() as usize);
328 let mut cur_len = max_j - min_j;
329
330 let len_a = (n + 1).min(2 * max_j + 2);
331 let mut a: Vec<f64> = (0..len_a)
332 .map(|i| if i >= min_j && i < max_j { 1.0 } else { 0.0 })
333 .collect();
334 for i in 1..=m {
335 let i_f = i as f64;
336 let last_min_j = min_j;
337 let last_len = cur_len;
338
339 min_j =
340 ((ng_f.mul_add(i_f, -h) / mg_f).floor() + 1.0).max(0.0) as usize;
341 min_j = min_j.min(n);
342
343 max_j =
344 (((ng_f.mul_add(i_f, h) / mg_f).floor() + 1.0) as usize).max(n + 1);
345 if max_j <= min_j {
346 return 0.0;
347 }
348
349 for j in 0..(max_j - min_j) {
350 a[j] = a[min_j - last_min_j..max_j - last_min_j].iter().sum();
351 }
352 cur_len = max_j - min_j;
353 if last_len > cur_len {
354 for a_part in
355 a.iter_mut().skip(max_j - min_j).take(last_len - cur_len)
356 {
357 *a_part = 0.0;
358 }
359 }
360 let scaling_factor = i_f / (n_f + i_f);
361 a = a.iter().map(|x| x * scaling_factor).collect();
362 }
363 a[max_j - min_j - 1]
364}
365
366fn mmul(xs: &[Vec<f64>], ys: &[Vec<f64>]) -> Vec<Vec<f64>> {
367 let m = xs.len();
368 let mut zs = vec![vec![0.0; m]; m];
369 for i in 0..m {
370 for j in 0..m {
371 zs[i][j] =
372 (0..m).fold(0.0, |acc, k| xs[i][k].mul_add(ys[k][j], acc));
373 }
374 }
375 zs
376}
377
378fn mpow(xs: &[Vec<f64>], ea: i32, n: usize) -> (Vec<Vec<f64>>, i32) {
379 let m = xs.len();
380 if n == 1 {
381 (xs.to_owned(), ea)
382 } else {
383 let (mut zs, mut ev) = mpow(xs, ea, n / 2);
384 let ys = mmul(&zs, &zs);
385 let eb = 2 * ev;
386 if n % 2 == 0 {
387 zs = ys;
388 ev = eb;
389 } else {
390 zs = mmul(xs, &ys);
391 ev = ea + eb;
392 }
393 if zs[m / 2][m / 2] > 1E140 {
394 for zs_i in &mut zs {
395 zs_i.iter_mut().for_each(|z| (*z) *= 1E-140);
396 }
397 ev += 140;
398 }
399 (zs, ev)
400 }
401}
402
403#[allow(clippy::needless_range_loop)]
408#[allow(clippy::many_single_char_names)]
409fn ks_cdf(n: usize, d: f64) -> f64 {
410 let nf = n as f64;
411 let s: f64 = d * d * nf;
412 if s > 7.24 || (s > 3.76 && n > 99) {
413 2.0_f64.mul_add(
414 -(-(2.000_071 + 0.331 / nf.sqrt() + 1.409 / nf) * s).exp(),
415 1.0,
416 )
417 } else {
418 let k: usize = ((nf * d) as usize) + 1;
419 let m: usize = 2 * k - 1;
420 let h: f64 = nf.mul_add(-d, k as f64);
421
422 let mut hs = vec![vec![0.0; m]; m];
423 for i in 0..m {
424 for j in 0..m {
425 if ((i as i32) - (j as i32) + 1) >= 0 {
426 hs[i][j] = 1.0;
427 }
428 }
429 }
430
431 for i in 0..m {
432 hs[i][0] -= h.powi((i as i32) + 1);
433 hs[m - 1][i] -= h.powi((m as i32) - (i as i32));
434 }
435
436 hs[m - 1][0] += if 2.0_f64.mul_add(h, -1.0) > 0.0 {
437 2.0_f64.mul_add(h, -1.0).powi(m as i32)
438 } else {
439 0.0
440 };
441
442 for i in 0..m {
443 for j in 0..m {
444 if (i as i32) - (j as i32) + 1 > 0 {
445 for g in 1..=i - j + 1 {
446 hs[i][j] /= g as f64;
447 }
448 }
449 }
450 }
451
452 let (qs, mut eq) = mpow(&hs, 0, n);
453 let mut s = qs[k - 1][k - 1];
454 for i in 1..n {
455 s *= (i as f64) / nf;
456 if s < 1e-140 {
457 s *= 1e140;
458 eq -= 140;
459 }
460 }
461 s * 10.0_f64.powi(eq)
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::dist::Gaussian;
469
470 const TOL: f64 = 1E-12;
471
472 #[test]
473 fn ks_cdf_normal() {
474 assert::close(ks_cdf(10, 0.274), 0.628_479_615_456_504_3, TOL);
475 }
476
477 #[test]
478 fn ks_cdf_large_n() {
479 assert::close(ks_cdf(1000, 0.074), 0.999_967_173_529_903_7, TOL);
480 }
481
482 #[test]
483 fn ks_test_pval() {
484 let xs: Vec<f64> =
485 vec![0.42, 0.24, 0.86, 0.85, 0.82, 0.82, 0.25, 0.78, 0.13, 0.27];
486
487 let g = Gaussian::standard();
488 let cdf = |x: f64| g.cdf(&x);
489 let (ks, p) = ks_test(&xs, cdf);
490
491 assert::close(ks, 0.551_716_786_654_561_1, TOL);
492 assert::close(p, 0.002_180_450_252_694_976_5, TOL);
493 }
494
495 #[test]
496 fn ks_two_sample_exact() {
497 let xs = [
498 0.956_920_26,
499 1.134_881_2,
500 -0.765_792_39,
501 -0.580_656_53,
502 -0.051_223_93,
503 0.715_987_54,
504 1.398_735_28,
505 0.427_905_27,
506 1.845_807_64,
507 0.642_285_21,
508 ];
509
510 let ys = [
511 0.694_867_8,
512 -0.374_182_5,
513 0.366_572_79,
514 1.158_341_74,
515 -0.324_217_06,
516 -0.384_992_95,
517 1.449_769_91,
518 0.250_460_8,
519 -0.536_947_74,
520 1.422_219_93,
521 ];
522
523 let (stat, alpha) =
524 ks_two_sample(&xs, &ys, KsMode::Exact, KsAlternative::TwoSided)
525 .unwrap();
526 assert::close(stat, 0.3, 1E-8);
527 assert::close(alpha, 0.786_929_788_477_776_1, 1E-8);
528
529 let (stat, alpha) =
530 ks_two_sample(&xs, &ys, KsMode::Exact, KsAlternative::Less)
531 .unwrap();
532 assert::close(stat, 0.3, 1E-8);
533 assert::close(alpha, 0.419_580_419_580_419_53, 1E-8);
534
535 let (stat, alpha) =
536 ks_two_sample(&xs, &ys, KsMode::Exact, KsAlternative::Greater)
537 .unwrap();
538 assert::close(stat, 0.2, 1E-8);
539 assert::close(alpha, 0.681_818_181_818_181_8, 1E-8);
540 }
541
542 #[test]
543 fn ks_two_sample_asymp() {
544 let xs = [
545 0.956_920_26,
546 1.134_881_2,
547 -0.765_792_39,
548 -0.580_656_53,
549 -0.051_223_93,
550 0.715_987_54,
551 1.398_735_28,
552 0.427_905_27,
553 1.845_807_64,
554 0.642_285_21,
555 ];
556
557 let ys = [
558 0.694_867_8,
559 -0.374_182_5,
560 0.366_572_79,
561 1.158_341_74,
562 -0.324_217_06,
563 -0.384_992_95,
564 1.449_769_91,
565 0.250_460_8,
566 -0.536_947_74,
567 1.422_219_93,
568 ];
569
570 let (stat, alpha) = ks_two_sample(
571 &xs,
572 &ys,
573 KsMode::Asymptotic,
574 KsAlternative::TwoSided,
575 )
576 .unwrap();
577 assert::close(stat, 0.3, 1E-8);
578 assert::close(alpha, 0.759_097_838_420_394_8, 1E-8);
579
580 let (stat, alpha) =
581 ks_two_sample(&xs, &ys, KsMode::Asymptotic, KsAlternative::Less)
582 .unwrap();
583 assert::close(stat, 0.3, 1E-8);
584 assert::close(alpha, 0.301_194_211_912_202_14, 1E-8);
585
586 let (stat, alpha) =
587 ks_two_sample(&xs, &ys, KsMode::Asymptotic, KsAlternative::Greater)
588 .unwrap();
589 assert::close(stat, 0.2, 1E-8);
590 assert::close(alpha, 0.548_811_636_094_026_4, 1E-8);
591 }
592}