compute/distributions/
normal.rs

1#![allow(clippy::many_single_char_names)]
2
3use crate::{distributions::*, prelude::erf};
4use std::f64::consts::PI;
5
6/// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) distribution.
7#[derive(Debug, Clone, Copy)]
8pub struct Normal {
9    /// Mean (or location) parameter.
10    mu: f64,
11    /// Standard deviation (or scale) parameter.
12    sigma: f64,
13}
14
15impl Normal {
16    /// Create a new Normal distribution with mean `mu` and standard deviation `sigma`.
17    ///
18    /// # Errors
19    /// Panics if `sigma < 0`.
20    pub fn new(mu: f64, sigma: f64) -> Self {
21        if sigma < 0. {
22            panic!("Sigma must be non-negative.")
23        }
24        Normal { mu, sigma }
25    }
26    pub fn set_mu(&mut self, mu: f64) -> &mut Self {
27        self.mu = mu;
28        self
29    }
30    pub fn set_sigma(&mut self, sigma: f64) -> &mut Self {
31        if sigma < 0. {
32            panic!("Sigma must be non-negative.")
33        }
34        self.sigma = sigma;
35        self
36    }
37    /// TODO: make `cdf` a method of the `Continuous` trait.
38    pub fn cdf(&self, x: f64) -> f64 {
39        0.5 * (1. + erf((x - self.mu) / (self.sigma * 2_f64.sqrt())))
40    }
41}
42
43impl Default for Normal {
44    fn default() -> Self {
45        Self::new(0., 1.)
46    }
47}
48
49impl Distribution for Normal {
50    type Output = f64;
51    /// Sample from the given Normal distribution.
52    fn sample(&self) -> f64 {
53        loop {
54            let u = alea::u64();
55
56            let i = (u & 0x7F) as usize;
57            let j = ((u >> 8) & 0xFFFFFF) as u32;
58            let s = if u & 0x80 != 0 { 1.0 } else { -1.0 };
59
60            if j < K[i] {
61                let x = j as f64 * W[i];
62                return s * x * self.sigma + self.mu;
63            }
64
65            let (x, y) = if i < 127 {
66                let x = j as f64 * W[i];
67                let y = Y[i + 1] + (Y[i] - Y[i + 1]) * alea::f64();
68                (x, y)
69            } else {
70                let x = R - (-alea::f64()).ln_1p() / R;
71                let y = (-R * (x - 0.5 * R)).exp() * alea::f64();
72                (x, y)
73            };
74
75            if y < (-0.5 * x * x).exp() {
76                return s * x * self.sigma + self.mu;
77            }
78        }
79    }
80}
81
82impl Distribution1D for Normal {
83    fn update(&mut self, params: &[f64]) {
84        self.set_mu(params[0]).set_sigma(params[1]);
85    }
86}
87
88impl Continuous for Normal {
89    type PDFType = f64;
90    /// Calculates the probability density function of the given Normal distribution at `x`.
91    fn pdf(&self, x: f64) -> f64 {
92        1. / (self.sigma * (2. * PI).sqrt()) * (-0.5 * ((x - self.mu) / self.sigma).powi(2)).exp()
93    }
94
95    fn ln_pdf(&self, x: Self::PDFType) -> f64 {
96        -0.5 * ((x - self.mu) / self.sigma).powi(2) - (self.sigma * (2. * PI).sqrt()).ln()
97    }
98}
99
100impl Mean for Normal {
101    type MeanType = f64;
102    /// Returns the mean of the given Normal distribution.
103    fn mean(&self) -> f64 {
104        self.mu
105    }
106}
107
108impl Variance for Normal {
109    type VarianceType = f64;
110    /// Returns the variance of the given Normal distribution.
111    fn var(&self) -> f64 {
112        self.sigma
113    }
114}
115
116#[test]
117fn maxprob() {
118    let n = self::Normal::new(5., 4.);
119    (0..20).for_each(|x| {
120        assert!(n.pdf(5.) >= n.pdf(x as f64));
121    });
122    assert!(n.pdf(5.) > n.pdf(2.));
123    assert!(n.pdf(5.) > n.pdf(6.));
124}
125
126const R: f64 = 3.44428647676;
127
128const K: [u32; 128] = [
129    00000000, 12590644, 14272653, 14988939, 15384584, 15635009, 15807561, 15933577, 16029594,
130    16105155, 16166147, 16216399, 16258508, 16294295, 16325078, 16351831, 16375291, 16396026,
131    16414479, 16431002, 16445880, 16459343, 16471578, 16482744, 16492970, 16502368, 16511031,
132    16519039, 16526459, 16533352, 16539769, 16545755, 16551348, 16556584, 16561493, 16566101,
133    16570433, 16574511, 16578353, 16581977, 16585398, 16588629, 16591685, 16594575, 16597311,
134    16599901, 16602354, 16604679, 16606881, 16608968, 16610945, 16612818, 16614592, 16616272,
135    16617861, 16619363, 16620782, 16622121, 16623383, 16624570, 16625685, 16626730, 16627708,
136    16628619, 16629465, 16630248, 16630969, 16631628, 16632228, 16632768, 16633248, 16633671,
137    16634034, 16634340, 16634586, 16634774, 16634903, 16634972, 16634980, 16634926, 16634810,
138    16634628, 16634381, 16634066, 16633680, 16633222, 16632688, 16632075, 16631380, 16630598,
139    16629726, 16628757, 16627686, 16626507, 16625212, 16623794, 16622243, 16620548, 16618698,
140    16616679, 16614476, 16612071, 16609444, 16606571, 16603425, 16599973, 16596178, 16591995,
141    16587369, 16582237, 16576520, 16570120, 16562917, 16554758, 16545450, 16534739, 16522287,
142    16507638, 16490152, 16468907, 16442518, 16408804, 16364095, 16301683, 16207738, 16047994,
143    15704248, 15472926,
144];
145
146const Y: [f64; 128] = [
147    1.0000000000000,
148    0.96359862301100,
149    0.93628081335300,
150    0.91304110425300,
151    0.8922785066960,
152    0.87323935691900,
153    0.85549640763400,
154    0.83877892834900,
155    0.8229020836990,
156    0.80773273823400,
157    0.79317104551900,
158    0.77913972650500,
159    0.7655774360820,
160    0.75243445624800,
161    0.73966978767700,
162    0.72724912028500,
163    0.7151433774130,
164    0.70332764645500,
165    0.69178037703500,
166    0.68048276891000,
167    0.6694182972330,
168    0.65857233912000,
169    0.64793187618900,
170    0.63748525489600,
171    0.6272219914500,
172    0.61713261153200,
173    0.60720851746700,
174    0.59744187729600,
175    0.5878255314650,
176    0.57835291380300,
177    0.56901798419800,
178    0.55981517091100,
179    0.5507393208770,
180    0.54178565668200,
181    0.53294973914500,
182    0.52422743462800,
183    0.5156148863730,
184    0.50710848925300,
185    0.49870486747800,
186    0.49040085481200,
187    0.4821934769860,
188    0.47407993601000,
189    0.46605759612500,
190    0.45812397121400,
191    0.4502767134670,
192    0.44251360317100,
193    0.43483253947300,
194    0.42723153202200,
195    0.4197086933790,
196    0.41226223212000,
197    0.40489044654800,
198    0.39759171895500,
199    0.3903645103820,
200    0.38320735581600,
201    0.37611885978800,
202    0.36909769233400,
203    0.3621425852820,
204    0.35525232883400,
205    0.34842576841500,
206    0.34166180177600,
207    0.3349593763110,
208    0.32831748658800,
209    0.32173517206300,
210    0.31521151497000,
211    0.3087456383670,
212    0.30233670433800,
213    0.29598391232000,
214    0.28968649757100,
215    0.2834437297390,
216    0.27725491156000,
217    0.27111937764900,
218    0.26503649338700,
219    0.2590056539120,
220    0.25302628318300,
221    0.24709783313900,
222    0.24121978293200,
223    0.2353916382390,
224    0.22961293064900,
225    0.22388321712200,
226    0.21820207951800,
227    0.2125691242010,
228    0.20698398170900,
229    0.20144630649600,
230    0.19595577674500,
231    0.1905120942560,
232    0.18511498440600,
233    0.17976419618500,
234    0.17445950232400,
235    0.1692006994920,
236    0.16398760860000,
237    0.15882007519500,
238    0.15369796996400,
239    0.1486211893480,
240    0.14358965629500,
241    0.13860332114300,
242    0.13366216266900,
243    0.1287661893090,
244    0.12391544058200,
245    0.11910998874500,
246    0.11434994070300,
247    0.1096354402300,
248    0.10496667053300,
249    0.10034385723200,
250    0.09576727182660,
251    0.0912372357329,
252    0.08675412501270,
253    0.08231837593200,
254    0.07793049152950,
255    0.0735910494266,
256    0.06930071117420,
257    0.06506023352900,
258    0.06087048217450,
259    0.0567324485840,
260    0.05264727098000,
261    0.04861626071630,
262    0.04464093597690,
263    0.0407230655415,
264    0.03686472673860,
265    0.03306838393780,
266    0.02933699774110,
267    0.0256741818288,
268    0.02208443726340,
269    0.01857352005770,
270    0.01514905528540,
271    0.0118216532614,
272    0.00860719483079,
273    0.00553245272614,
274    0.00265435214565,
275];
276
277const W: [f64; 128] = [
278    1.62318314817e-08,
279    2.16291505214e-08,
280    2.54246305087e-08,
281    2.84579525938e-08,
282    3.10340022482e-08,
283    3.33011726243e-08,
284    3.53439060345e-08,
285    3.72152672658e-08,
286    3.89509895720e-08,
287    4.05763964764e-08,
288    4.21101548915e-08,
289    4.35664624904e-08,
290    4.49563968336e-08,
291    4.62887864029e-08,
292    4.75707945735e-08,
293    4.88083237257e-08,
294    5.00063025384e-08,
295    5.11688950428e-08,
296    5.22996558616e-08,
297    5.34016475624e-08,
298    5.44775307871e-08,
299    5.55296344581e-08,
300    5.65600111659e-08,
301    5.75704813695e-08,
302    5.85626690412e-08,
303    5.95380306862e-08,
304    6.04978791776e-08,
305    6.14434034901e-08,
306    6.23756851626e-08,
307    6.32957121259e-08,
308    6.42043903937e-08,
309    6.51025540077e-08,
310    6.59909735447e-08,
311    6.68703634341e-08,
312    6.77413882848e-08,
313    6.86046683810e-08,
314    6.94607844804e-08,
315    7.03102820203e-08,
316    7.11536748229e-08,
317    7.19914483720e-08,
318    7.28240627230e-08,
319    7.36519550992e-08,
320    7.44755422158e-08,
321    7.52952223703e-08,
322    7.61113773308e-08,
323    7.69243740467e-08,
324    7.77345662086e-08,
325    7.85422956743e-08,
326    7.93478937793e-08,
327    8.01516825471e-08,
328    8.09539758128e-08,
329    8.17550802699e-08,
330    8.25552964535e-08,
331    8.33549196661e-08,
332    8.41542408569e-08,
333    8.49535474601e-08,
334    8.57531242006e-08,
335    8.65532538723e-08,
336    8.73542180955e-08,
337    8.81562980590e-08,
338    8.89597752521e-08,
339    8.97649321908e-08,
340    9.05720531451e-08,
341    9.13814248700e-08,
342    9.21933373471e-08,
343    9.30080845407e-08,
344    9.38259651738e-08,
345    9.46472835298e-08,
346    9.54723502847e-08,
347    9.63014833769e-08,
348    9.71350089201e-08,
349    9.79732621669e-08,
350    9.88165885297e-08,
351    9.96653446693e-08,
352    1.00519899658e-07,
353    1.01380636230e-07,
354    1.02247952126e-07,
355    1.03122261554e-07,
356    1.04003996769e-07,
357    1.04893609795e-07,
358    1.05791574313e-07,
359    1.06698387725e-07,
360    1.07614573423e-07,
361    1.08540683296e-07,
362    1.09477300508e-07,
363    1.10425042570e-07,
364    1.11384564771e-07,
365    1.12356564007e-07,
366    1.13341783071e-07,
367    1.14341015475e-07,
368    1.15355110887e-07,
369    1.16384981291e-07,
370    1.17431607977e-07,
371    1.18496049514e-07,
372    1.19579450872e-07,
373    1.20683053909e-07,
374    1.21808209468e-07,
375    1.22956391410e-07,
376    1.24129212952e-07,
377    1.25328445797e-07,
378    1.26556042658e-07,
379    1.27814163916e-07,
380    1.29105209375e-07,
381    1.30431856341e-07,
382    1.31797105598e-07,
383    1.33204337360e-07,
384    1.34657379914e-07,
385    1.36160594606e-07,
386    1.37718982103e-07,
387    1.39338316679e-07,
388    1.41025317971e-07,
389    1.42787873535e-07,
390    1.44635331499e-07,
391    1.46578891730e-07,
392    1.48632138436e-07,
393    1.50811780719e-07,
394    1.53138707402e-07,
395    1.55639532047e-07,
396    1.58348931426e-07,
397    1.61313325908e-07,
398    1.64596952856e-07,
399    1.68292495203e-07,
400    1.72541128694e-07,
401    1.77574279496e-07,
402    1.83813550477e-07,
403    1.92166040885e-07,
404    2.05295471952e-07,
405    2.22600839893e-07,
406];
407
408#[cfg(test)]
409mod tests {
410
411    use super::*;
412    use crate::statistics::{mean, std};
413    use approx_eq::assert_approx_eq;
414
415    #[test]
416    fn test_moments() {
417        let data1 = Normal::new(0., 1.).sample_n(1e6 as usize);
418        assert_approx_eq!(0., mean(&data1), 1e-2);
419        assert_approx_eq!(1., std(&data1), 1e-2);
420
421        let data2 = Normal::new(10., 20.).sample_n(1e6 as usize);
422        assert_approx_eq!(10., mean(&data2), 1e-2);
423        assert_approx_eq!(20., std(&data2), 1e-2);
424    }
425
426    #[test]
427    fn test_cdf() {
428        let x = vec![-4., -3.9, -2.81, -2.67, -2.01, 0.01, 0.75, 1.5, 1.79];
429        let y = vec![
430            3.167124183311986e-05,
431            4.8096344017602614e-05,
432            0.002477074998785861,
433            0.0037925623476854887,
434            0.022215594429431475,
435            0.5039893563146316,
436            0.7733726476231317,
437            0.9331927987311419,
438            0.9632730443012737,
439        ];
440        assert_eq!(x.len(), y.len());
441
442        let sn = Normal::new(0., 1.);
443
444        for i in 0..x.len() {
445            assert_approx_eq!(sn.cdf(x[i]), y[i], 1e-3);
446        }
447    }
448}