1#[cfg(feature = "serde1")]
5use serde::{Deserialize, Serialize};
6
7use crate::consts::TWO_PI;
8use crate::impl_display;
9use crate::traits::{
10 Cdf, ContinuousDistr, HasDensity, InverseCdf, Parameterized, Sampleable,
11 Support,
12};
13use rand::Rng;
14use std::f64::consts::{PI, SQRT_2};
15
16#[inline]
17fn within_tol(x: f64, y: f64, atol: f64, rtol: f64) -> bool {
18 let diff = (x - y).abs();
19 diff <= rtol.mul_add(y.abs(), atol)
20}
21
22#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
43#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
44#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
45pub struct KsTwoAsymptotic {}
46
47impl Parameterized for KsTwoAsymptotic {
48 type Parameters = ();
49
50 fn emit_params(&self) -> Self::Parameters {}
51
52 fn from_params(_params: Self::Parameters) -> Self {
53 Self {}
54 }
55}
56
57struct CdfPdf {
58 cdf: f64,
59 pdf: f64,
60}
61
62const MIN_EXP: f64 = -746.0;
63const MIN_THRESHOLD: f64 = PI / (8.0 * -MIN_EXP);
64const KOLMOGO_CUTOVER: f64 = 0.82;
65const MAX_ITERS: usize = 2000;
66
67impl KsTwoAsymptotic {
68 #[inline]
70 #[must_use]
71 pub fn new() -> Self {
72 Self {}
73 }
74
75 #[allow(clippy::many_single_char_names)]
76 fn compute(x: f64) -> CdfPdf {
77 if x <= MIN_THRESHOLD {
78 CdfPdf { cdf: 0.0, pdf: 0.0 }
79 } else if x <= KOLMOGO_CUTOVER {
80 let mut p: f64 = 1.0;
86 let mut d: f64 = 0.0;
87
88 let w = TWO_PI.sqrt() / x;
89 let logu8 = -PI * PI / (x * x);
90 let u = (logu8 / 8.0).exp();
91
92 if u == 0.0 {
93 let log_p = logu8 / 8.0 + w.ln();
94 p = log_p.exp();
95 } else {
96 let u_8 = logu8.exp();
97 let u_8cub = u_8.powi(3);
98
99 p = u_8cub.mul_add(p, 1.0);
100 d = u_8cub.mul_add(d, 25.0);
101
102 p = u_8cub.mul_add(p, 1.0);
103 d = u_8cub.mul_add(d, 9.0);
104
105 p = u_8cub.mul_add(p, 1.0);
106 d = u_8cub.mul_add(d, 1.0);
107
108 d = (PI * PI / (4.0 * x * x)).mul_add(d, -p);
109 d *= w * u / x;
110 p *= w * u;
111 }
112
113 CdfPdf {
114 cdf: p.clamp(0.0, 1.0),
115 pdf: d.max(0.0),
116 }
117 } else {
118 let mut p: f64 = 1.0;
119 let mut d: f64 = 0.0;
120 let logv = -2.0 * x * x;
126 let v = logv.exp();
127 let vsq = v * v;
132 let v3 = v.powi(3);
133 let mut vpwr;
134
135 vpwr = v3 * v3 * v;
136 p = vpwr.mul_add(-p, 1.0);
137 d = 3.0_f64.mul_add(3.0, -vpwr * d);
138
139 vpwr = v3 * vsq;
140 p = vpwr.mul_add(-p, 1.0);
141 d = 2.0_f64.mul_add(2.0, -vpwr * d);
142
143 vpwr = v3;
144 p = vpwr.mul_add(-p, 1.0);
145 d = 1.0_f64.mul_add(1.0, -vpwr * d);
146
147 p *= 2.0 * v;
148 d *= 8.0 * v * x;
149 p = p.max(0.0);
150 let cdf = (1.0 - p).clamp(0.0, 1.0);
151 let pdf = d.max(0.0);
152 CdfPdf { cdf, pdf }
153 }
154 }
155
156 #[allow(clippy::many_single_char_names)]
160 fn inverse(sf: f64, cdf: f64) -> f64 {
161 if !(sf >= 0.0 && cdf >= 0.0 && sf <= 1.0 && cdf <= 1.0)
162 || (1.0 - cdf - sf).abs() > 4.0 * f64::EPSILON
163 {
164 f64::NAN
165 } else if cdf == 0.0 {
166 0.0
167 } else if sf == 0.0 {
168 f64::INFINITY
169 } else {
170 let mut x: f64;
171 let mut a: f64;
172 let mut b: f64;
173
174 if cdf <= 0.5 {
175 let logcdf = cdf.ln();
176 let log_sqrt_2pi: f64 = (2.0 * PI).sqrt().ln();
177
178 a = PI
179 / (2.0
180 * SQRT_2
181 * (-(logcdf + logcdf / 2.0 - log_sqrt_2pi)).sqrt());
182 b = PI
183 / (2.0 * SQRT_2 * (-(logcdf + 0.0 - log_sqrt_2pi)).sqrt());
184 a = PI
185 / (2.0
186 * SQRT_2
187 * (-(logcdf + a.ln() - log_sqrt_2pi)).sqrt());
188 b = PI
189 / (2.0
190 * SQRT_2
191 * (-(logcdf + b.ln() - log_sqrt_2pi)).sqrt());
192 x = (a + b) / 2.0;
193 } else {
194 const JITTERB: f64 = f64::EPSILON * 256.0;
195 let pba = sf / (2.0 * (1.0 - (-4.0_f64).exp()));
196 let pbb = sf * (1.0 - JITTERB) / 2.0;
197
198 a = (-0.5 * pba.ln()).sqrt();
199 b = (-0.5 * pbb.ln()).sqrt();
200
201 let q = sf / 2.0;
202 let q2 = q * q;
203 let q3 = q2 * q;
204
205 let q0 = q3.mul_add(
206 q3.mul_add(
207 q2.mul_add(
208 q.mul_add(
209 q2.mul_add(140.0_f64.mul_add(q, -13.0), 22.0),
210 -1.0,
211 ),
212 4.0,
213 ),
214 1.0,
215 ),
216 1.0,
217 );
218 let q0 = q0 * q;
219
220 x = (-(q0).ln() / 2.0).sqrt();
221 if x < a || x > b {
222 x = (a + b) / 2.0;
223 }
224 }
225 assert!(a <= b, "{a} > {b}");
226
227 for _ in 0..MAX_ITERS {
228 let x0 = x;
229 let c = Self::compute(x0);
230 let df = if cdf < 0.5 {
231 cdf - c.cdf
232 } else {
233 (1.0 - c.cdf) - sf
234 };
235
236 if df == 0.0 {
237 break;
238 }
239
240 if df > 0.0 && x > a {
241 a = x;
242 } else if df < 0.0 && x < b {
243 b = x;
244 }
245
246 let dfdx = -c.pdf;
247 if dfdx.abs() <= f64::EPSILON {
248 x = (a + b) / 2.0;
249 } else {
250 let t = df / dfdx;
251 x = x0 - t;
252 }
253
254 if x >= a && x <= b {
255 if within_tol(x, x0, f64::EPSILON, f64::EPSILON * 2.0) {
256 break;
257 } else if (x - a).abs() < f64::EPSILON
258 || (x - b).abs() < f64::EPSILON
259 {
260 x = (a + b) / 2.0;
261 if (x - a).abs() > f64::EPSILON
262 || (x - b).abs() < f64::EPSILON
263 {
264 break;
265 }
266 }
267 } else {
268 x = (a + b) / 2.0;
269 if within_tol(x, x0, f64::EPSILON, f64::EPSILON * 2.0) {
270 break;
271 }
272 }
273 }
274
275 x
276 }
277 }
278}
279
280impl From<&KsTwoAsymptotic> for String {
281 fn from(_kstwobign: &KsTwoAsymptotic) -> String {
282 "KsTwoAsymptotic()".to_string()
283 }
284}
285
286impl_display!(KsTwoAsymptotic);
287
288macro_rules! impl_traits {
289 ($kind:ty) => {
290 impl HasDensity<$kind> for KsTwoAsymptotic {
291 fn ln_f(&self, x: &$kind) -> f64 {
292 Self::compute(*x as f64).pdf.ln()
293 }
294 }
295
296 impl Sampleable<$kind> for KsTwoAsymptotic {
297 fn draw<R: Rng>(&self, rng: &mut R) -> $kind {
298 let p: f64 = rng.random();
299 self.invcdf(p)
300 }
301 }
302
303 impl Support<$kind> for KsTwoAsymptotic {
304 fn supports(&self, x: &$kind) -> bool {
305 *x >= 0.0 && *x <= 1.0
306 }
307 }
308
309 impl ContinuousDistr<$kind> for KsTwoAsymptotic {}
310
311 impl Cdf<$kind> for KsTwoAsymptotic {
312 fn cdf(&self, x: &$kind) -> f64 {
313 Self::compute(*x as f64).cdf
314 }
315 }
316
317 impl InverseCdf<$kind> for KsTwoAsymptotic {
318 fn invcdf(&self, p: f64) -> $kind {
319 Self::inverse(1.0 - p, p) as $kind
320 }
321 }
322 };
323}
324
325impl_traits!(f32);
326impl_traits!(f64);
327
328#[cfg(test)]
329mod test {
330 use super::*;
331 use crate::misc::ks_test;
332 use rand::SeedableRng;
333 use rand_xoshiro::Xoshiro256Plus;
334 const TOL: f64 = 1E-5;
335
336 #[test]
337 fn ln_f() {
338 let ks = KsTwoAsymptotic::new();
339 let xs: [f64; 10] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
340 let ys: [f64; 10] = [
341 -112.341_671_780_246_46,
342 -22.599_002_391_501_188,
343 -7.106_946_223_524_299,
344 -2.290_405_151_293_747_6,
345 -0.446_939_110_544_928_63,
346 0.280_750_537_963_607_1,
347 0.509_665_510_735_263_6,
348 0.486_752_791_638_526_86,
349 0.322_610_211_790_590_3,
350 0.069_478_074_891_398,
351 ];
352
353 xs.iter().zip(ys.iter()).for_each(|(x, &y)| {
354 let y_est: f64 = ks.ln_f(x);
355 assert::close(y_est, y, TOL);
356 });
357 }
358
359 #[test]
360 fn cdf() {
361 let ks = KsTwoAsymptotic::new();
362 let xs: [f64; 10] = [
363 0.1,
364 0.311_111_111_111_111_1,
365 0.522_222_222_222_222_3,
366 0.733_333_333_333_333_3,
367 0.944_444_444_444_444_4,
368 1.155_555_555_555_555_7,
369 1.366_666_666_666_666_7,
370 1.577_777_777_777_778,
371 1.788_888_888_888_889,
372 2.0,
373 ];
374 let ys: [f64; 10] = [
375 6.609_305_242_245_699e-53,
376 2.347_446_802_363_517e-5,
377 0.052_070_628_335_016_79,
378 0.344_735_508_258_350_1,
379 0.665_645_486_961_299_3,
380 0.861_626_906_810_242,
381 0.952_280_824_435_727_8,
382 0.986_234_895_897_317_9,
383 0.996_677_705_889_700_3,
384 0.999_329_074_744_220_3,
385 ];
386 xs.iter().zip(ys.iter()).for_each(|(x, &y)| {
387 let y_est: f64 = ks.cdf(x);
388 assert::close(y_est, y, TOL);
389 });
390 }
391
392 #[test]
393 fn invcdf() {
394 let ks = KsTwoAsymptotic::new();
395 let xs: [f64; 10] = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
396 let ys: [f64; 10] = [
397 0.0,
398 0.571_173_265_106_340_1,
399 0.644_812_606_166_356_7,
400 0.706_732_652_306_898_1,
401 0.766_185_555_561_768_2,
402 0.827_573_555_189_905_9,
403 0.894_764_454_985_119_6,
404 0.973_063_375_332_372_6,
405 1.072_749_174_939_648,
406 1.223_847_870_217_082_5,
407 ];
408
409 xs.iter().zip(ys.iter()).rev().for_each(|(&x, &y)| {
410 let y_est: f64 = ks.invcdf(x);
411 assert::close(y_est, y, TOL);
412 });
413 }
414
415 #[test]
416 fn draw() {
417 let ks = KsTwoAsymptotic::new();
418 let mut rng = Xoshiro256Plus::seed_from_u64(0x1234);
419 let sample: Vec<f64> = ks.sample(1000, &mut rng);
420 let (_, alpha) = ks_test(&sample, |x| ks.cdf(&x));
421 assert!(alpha >= 0.05);
422 }
423
424 #[test]
425 fn emit_and_from_params_are_identity() {
426 let dist_a = KsTwoAsymptotic::new();
427 let dist_b = KsTwoAsymptotic::from_params(dist_a.emit_params());
428 assert_eq!(dist_a, dist_b);
429 }
430}