probability/distribution/
gaussian.rs

1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8/// A Gaussian distribution.
9#[derive(Clone, Copy, Debug)]
10pub struct Gaussian {
11    mu: f64,
12    sigma: f64,
13    norm: f64,
14}
15
16impl Gaussian {
17    /// Create a Gaussian distribution with mean `mu` and standard deviation
18    /// `sigma`.
19    ///
20    /// It should hold that `sigma > 0`.
21    #[inline]
22    pub fn new(mu: f64, sigma: f64) -> Self {
23        use core::f64::consts::PI;
24        should!(sigma > 0.0);
25        Gaussian {
26            mu,
27            sigma,
28            norm: (2.0 * PI).sqrt() * sigma,
29        }
30    }
31
32    /// Return the mean.
33    #[inline(always)]
34    pub fn mu(&self) -> f64 {
35        self.mu
36    }
37
38    /// Return the standard deviation.
39    #[inline(always)]
40    pub fn sigma(&self) -> f64 {
41        self.sigma
42    }
43}
44
45impl Default for Gaussian {
46    #[inline]
47    fn default() -> Self {
48        Gaussian::new(0.0, 1.0)
49    }
50}
51
52impl distribution::Continuous for Gaussian {
53    fn density(&self, x: f64) -> f64 {
54        (-(x - self.mu).powi(2) / (2.0 * self.sigma * self.sigma)).exp() / self.norm
55    }
56}
57
58impl distribution::Distribution for Gaussian {
59    type Value = f64;
60
61    fn distribution(&self, x: f64) -> f64 {
62        use core::f64::consts::SQRT_2;
63        use special::Error;
64        (1.0 + ((x - self.mu) / (self.sigma * SQRT_2)).error()) / 2.0
65    }
66}
67
68impl distribution::Entropy for Gaussian {
69    #[inline]
70    fn entropy(&self) -> f64 {
71        use core::f64::consts::{E, PI};
72        0.5 * (2.0 * PI * E * self.sigma * self.sigma).ln()
73    }
74}
75
76impl distribution::Inverse for Gaussian {
77    /// Compute the inverse of the cumulative distribution function.
78    ///
79    /// ## References
80    ///
81    /// 1. M. J. Wichura, “Algorithm as 241: The percentage points of the normal
82    ///    distribution,” Journal of the Royal Statistical Society. Series C
83    ///    (Applied Statistics), vol. 37, no. 3, pp. pp. 477–484, 1988.
84    ///
85    /// 2. <http://people.sc.fsu.edu/~jburkardt/c_src/asa241/asa241.html>
86    #[inline(always)]
87    fn inverse(&self, p: f64) -> f64 {
88        self.mu + self.sigma * inverse(p)
89    }
90}
91
92impl distribution::Kurtosis for Gaussian {
93    #[inline]
94    fn kurtosis(&self) -> f64 {
95        0.0
96    }
97}
98
99impl distribution::Mean for Gaussian {
100    #[inline]
101    fn mean(&self) -> f64 {
102        self.mu
103    }
104}
105
106impl distribution::Median for Gaussian {
107    #[inline]
108    fn median(&self) -> f64 {
109        self.mu
110    }
111}
112
113impl distribution::Modes for Gaussian {
114    #[inline]
115    fn modes(&self) -> Vec<f64> {
116        vec![self.mu]
117    }
118}
119
120impl distribution::Sample for Gaussian {
121    /// Draw a sample.
122    ///
123    /// ## References
124    ///
125    /// 1. G. Marsaglia and W. W. Tsang, “The ziggurat method for generating
126    ///    random variables,” Journal of Statistical Software, vol. 5, no. 8,
127    ///    pp. 1–7, 10 2000.
128    ///
129    /// 2. D. Eddelbuettel, “Ziggurat Revisited,” 2014.
130    #[inline]
131    fn sample<S>(&self, source: &mut S) -> f64
132    where
133        S: Source,
134    {
135        self.sigma * sample(source) + self.mu
136    }
137}
138
139impl distribution::Skewness for Gaussian {
140    #[inline]
141    fn skewness(&self) -> f64 {
142        0.0
143    }
144}
145
146impl distribution::Variance for Gaussian {
147    #[inline]
148    fn variance(&self) -> f64 {
149        self.sigma * self.sigma
150    }
151
152    #[inline]
153    fn deviation(&self) -> f64 {
154        self.sigma
155    }
156}
157
158impl core::iter::FromIterator<f64> for Gaussian {
159    /// Infer the distribution from an iterator.
160    fn from_iter<T: IntoIterator<Item = f64>>(iterator: T) -> Self {
161        let samples: Vec<f64> = iterator.into_iter().collect();
162        let mu = samples.iter().fold(0.0, |a, b| a + b) / samples.len() as f64;
163        let sigma = f64::sqrt(
164            samples
165                .iter()
166                .fold(0.0, |a, b| a + f64::powf(b - mu as f64, 2.0))
167                / (samples.len() - 1) as f64,
168        );
169        Gaussian::new(mu, sigma)
170    }
171}
172
173/// Compute the inverse cumulative distribution function of the standard
174/// Gaussian distribution.
175#[allow(clippy::excessive_precision)]
176pub fn inverse(p: f64) -> f64 {
177    use core::f64::{INFINITY, NEG_INFINITY};
178
179    should!((0.0..=1.0).contains(&p));
180
181    const CONST1: f64 = 0.180625;
182    const CONST2: f64 = 1.6;
183    const SPLIT1: f64 = 0.425;
184    const SPLIT2: f64 = 5.0;
185    const A: [f64; 8] = [
186        3.3871328727963666080e+00,
187        1.3314166789178437745e+02,
188        1.9715909503065514427e+03,
189        1.3731693765509461125e+04,
190        4.5921953931549871457e+04,
191        6.7265770927008700853e+04,
192        3.3430575583588128105e+04,
193        2.5090809287301226727e+03,
194    ];
195    const B: [f64; 8] = [
196        1.0000000000000000000e+00,
197        4.2313330701600911252e+01,
198        6.8718700749205790830e+02,
199        5.3941960214247511077e+03,
200        2.1213794301586595867e+04,
201        3.9307895800092710610e+04,
202        2.8729085735721942674e+04,
203        5.2264952788528545610e+03,
204    ];
205    const C: [f64; 8] = [
206        1.42343711074968357734e+00,
207        4.63033784615654529590e+00,
208        5.76949722146069140550e+00,
209        3.64784832476320460504e+00,
210        1.27045825245236838258e+00,
211        2.41780725177450611770e-01,
212        2.27238449892691845833e-02,
213        7.74545014278341407640e-04,
214    ];
215    const D: [f64; 8] = [
216        1.00000000000000000000e+00,
217        2.05319162663775882187e+00,
218        1.67638483018380384940e+00,
219        6.89767334985100004550e-01,
220        1.48103976427480074590e-01,
221        1.51986665636164571966e-02,
222        5.47593808499534494600e-04,
223        1.05075007164441684324e-09,
224    ];
225    const E: [f64; 8] = [
226        6.65790464350110377720e+00,
227        5.46378491116411436990e+00,
228        1.78482653991729133580e+00,
229        2.96560571828504891230e-01,
230        2.65321895265761230930e-02,
231        1.24266094738807843860e-03,
232        2.71155556874348757815e-05,
233        2.01033439929228813265e-07,
234    ];
235    const F: [f64; 8] = [
236        1.00000000000000000000e+00,
237        5.99832206555887937690e-01,
238        1.36929880922735805310e-01,
239        1.48753612908506148525e-02,
240        7.86869131145613259100e-04,
241        1.84631831751005468180e-05,
242        1.42151175831644588870e-07,
243        2.04426310338993978564e-15,
244    ];
245
246    #[inline(always)]
247    #[rustfmt::skip]
248    fn poly(c: &[f64], x: f64) -> f64 {
249        c[0] + x * (c[1] + x * (c[2] + x * (c[3] + x * (
250        c[4] + x * (c[5] + x * (c[6] + x * (c[7])))))))
251    }
252
253    if p <= 0.0 {
254        return NEG_INFINITY;
255    }
256    if 1.0 <= p {
257        return INFINITY;
258    }
259
260    let q = p - 0.5;
261
262    if (if q < 0.0 { -q } else { q }) <= SPLIT1 {
263        let x = CONST1 - q * q;
264        return q * poly(&A, x) / poly(&B, x);
265    }
266
267    let mut x = if q < 0.0 { p } else { 1.0 - p };
268
269    x = (-x.ln()).sqrt();
270
271    if x <= SPLIT2 {
272        x -= CONST2;
273        x = poly(&C, x) / poly(&D, x);
274    } else {
275        x -= SPLIT2;
276        x = poly(&E, x) / poly(&F, x);
277    }
278
279    if q < 0.0 {
280        -x
281    } else {
282        x
283    }
284}
285
286/// Draw a sample from the standard Gaussian distribution.
287pub fn sample<S: Source>(source: &mut S) -> f64 {
288    loop {
289        let u = source.read::<u64>();
290
291        let i = (u & 0x7F) as usize;
292        let j = ((u >> 8) & 0xFFFFFF) as u32;
293        let s = if u & 0x80 != 0 { 1.0 } else { -1.0 };
294
295        if j < K[i] {
296            let x = j as f64 * W[i];
297            return s * x;
298        }
299
300        let (x, y) = if i < 127 {
301            let x = j as f64 * W[i];
302            let y = Y[i + 1] + (Y[i] - Y[i + 1]) * source.read::<f64>();
303            (x, y)
304        } else {
305            let x = R - (-source.read::<f64>()).ln_1p() / R;
306            let y = (-R * (x - 0.5 * R)).exp() * source.read::<f64>();
307            (x, y)
308        };
309
310        if y < (-0.5 * x * x).exp() {
311            return s * x;
312        }
313    }
314}
315
316const R: f64 = 3.44428647676;
317
318#[rustfmt::skip]
319const K: [u32; 128] = [
320    00000000, 12590644, 14272653, 14988939,
321    15384584, 15635009, 15807561, 15933577,
322    16029594, 16105155, 16166147, 16216399,
323    16258508, 16294295, 16325078, 16351831,
324    16375291, 16396026, 16414479, 16431002,
325    16445880, 16459343, 16471578, 16482744,
326    16492970, 16502368, 16511031, 16519039,
327    16526459, 16533352, 16539769, 16545755,
328    16551348, 16556584, 16561493, 16566101,
329    16570433, 16574511, 16578353, 16581977,
330    16585398, 16588629, 16591685, 16594575,
331    16597311, 16599901, 16602354, 16604679,
332    16606881, 16608968, 16610945, 16612818,
333    16614592, 16616272, 16617861, 16619363,
334    16620782, 16622121, 16623383, 16624570,
335    16625685, 16626730, 16627708, 16628619,
336    16629465, 16630248, 16630969, 16631628,
337    16632228, 16632768, 16633248, 16633671,
338    16634034, 16634340, 16634586, 16634774,
339    16634903, 16634972, 16634980, 16634926,
340    16634810, 16634628, 16634381, 16634066,
341    16633680, 16633222, 16632688, 16632075,
342    16631380, 16630598, 16629726, 16628757,
343    16627686, 16626507, 16625212, 16623794,
344    16622243, 16620548, 16618698, 16616679,
345    16614476, 16612071, 16609444, 16606571,
346    16603425, 16599973, 16596178, 16591995,
347    16587369, 16582237, 16576520, 16570120,
348    16562917, 16554758, 16545450, 16534739,
349    16522287, 16507638, 16490152, 16468907,
350    16442518, 16408804, 16364095, 16301683,
351    16207738, 16047994, 15704248, 15472926
352];
353
354#[rustfmt::skip]
355const Y: [f64; 128] = [
356    1.0000000000000, 0.96359862301100, 0.93628081335300, 0.91304110425300,
357    0.8922785066960, 0.87323935691900, 0.85549640763400, 0.83877892834900,
358    0.8229020836990, 0.80773273823400, 0.79317104551900, 0.77913972650500,
359    0.7655774360820, 0.75243445624800, 0.73966978767700, 0.72724912028500,
360    0.7151433774130, 0.70332764645500, 0.69178037703500, 0.68048276891000,
361    0.6694182972330, 0.65857233912000, 0.64793187618900, 0.63748525489600,
362    0.6272219914500, 0.61713261153200, 0.60720851746700, 0.59744187729600,
363    0.5878255314650, 0.57835291380300, 0.56901798419800, 0.55981517091100,
364    0.5507393208770, 0.54178565668200, 0.53294973914500, 0.52422743462800,
365    0.5156148863730, 0.50710848925300, 0.49870486747800, 0.49040085481200,
366    0.4821934769860, 0.47407993601000, 0.46605759612500, 0.45812397121400,
367    0.4502767134670, 0.44251360317100, 0.43483253947300, 0.42723153202200,
368    0.4197086933790, 0.41226223212000, 0.40489044654800, 0.39759171895500,
369    0.3903645103820, 0.38320735581600, 0.37611885978800, 0.36909769233400,
370    0.3621425852820, 0.35525232883400, 0.34842576841500, 0.34166180177600,
371    0.3349593763110, 0.32831748658800, 0.32173517206300, 0.31521151497000,
372    0.3087456383670, 0.30233670433800, 0.29598391232000, 0.28968649757100,
373    0.2834437297390, 0.27725491156000, 0.27111937764900, 0.26503649338700,
374    0.2590056539120, 0.25302628318300, 0.24709783313900, 0.24121978293200,
375    0.2353916382390, 0.22961293064900, 0.22388321712200, 0.21820207951800,
376    0.2125691242010, 0.20698398170900, 0.20144630649600, 0.19595577674500,
377    0.1905120942560, 0.18511498440600, 0.17976419618500, 0.17445950232400,
378    0.1692006994920, 0.16398760860000, 0.15882007519500, 0.15369796996400,
379    0.1486211893480, 0.14358965629500, 0.13860332114300, 0.13366216266900,
380    0.1287661893090, 0.12391544058200, 0.11910998874500, 0.11434994070300,
381    0.1096354402300, 0.10496667053300, 0.10034385723200, 0.09576727182660,
382    0.0912372357329, 0.08675412501270, 0.08231837593200, 0.07793049152950,
383    0.0735910494266, 0.06930071117420, 0.06506023352900, 0.06087048217450,
384    0.0567324485840, 0.05264727098000, 0.04861626071630, 0.04464093597690,
385    0.0407230655415, 0.03686472673860, 0.03306838393780, 0.02933699774110,
386    0.0256741818288, 0.02208443726340, 0.01857352005770, 0.01514905528540,
387    0.0118216532614, 0.00860719483079, 0.00553245272614, 0.00265435214565,
388];
389
390#[rustfmt::skip]
391const W: [f64; 128] = [
392    1.62318314817e-08, 2.16291505214e-08, 2.54246305087e-08, 2.84579525938e-08,
393    3.10340022482e-08, 3.33011726243e-08, 3.53439060345e-08, 3.72152672658e-08,
394    3.89509895720e-08, 4.05763964764e-08, 4.21101548915e-08, 4.35664624904e-08,
395    4.49563968336e-08, 4.62887864029e-08, 4.75707945735e-08, 4.88083237257e-08,
396    5.00063025384e-08, 5.11688950428e-08, 5.22996558616e-08, 5.34016475624e-08,
397    5.44775307871e-08, 5.55296344581e-08, 5.65600111659e-08, 5.75704813695e-08,
398    5.85626690412e-08, 5.95380306862e-08, 6.04978791776e-08, 6.14434034901e-08,
399    6.23756851626e-08, 6.32957121259e-08, 6.42043903937e-08, 6.51025540077e-08,
400    6.59909735447e-08, 6.68703634341e-08, 6.77413882848e-08, 6.86046683810e-08,
401    6.94607844804e-08, 7.03102820203e-08, 7.11536748229e-08, 7.19914483720e-08,
402    7.28240627230e-08, 7.36519550992e-08, 7.44755422158e-08, 7.52952223703e-08,
403    7.61113773308e-08, 7.69243740467e-08, 7.77345662086e-08, 7.85422956743e-08,
404    7.93478937793e-08, 8.01516825471e-08, 8.09539758128e-08, 8.17550802699e-08,
405    8.25552964535e-08, 8.33549196661e-08, 8.41542408569e-08, 8.49535474601e-08,
406    8.57531242006e-08, 8.65532538723e-08, 8.73542180955e-08, 8.81562980590e-08,
407    8.89597752521e-08, 8.97649321908e-08, 9.05720531451e-08, 9.13814248700e-08,
408    9.21933373471e-08, 9.30080845407e-08, 9.38259651738e-08, 9.46472835298e-08,
409    9.54723502847e-08, 9.63014833769e-08, 9.71350089201e-08, 9.79732621669e-08,
410    9.88165885297e-08, 9.96653446693e-08, 1.00519899658e-07, 1.01380636230e-07,
411    1.02247952126e-07, 1.03122261554e-07, 1.04003996769e-07, 1.04893609795e-07,
412    1.05791574313e-07, 1.06698387725e-07, 1.07614573423e-07, 1.08540683296e-07,
413    1.09477300508e-07, 1.10425042570e-07, 1.11384564771e-07, 1.12356564007e-07,
414    1.13341783071e-07, 1.14341015475e-07, 1.15355110887e-07, 1.16384981291e-07,
415    1.17431607977e-07, 1.18496049514e-07, 1.19579450872e-07, 1.20683053909e-07,
416    1.21808209468e-07, 1.22956391410e-07, 1.24129212952e-07, 1.25328445797e-07,
417    1.26556042658e-07, 1.27814163916e-07, 1.29105209375e-07, 1.30431856341e-07,
418    1.31797105598e-07, 1.33204337360e-07, 1.34657379914e-07, 1.36160594606e-07,
419    1.37718982103e-07, 1.39338316679e-07, 1.41025317971e-07, 1.42787873535e-07,
420    1.44635331499e-07, 1.46578891730e-07, 1.48632138436e-07, 1.50811780719e-07,
421    1.53138707402e-07, 1.55639532047e-07, 1.58348931426e-07, 1.61313325908e-07,
422    1.64596952856e-07, 1.68292495203e-07, 1.72541128694e-07, 1.77574279496e-07,
423    1.83813550477e-07, 1.92166040885e-07, 2.05295471952e-07, 2.22600839893e-07,
424];
425
426#[cfg(test)]
427mod tests {
428    use core::iter::FromIterator;
429
430    use alloc::{vec, vec::Vec};
431    use assert;
432    use prelude::*;
433
434    macro_rules! new(
435        ($mu:expr, $sigma:expr) => (Gaussian::new($mu, $sigma));
436    );
437
438    #[test]
439    fn density() {
440        let d = new!(1.0, 2.0);
441        let x = vec![
442            -4.0, -3.5, -3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5,
443            4.0,
444        ];
445        let p = vec![
446            8.764150246784270e-03,
447            1.586982591783371e-02,
448            2.699548325659403e-02,
449            4.313865941325577e-02,
450            6.475879783294587e-02,
451            9.132454269451096e-02,
452            1.209853622595717e-01,
453            1.505687160774022e-01,
454            1.760326633821498e-01,
455            1.933340584014246e-01,
456            1.994711402007164e-01,
457            1.933340584014246e-01,
458            1.760326633821498e-01,
459            1.505687160774022e-01,
460            1.209853622595717e-01,
461            9.132454269451096e-02,
462            6.475879783294587e-02,
463        ];
464
465        assert::close(
466            &x.iter().map(|&x| d.density(x)).collect::<Vec<_>>(),
467            &p,
468            1e-14,
469        );
470    }
471
472    #[test]
473    fn distribution() {
474        let d = new!(1.0, 2.0);
475        let x = vec![
476            -4.0, -3.5, -3.0, -2.5, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5,
477            4.0,
478        ];
479        let p = vec![
480            6.209665325776139e-03,
481            1.222447265504470e-02,
482            2.275013194817922e-02,
483            4.005915686381709e-02,
484            6.680720126885809e-02,
485            1.056497736668553e-01,
486            1.586552539314571e-01,
487            2.266273523768682e-01,
488            3.085375387259869e-01,
489            4.012936743170763e-01,
490            5.000000000000000e-01,
491            5.987063256829237e-01,
492            6.914624612740131e-01,
493            7.733726476231317e-01,
494            8.413447460685429e-01,
495            8.943502263331446e-01,
496            9.331927987311419e-01,
497        ];
498
499        assert::close(
500            &x.iter().map(|&x| d.distribution(x)).collect::<Vec<_>>(),
501            &p,
502            1e-14,
503        );
504    }
505
506    #[test]
507    fn entropy() {
508        use core::f64::consts::PI;
509        assert_eq!(new!(0.0, 1.0).entropy(), ((2.0 * PI).ln() + 1.0) / 2.0);
510    }
511
512    #[test]
513    fn inverse() {
514        use core::f64::{INFINITY, NEG_INFINITY};
515
516        let d = new!(-1.0, 0.25);
517        let p = vec![
518            0.00, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65,
519            0.70, 0.75, 0.80, 0.85, 0.90, 0.95, 1.00,
520        ];
521        let x = vec![
522            NEG_INFINITY,
523            -1.411213406737868e+00,
524            -1.320387891386150e+00,
525            -1.259108347373447e+00,
526            -1.210405308393228e+00,
527            -1.168622437549020e+00,
528            -1.131100128177010e+00,
529            -1.096330116601892e+00,
530            -1.063336775783950e+00,
531            -1.031415336713768e+00,
532            -1.000000000000000e+00,
533            -9.685846632862315e-01,
534            -9.366632242160501e-01,
535            -9.036698833981082e-01,
536            -8.688998718229899e-01,
537            -8.313775624509796e-01,
538            -7.895946916067714e-01,
539            -7.408916526265525e-01,
540            -6.796121086138498e-01,
541            -5.887865932621319e-01,
542            INFINITY,
543        ];
544
545        assert::close(
546            &p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(),
547            &x,
548            1e-14,
549        );
550    }
551
552    #[test]
553    fn kurtosis() {
554        assert_eq!(new!(0.0, 2.0).kurtosis(), 0.0);
555    }
556
557    #[test]
558    fn mean() {
559        assert_eq!(new!(0.0, 1.0).mean(), 0.0);
560    }
561
562    #[test]
563    fn median() {
564        assert_eq!(new!(0.0, 2.0).median(), 0.0);
565    }
566
567    #[test]
568    fn modes() {
569        assert_eq!(new!(2.0, 5.0).modes(), vec![2.0]);
570    }
571
572    #[test]
573    fn skewness() {
574        assert_eq!(new!(0.0, 2.0).skewness(), 0.0);
575    }
576
577    #[test]
578    fn variance() {
579        assert_eq!(new!(0.0, 2.0).variance(), 4.0);
580    }
581
582    #[test]
583    fn deviation() {
584        assert_eq!(new!(0.0, 2.0).deviation(), 2.0);
585    }
586
587    #[test]
588    fn from_iter() {
589        let mut source = source::default(42);
590        let distribution = new!(1.0, 2.0);
591        let sampler = Independent(&distribution, &mut source);
592        let samples = sampler.take(10000).collect::<Vec<_>>();
593        let derived_distribution = Gaussian::from_iter(samples);
594
595        assert::close(derived_distribution.mu, distribution.mu, 0.1);
596        assert::close(derived_distribution.sigma, distribution.sigma, 0.1);
597        assert::close(derived_distribution.norm, distribution.norm, 0.1);
598    }
599}