1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8#[derive(Clone, Copy, Debug)]
10pub struct Gaussian {
11 mu: f64,
12 sigma: f64,
13 norm: f64,
14}
15
16impl Gaussian {
17 #[inline]
22 pub fn new(mu: f64, sigma: f64) -> Self {
23 use core::f64::consts::PI;
24 should!(sigma > 0.0);
25 Gaussian {
26 mu,
27 sigma,
28 norm: (2.0 * PI).sqrt() * sigma,
29 }
30 }
31
32 #[inline(always)]
34 pub fn mu(&self) -> f64 {
35 self.mu
36 }
37
38 #[inline(always)]
40 pub fn sigma(&self) -> f64 {
41 self.sigma
42 }
43}
44
45impl Default for Gaussian {
46 #[inline]
47 fn default() -> Self {
48 Gaussian::new(0.0, 1.0)
49 }
50}
51
52impl distribution::Continuous for Gaussian {
53 fn density(&self, x: f64) -> f64 {
54 (-(x - self.mu).powi(2) / (2.0 * self.sigma * self.sigma)).exp() / self.norm
55 }
56}
57
58impl distribution::Distribution for Gaussian {
59 type Value = f64;
60
61 fn distribution(&self, x: f64) -> f64 {
62 use core::f64::consts::SQRT_2;
63 use special::Error;
64 (1.0 + ((x - self.mu) / (self.sigma * SQRT_2)).error()) / 2.0
65 }
66}
67
68impl distribution::Entropy for Gaussian {
69 #[inline]
70 fn entropy(&self) -> f64 {
71 use core::f64::consts::{E, PI};
72 0.5 * (2.0 * PI * E * self.sigma * self.sigma).ln()
73 }
74}
75
76impl distribution::Inverse for Gaussian {
77 #[inline(always)]
87 fn inverse(&self, p: f64) -> f64 {
88 self.mu + self.sigma * inverse(p)
89 }
90}
91
92impl distribution::Kurtosis for Gaussian {
93 #[inline]
94 fn kurtosis(&self) -> f64 {
95 0.0
96 }
97}
98
99impl distribution::Mean for Gaussian {
100 #[inline]
101 fn mean(&self) -> f64 {
102 self.mu
103 }
104}
105
106impl distribution::Median for Gaussian {
107 #[inline]
108 fn median(&self) -> f64 {
109 self.mu
110 }
111}
112
113impl distribution::Modes for Gaussian {
114 #[inline]
115 fn modes(&self) -> Vec<f64> {
116 vec![self.mu]
117 }
118}
119
120impl distribution::Sample for Gaussian {
121 #[inline]
131 fn sample<S>(&self, source: &mut S) -> f64
132 where
133 S: Source,
134 {
135 self.sigma * sample(source) + self.mu
136 }
137}
138
139impl distribution::Skewness for Gaussian {
140 #[inline]
141 fn skewness(&self) -> f64 {
142 0.0
143 }
144}
145
146impl distribution::Variance for Gaussian {
147 #[inline]
148 fn variance(&self) -> f64 {
149 self.sigma * self.sigma
150 }
151
152 #[inline]
153 fn deviation(&self) -> f64 {
154 self.sigma
155 }
156}
157
158impl core::iter::FromIterator<f64> for Gaussian {
159 fn from_iter<T: IntoIterator<Item = f64>>(iterator: T) -> Self {
161 let samples: Vec<f64> = iterator.into_iter().collect();
162 let mu = samples.iter().fold(0.0, |a, b| a + b) / samples.len() as f64;
163 let sigma = f64::sqrt(
164 samples
165 .iter()
166 .fold(0.0, |a, b| a + f64::powf(b - mu as f64, 2.0))
167 / (samples.len() - 1) as f64,
168 );
169 Gaussian::new(mu, sigma)
170 }
171}
172
173#[allow(clippy::excessive_precision)]
176pub fn inverse(p: f64) -> f64 {
177 use core::f64::{INFINITY, NEG_INFINITY};
178
179 should!((0.0..=1.0).contains(&p));
180
181 const CONST1: f64 = 0.180625;
182 const CONST2: f64 = 1.6;
183 const SPLIT1: f64 = 0.425;
184 const SPLIT2: f64 = 5.0;
185 const A: [f64; 8] = [
186 3.3871328727963666080e+00,
187 1.3314166789178437745e+02,
188 1.9715909503065514427e+03,
189 1.3731693765509461125e+04,
190 4.5921953931549871457e+04,
191 6.7265770927008700853e+04,
192 3.3430575583588128105e+04,
193 2.5090809287301226727e+03,
194 ];
195 const B: [f64; 8] = [
196 1.0000000000000000000e+00,
197 4.2313330701600911252e+01,
198 6.8718700749205790830e+02,
199 5.3941960214247511077e+03,
200 2.1213794301586595867e+04,
201 3.9307895800092710610e+04,
202 2.8729085735721942674e+04,
203 5.2264952788528545610e+03,
204 ];
205 const C: [f64; 8] = [
206 1.42343711074968357734e+00,
207 4.63033784615654529590e+00,
208 5.76949722146069140550e+00,
209 3.64784832476320460504e+00,
210 1.27045825245236838258e+00,
211 2.41780725177450611770e-01,
212 2.27238449892691845833e-02,
213 7.74545014278341407640e-04,
214 ];
215 const D: [f64; 8] = [
216 1.00000000000000000000e+00,
217 2.05319162663775882187e+00,
218 1.67638483018380384940e+00,
219 6.89767334985100004550e-01,
220 1.48103976427480074590e-01,
221 1.51986665636164571966e-02,
222 5.47593808499534494600e-04,
223 1.05075007164441684324e-09,
224 ];
225 const E: [f64; 8] = [
226 6.65790464350110377720e+00,
227 5.46378491116411436990e+00,
228 1.78482653991729133580e+00,
229 2.96560571828504891230e-01,
230 2.65321895265761230930e-02,
231 1.24266094738807843860e-03,
232 2.71155556874348757815e-05,
233 2.01033439929228813265e-07,
234 ];
235 const F: [f64; 8] = [
236 1.00000000000000000000e+00,
237 5.99832206555887937690e-01,
238 1.36929880922735805310e-01,
239 1.48753612908506148525e-02,
240 7.86869131145613259100e-04,
241 1.84631831751005468180e-05,
242 1.42151175831644588870e-07,
243 2.04426310338993978564e-15,
244 ];
245
246 #[inline(always)]
247 #[rustfmt::skip]
248 fn poly(c: &[f64], x: f64) -> f64 {
249 c[0] + x * (c[1] + x * (c[2] + x * (c[3] + x * (
250 c[4] + x * (c[5] + x * (c[6] + x * (c[7])))))))
251 }
252
253 if p <= 0.0 {
254 return NEG_INFINITY;
255 }
256 if 1.0 <= p {
257 return INFINITY;
258 }
259
260 let q = p - 0.5;
261
262 if (if q < 0.0 { -q } else { q }) <= SPLIT1 {
263 let x = CONST1 - q * q;
264 return q * poly(&A, x) / poly(&B, x);
265 }
266
267 let mut x = if q < 0.0 { p } else { 1.0 - p };
268
269 x = (-x.ln()).sqrt();
270
271 if x <= SPLIT2 {
272 x -= CONST2;
273 x = poly(&C, x) / poly(&D, x);
274 } else {
275 x -= SPLIT2;
276 x = poly(&E, x) / poly(&F, x);
277 }
278
279 if q < 0.0 {
280 -x
281 } else {
282 x
283 }
284}
285
286pub fn sample<S: Source>(source: &mut S) -> f64 {
288 loop {
289 let u = source.read::<u64>();
290
291 let i = (u & 0x7F) as usize;
292 let j = ((u >> 8) & 0xFFFFFF) as u32;
293 let s = if u & 0x80 != 0 { 1.0 } else { -1.0 };
294
295 if j < K[i] {
296 let x = j as f64 * W[i];
297 return s * x;
298 }
299
300 let (x, y) = if i < 127 {
301 let x = j as f64 * W[i];
302 let y = Y[i + 1] + (Y[i] - Y[i + 1]) * source.read::<f64>();
303 (x, y)
304 } else {
305 let x = R - (-source.read::<f64>()).ln_1p() / R;
306 let y = (-R * (x - 0.5 * R)).exp() * source.read::<f64>();
307 (x, y)
308 };
309
310 if y < (-0.5 * x * x).exp() {
311 return s * x;
312 }
313 }
314}
315
316const R: f64 = 3.44428647676;
317
318#[rustfmt::skip]
319const K: [u32; 128] = [
320 00000000, 12590644, 14272653, 14988939,
321 15384584, 15635009, 15807561, 15933577,
322 16029594, 16105155, 16166147, 16216399,
323 16258508, 16294295, 16325078, 16351831,
324 16375291, 16396026, 16414479, 16431002,
325 16445880, 16459343, 16471578, 16482744,
326 16492970, 16502368, 16511031, 16519039,
327 16526459, 16533352, 16539769, 16545755,
328 16551348, 16556584, 16561493, 16566101,
329 16570433, 16574511, 16578353, 16581977,
330 16585398, 16588629, 16591685, 16594575,
331 16597311, 16599901, 16602354, 16604679,
332 16606881, 16608968, 16610945, 16612818,
333 16614592, 16616272, 16617861, 16619363,
334 16620782, 16622121, 16623383, 16624570,
335 16625685, 16626730, 16627708, 16628619,
336 16629465, 16630248, 16630969, 16631628,
337 16632228, 16632768, 16633248, 16633671,
338 16634034, 16634340, 16634586, 16634774,
339 16634903, 16634972, 16634980, 16634926,
340 16634810, 16634628, 16634381, 16634066,
341 16633680, 16633222, 16632688, 16632075,
342 16631380, 16630598, 16629726, 16628757,
343 16627686, 16626507, 16625212, 16623794,
344 16622243, 16620548, 16618698, 16616679,
345 16614476, 16612071, 16609444, 16606571,
346 16603425, 16599973, 16596178, 16591995,
347 16587369, 16582237, 16576520, 16570120,
348 16562917, 16554758, 16545450, 16534739,
349 16522287, 16507638, 16490152, 16468907,
350 16442518, 16408804, 16364095, 16301683,
351 16207738, 16047994, 15704248, 15472926
352];
353
354#[rustfmt::skip]
355const Y: [f64; 128] = [
356 1.0000000000000, 0.96359862301100, 0.93628081335300, 0.91304110425300,
357 0.8922785066960, 0.87323935691900, 0.85549640763400, 0.83877892834900,
358 0.8229020836990, 0.80773273823400, 0.79317104551900, 0.77913972650500,
359 0.7655774360820, 0.75243445624800, 0.73966978767700, 0.72724912028500,
360 0.7151433774130, 0.70332764645500, 0.69178037703500, 0.68048276891000,
361 0.6694182972330, 0.65857233912000, 0.64793187618900, 0.63748525489600,
362 0.6272219914500, 0.61713261153200, 0.60720851746700, 0.59744187729600,
363 0.5878255314650, 0.57835291380300, 0.56901798419800, 0.55981517091100,
364 0.5507393208770, 0.54178565668200, 0.53294973914500, 0.52422743462800,
365 0.5156148863730, 0.50710848925300, 0.49870486747800, 0.49040085481200,
366 0.4821934769860, 0.47407993601000, 0.46605759612500, 0.45812397121400,
367 0.4502767134670, 0.44251360317100, 0.43483253947300, 0.42723153202200,
368 0.4197086933790, 0.41226223212000, 0.40489044654800, 0.39759171895500,
369 0.3903645103820, 0.38320735581600, 0.37611885978800, 0.36909769233400,
370 0.3621425852820, 0.35525232883400, 0.34842576841500, 0.34166180177600,
371 0.3349593763110, 0.32831748658800, 0.32173517206300, 0.31521151497000,
372 0.3087456383670, 0.30233670433800, 0.29598391232000, 0.28968649757100,
373 0.2834437297390, 0.27725491156000, 0.27111937764900, 0.26503649338700,
374 0.2590056539120, 0.25302628318300, 0.24709783313900, 0.24121978293200,
375 0.2353916382390, 0.22961293064900, 0.22388321712200, 0.21820207951800,
376 0.2125691242010, 0.20698398170900, 0.20144630649600, 0.19595577674500,
377 0.1905120942560, 0.18511498440600, 0.17976419618500, 0.17445950232400,
378 0.1692006994920, 0.16398760860000, 0.15882007519500, 0.15369796996400,
379 0.1486211893480, 0.14358965629500, 0.13860332114300, 0.13366216266900,
380 0.1287661893090, 0.12391544058200, 0.11910998874500, 0.11434994070300,
381 0.1096354402300, 0.10496667053300, 0.10034385723200, 0.09576727182660,
382 0.0912372357329, 0.08675412501270, 0.08231837593200, 0.07793049152950,
383 0.0735910494266, 0.06930071117420, 0.06506023352900, 0.06087048217450,
384 0.0567324485840, 0.05264727098000, 0.04861626071630, 0.04464093597690,
385 0.0407230655415, 0.03686472673860, 0.03306838393780, 0.02933699774110,
386 0.0256741818288, 0.02208443726340, 0.01857352005770, 0.01514905528540,
387 0.0118216532614, 0.00860719483079, 0.00553245272614, 0.00265435214565,
388];
389
390#[rustfmt::skip]
391const W: [f64; 128] = [
392 1.62318314817e-08, 2.16291505214e-08, 2.54246305087e-08, 2.84579525938e-08,
393 3.10340022482e-08, 3.33011726243e-08, 3.53439060345e-08, 3.72152672658e-08,
394 3.89509895720e-08, 4.05763964764e-08, 4.21101548915e-08, 4.35664624904e-08,
395 4.49563968336e-08, 4.62887864029e-08, 4.75707945735e-08, 4.88083237257e-08,
396 5.00063025384e-08, 5.11688950428e-08, 5.22996558616e-08, 5.34016475624e-08,
397 5.44775307871e-08, 5.55296344581e-08, 5.65600111659e-08, 5.75704813695e-08,
398 5.85626690412e-08, 5.95380306862e-08, 6.04978791776e-08, 6.14434034901e-08,
399 6.23756851626e-08, 6.32957121259e-08, 6.42043903937e-08, 6.51025540077e-08,
400 6.59909735447e-08, 6.68703634341e-08, 6.77413882848e-08, 6.86046683810e-08,
401 6.94607844804e-08, 7.03102820203e-08, 7.11536748229e-08, 7.19914483720e-08,
402 7.28240627230e-08, 7.36519550992e-08, 7.44755422158e-08, 7.52952223703e-08,
403 7.61113773308e-08, 7.69243740467e-08, 7.77345662086e-08, 7.85422956743e-08,
404 7.93478937793e-08, 8.01516825471e-08, 8.09539758128e-08, 8.17550802699e-08,
405 8.25552964535e-08, 8.33549196661e-08, 8.41542408569e-08, 8.49535474601e-08,
406 8.57531242006e-08, 8.65532538723e-08, 8.73542180955e-08, 8.81562980590e-08,
407 8.89597752521e-08, 8.97649321908e-08, 9.05720531451e-08, 9.13814248700e-08,
408 9.21933373471e-08, 9.30080845407e-08, 9.38259651738e-08, 9.46472835298e-08,
409 9.54723502847e-08, 9.63014833769e-08, 9.71350089201e-08, 9.79732621669e-08,
410 9.88165885297e-08, 9.96653446693e-08, 1.00519899658e-07, 1.01380636230e-07,
411 1.02247952126e-07, 1.03122261554e-07, 1.04003996769e-07, 1.04893609795e-07,
412 1.05791574313e-07, 1.06698387725e-07, 1.07614573423e-07, 1.08540683296e-07,
413 1.09477300508e-07, 1.10425042570e-07, 1.11384564771e-07, 1.12356564007e-07,
414 1.13341783071e-07, 1.14341015475e-07, 1.15355110887e-07, 1.16384981291e-07,
415 1.17431607977e-07, 1.18496049514e-07, 1.19579450872e-07, 1.20683053909e-07,
416 1.21808209468e-07, 1.22956391410e-07, 1.24129212952e-07, 1.25328445797e-07,
417 1.26556042658e-07, 1.27814163916e-07, 1.29105209375e-07, 1.30431856341e-07,
418 1.31797105598e-07, 1.33204337360e-07, 1.34657379914e-07, 1.36160594606e-07,
419 1.37718982103e-07, 1.39338316679e-07, 1.41025317971e-07, 1.42787873535e-07,
420 1.44635331499e-07, 1.46578891730e-07, 1.48632138436e-07, 1.50811780719e-07,
421 1.53138707402e-07, 1.55639532047e-07, 1.58348931426e-07, 1.61313325908e-07,
422 1.64596952856e-07, 1.68292495203e-07, 1.72541128694e-07, 1.77574279496e-07,
423 1.83813550477e-07, 1.92166040885e-07, 2.05295471952e-07, 2.22600839893e-07,
424];
425
426#[cfg(test)]
427mod tests {
428 use core::iter::FromIterator;
429
430 use alloc::{vec, vec::Vec};
431 use assert;
432 use prelude::*;
433
434 macro_rules! new(
435 ($mu:expr, $sigma:expr) => (Gaussian::new($mu, $sigma));
436 );
437
438 #[test]
439 fn density() {
440 let d = new!(1.0, 2.0);
441 let x = vec![
442 -4.0, -3.5, -3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5,
443 4.0,
444 ];
445 let p = vec![
446 8.764150246784270e-03,
447 1.586982591783371e-02,
448 2.699548325659403e-02,
449 4.313865941325577e-02,
450 6.475879783294587e-02,
451 9.132454269451096e-02,
452 1.209853622595717e-01,
453 1.505687160774022e-01,
454 1.760326633821498e-01,
455 1.933340584014246e-01,
456 1.994711402007164e-01,
457 1.933340584014246e-01,
458 1.760326633821498e-01,
459 1.505687160774022e-01,
460 1.209853622595717e-01,
461 9.132454269451096e-02,
462 6.475879783294587e-02,
463 ];
464
465 assert::close(
466 &x.iter().map(|&x| d.density(x)).collect::<Vec<_>>(),
467 &p,
468 1e-14,
469 );
470 }
471
472 #[test]
473 fn distribution() {
474 let d = new!(1.0, 2.0);
475 let x = vec![
476 -4.0, -3.5, -3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5,
477 4.0,
478 ];
479 let p = vec![
480 6.209665325776139e-03,
481 1.222447265504470e-02,
482 2.275013194817922e-02,
483 4.005915686381709e-02,
484 6.680720126885809e-02,
485 1.056497736668553e-01,
486 1.586552539314571e-01,
487 2.266273523768682e-01,
488 3.085375387259869e-01,
489 4.012936743170763e-01,
490 5.000000000000000e-01,
491 5.987063256829237e-01,
492 6.914624612740131e-01,
493 7.733726476231317e-01,
494 8.413447460685429e-01,
495 8.943502263331446e-01,
496 9.331927987311419e-01,
497 ];
498
499 assert::close(
500 &x.iter().map(|&x| d.distribution(x)).collect::<Vec<_>>(),
501 &p,
502 1e-14,
503 );
504 }
505
506 #[test]
507 fn entropy() {
508 use core::f64::consts::PI;
509 assert_eq!(new!(0.0, 1.0).entropy(), ((2.0 * PI).ln() + 1.0) / 2.0);
510 }
511
512 #[test]
513 fn inverse() {
514 use core::f64::{INFINITY, NEG_INFINITY};
515
516 let d = new!(-1.0, 0.25);
517 let p = vec![
518 0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65,
519 0.70, 0.75, 0.80, 0.85, 0.90, 0.95, 1.00,
520 ];
521 let x = vec![
522 NEG_INFINITY,
523 -1.411213406737868e+00,
524 -1.320387891386150e+00,
525 -1.259108347373447e+00,
526 -1.210405308393228e+00,
527 -1.168622437549020e+00,
528 -1.131100128177010e+00,
529 -1.096330116601892e+00,
530 -1.063336775783950e+00,
531 -1.031415336713768e+00,
532 -1.000000000000000e+00,
533 -9.685846632862315e-01,
534 -9.366632242160501e-01,
535 -9.036698833981082e-01,
536 -8.688998718229899e-01,
537 -8.313775624509796e-01,
538 -7.895946916067714e-01,
539 -7.408916526265525e-01,
540 -6.796121086138498e-01,
541 -5.887865932621319e-01,
542 INFINITY,
543 ];
544
545 assert::close(
546 &p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(),
547 &x,
548 1e-14,
549 );
550 }
551
552 #[test]
553 fn kurtosis() {
554 assert_eq!(new!(0.0, 2.0).kurtosis(), 0.0);
555 }
556
557 #[test]
558 fn mean() {
559 assert_eq!(new!(0.0, 1.0).mean(), 0.0);
560 }
561
562 #[test]
563 fn median() {
564 assert_eq!(new!(0.0, 2.0).median(), 0.0);
565 }
566
567 #[test]
568 fn modes() {
569 assert_eq!(new!(2.0, 5.0).modes(), vec![2.0]);
570 }
571
572 #[test]
573 fn skewness() {
574 assert_eq!(new!(0.0, 2.0).skewness(), 0.0);
575 }
576
577 #[test]
578 fn variance() {
579 assert_eq!(new!(0.0, 2.0).variance(), 4.0);
580 }
581
582 #[test]
583 fn deviation() {
584 assert_eq!(new!(0.0, 2.0).deviation(), 2.0);
585 }
586
587 #[test]
588 fn from_iter() {
589 let mut source = source::default(42);
590 let distribution = new!(1.0, 2.0);
591 let sampler = Independent(&distribution, &mut source);
592 let samples = sampler.take(10000).collect::<Vec<_>>();
593 let derived_distribution = Gaussian::from_iter(samples);
594
595 assert::close(derived_distribution.mu, distribution.mu, 0.1);
596 assert::close(derived_distribution.sigma, distribution.sigma, 0.1);
597 assert::close(derived_distribution.norm, distribution.norm, 0.1);
598 }
599}