1#![allow(clippy::many_single_char_names)]
2
3use crate::{distributions::*, prelude::erf};
4use std::f64::consts::PI;
5
6#[derive(Debug, Clone, Copy)]
8pub struct Normal {
9 mu: f64,
11 sigma: f64,
13}
14
15impl Normal {
16 pub fn new(mu: f64, sigma: f64) -> Self {
21 if sigma < 0. {
22 panic!("Sigma must be non-negative.")
23 }
24 Normal { mu, sigma }
25 }
26 pub fn set_mu(&mut self, mu: f64) -> &mut Self {
27 self.mu = mu;
28 self
29 }
30 pub fn set_sigma(&mut self, sigma: f64) -> &mut Self {
31 if sigma < 0. {
32 panic!("Sigma must be non-negative.")
33 }
34 self.sigma = sigma;
35 self
36 }
37 pub fn cdf(&self, x: f64) -> f64 {
39 0.5 * (1. + erf((x - self.mu) / (self.sigma * 2_f64.sqrt())))
40 }
41}
42
43impl Default for Normal {
44 fn default() -> Self {
45 Self::new(0., 1.)
46 }
47}
48
49impl Distribution for Normal {
50 type Output = f64;
51 fn sample(&self) -> f64 {
53 loop {
54 let u = alea::u64();
55
56 let i = (u & 0x7F) as usize;
57 let j = ((u >> 8) & 0xFFFFFF) as u32;
58 let s = if u & 0x80 != 0 { 1.0 } else { -1.0 };
59
60 if j < K[i] {
61 let x = j as f64 * W[i];
62 return s * x * self.sigma + self.mu;
63 }
64
65 let (x, y) = if i < 127 {
66 let x = j as f64 * W[i];
67 let y = Y[i + 1] + (Y[i] - Y[i + 1]) * alea::f64();
68 (x, y)
69 } else {
70 let x = R - (-alea::f64()).ln_1p() / R;
71 let y = (-R * (x - 0.5 * R)).exp() * alea::f64();
72 (x, y)
73 };
74
75 if y < (-0.5 * x * x).exp() {
76 return s * x * self.sigma + self.mu;
77 }
78 }
79 }
80}
81
82impl Distribution1D for Normal {
83 fn update(&mut self, params: &[f64]) {
84 self.set_mu(params[0]).set_sigma(params[1]);
85 }
86}
87
88impl Continuous for Normal {
89 type PDFType = f64;
90 fn pdf(&self, x: f64) -> f64 {
92 1. / (self.sigma * (2. * PI).sqrt()) * (-0.5 * ((x - self.mu) / self.sigma).powi(2)).exp()
93 }
94
95 fn ln_pdf(&self, x: Self::PDFType) -> f64 {
96 -0.5 * ((x - self.mu) / self.sigma).powi(2) - (self.sigma * (2. * PI).sqrt()).ln()
97 }
98}
99
100impl Mean for Normal {
101 type MeanType = f64;
102 fn mean(&self) -> f64 {
104 self.mu
105 }
106}
107
108impl Variance for Normal {
109 type VarianceType = f64;
110 fn var(&self) -> f64 {
112 self.sigma
113 }
114}
115
116#[test]
117fn maxprob() {
118 let n = self::Normal::new(5., 4.);
119 (0..20).for_each(|x| {
120 assert!(n.pdf(5.) >= n.pdf(x as f64));
121 });
122 assert!(n.pdf(5.) > n.pdf(2.));
123 assert!(n.pdf(5.) > n.pdf(6.));
124}
125
126const R: f64 = 3.44428647676;
127
128const K: [u32; 128] = [
129 00000000, 12590644, 14272653, 14988939, 15384584, 15635009, 15807561, 15933577, 16029594,
130 16105155, 16166147, 16216399, 16258508, 16294295, 16325078, 16351831, 16375291, 16396026,
131 16414479, 16431002, 16445880, 16459343, 16471578, 16482744, 16492970, 16502368, 16511031,
132 16519039, 16526459, 16533352, 16539769, 16545755, 16551348, 16556584, 16561493, 16566101,
133 16570433, 16574511, 16578353, 16581977, 16585398, 16588629, 16591685, 16594575, 16597311,
134 16599901, 16602354, 16604679, 16606881, 16608968, 16610945, 16612818, 16614592, 16616272,
135 16617861, 16619363, 16620782, 16622121, 16623383, 16624570, 16625685, 16626730, 16627708,
136 16628619, 16629465, 16630248, 16630969, 16631628, 16632228, 16632768, 16633248, 16633671,
137 16634034, 16634340, 16634586, 16634774, 16634903, 16634972, 16634980, 16634926, 16634810,
138 16634628, 16634381, 16634066, 16633680, 16633222, 16632688, 16632075, 16631380, 16630598,
139 16629726, 16628757, 16627686, 16626507, 16625212, 16623794, 16622243, 16620548, 16618698,
140 16616679, 16614476, 16612071, 16609444, 16606571, 16603425, 16599973, 16596178, 16591995,
141 16587369, 16582237, 16576520, 16570120, 16562917, 16554758, 16545450, 16534739, 16522287,
142 16507638, 16490152, 16468907, 16442518, 16408804, 16364095, 16301683, 16207738, 16047994,
143 15704248, 15472926,
144];
145
146const Y: [f64; 128] = [
147 1.0000000000000,
148 0.96359862301100,
149 0.93628081335300,
150 0.91304110425300,
151 0.8922785066960,
152 0.87323935691900,
153 0.85549640763400,
154 0.83877892834900,
155 0.8229020836990,
156 0.80773273823400,
157 0.79317104551900,
158 0.77913972650500,
159 0.7655774360820,
160 0.75243445624800,
161 0.73966978767700,
162 0.72724912028500,
163 0.7151433774130,
164 0.70332764645500,
165 0.69178037703500,
166 0.68048276891000,
167 0.6694182972330,
168 0.65857233912000,
169 0.64793187618900,
170 0.63748525489600,
171 0.6272219914500,
172 0.61713261153200,
173 0.60720851746700,
174 0.59744187729600,
175 0.5878255314650,
176 0.57835291380300,
177 0.56901798419800,
178 0.55981517091100,
179 0.5507393208770,
180 0.54178565668200,
181 0.53294973914500,
182 0.52422743462800,
183 0.5156148863730,
184 0.50710848925300,
185 0.49870486747800,
186 0.49040085481200,
187 0.4821934769860,
188 0.47407993601000,
189 0.46605759612500,
190 0.45812397121400,
191 0.4502767134670,
192 0.44251360317100,
193 0.43483253947300,
194 0.42723153202200,
195 0.4197086933790,
196 0.41226223212000,
197 0.40489044654800,
198 0.39759171895500,
199 0.3903645103820,
200 0.38320735581600,
201 0.37611885978800,
202 0.36909769233400,
203 0.3621425852820,
204 0.35525232883400,
205 0.34842576841500,
206 0.34166180177600,
207 0.3349593763110,
208 0.32831748658800,
209 0.32173517206300,
210 0.31521151497000,
211 0.3087456383670,
212 0.30233670433800,
213 0.29598391232000,
214 0.28968649757100,
215 0.2834437297390,
216 0.27725491156000,
217 0.27111937764900,
218 0.26503649338700,
219 0.2590056539120,
220 0.25302628318300,
221 0.24709783313900,
222 0.24121978293200,
223 0.2353916382390,
224 0.22961293064900,
225 0.22388321712200,
226 0.21820207951800,
227 0.2125691242010,
228 0.20698398170900,
229 0.20144630649600,
230 0.19595577674500,
231 0.1905120942560,
232 0.18511498440600,
233 0.17976419618500,
234 0.17445950232400,
235 0.1692006994920,
236 0.16398760860000,
237 0.15882007519500,
238 0.15369796996400,
239 0.1486211893480,
240 0.14358965629500,
241 0.13860332114300,
242 0.13366216266900,
243 0.1287661893090,
244 0.12391544058200,
245 0.11910998874500,
246 0.11434994070300,
247 0.1096354402300,
248 0.10496667053300,
249 0.10034385723200,
250 0.09576727182660,
251 0.0912372357329,
252 0.08675412501270,
253 0.08231837593200,
254 0.07793049152950,
255 0.0735910494266,
256 0.06930071117420,
257 0.06506023352900,
258 0.06087048217450,
259 0.0567324485840,
260 0.05264727098000,
261 0.04861626071630,
262 0.04464093597690,
263 0.0407230655415,
264 0.03686472673860,
265 0.03306838393780,
266 0.02933699774110,
267 0.0256741818288,
268 0.02208443726340,
269 0.01857352005770,
270 0.01514905528540,
271 0.0118216532614,
272 0.00860719483079,
273 0.00553245272614,
274 0.00265435214565,
275];
276
277const W: [f64; 128] = [
278 1.62318314817e-08,
279 2.16291505214e-08,
280 2.54246305087e-08,
281 2.84579525938e-08,
282 3.10340022482e-08,
283 3.33011726243e-08,
284 3.53439060345e-08,
285 3.72152672658e-08,
286 3.89509895720e-08,
287 4.05763964764e-08,
288 4.21101548915e-08,
289 4.35664624904e-08,
290 4.49563968336e-08,
291 4.62887864029e-08,
292 4.75707945735e-08,
293 4.88083237257e-08,
294 5.00063025384e-08,
295 5.11688950428e-08,
296 5.22996558616e-08,
297 5.34016475624e-08,
298 5.44775307871e-08,
299 5.55296344581e-08,
300 5.65600111659e-08,
301 5.75704813695e-08,
302 5.85626690412e-08,
303 5.95380306862e-08,
304 6.04978791776e-08,
305 6.14434034901e-08,
306 6.23756851626e-08,
307 6.32957121259e-08,
308 6.42043903937e-08,
309 6.51025540077e-08,
310 6.59909735447e-08,
311 6.68703634341e-08,
312 6.77413882848e-08,
313 6.86046683810e-08,
314 6.94607844804e-08,
315 7.03102820203e-08,
316 7.11536748229e-08,
317 7.19914483720e-08,
318 7.28240627230e-08,
319 7.36519550992e-08,
320 7.44755422158e-08,
321 7.52952223703e-08,
322 7.61113773308e-08,
323 7.69243740467e-08,
324 7.77345662086e-08,
325 7.85422956743e-08,
326 7.93478937793e-08,
327 8.01516825471e-08,
328 8.09539758128e-08,
329 8.17550802699e-08,
330 8.25552964535e-08,
331 8.33549196661e-08,
332 8.41542408569e-08,
333 8.49535474601e-08,
334 8.57531242006e-08,
335 8.65532538723e-08,
336 8.73542180955e-08,
337 8.81562980590e-08,
338 8.89597752521e-08,
339 8.97649321908e-08,
340 9.05720531451e-08,
341 9.13814248700e-08,
342 9.21933373471e-08,
343 9.30080845407e-08,
344 9.38259651738e-08,
345 9.46472835298e-08,
346 9.54723502847e-08,
347 9.63014833769e-08,
348 9.71350089201e-08,
349 9.79732621669e-08,
350 9.88165885297e-08,
351 9.96653446693e-08,
352 1.00519899658e-07,
353 1.01380636230e-07,
354 1.02247952126e-07,
355 1.03122261554e-07,
356 1.04003996769e-07,
357 1.04893609795e-07,
358 1.05791574313e-07,
359 1.06698387725e-07,
360 1.07614573423e-07,
361 1.08540683296e-07,
362 1.09477300508e-07,
363 1.10425042570e-07,
364 1.11384564771e-07,
365 1.12356564007e-07,
366 1.13341783071e-07,
367 1.14341015475e-07,
368 1.15355110887e-07,
369 1.16384981291e-07,
370 1.17431607977e-07,
371 1.18496049514e-07,
372 1.19579450872e-07,
373 1.20683053909e-07,
374 1.21808209468e-07,
375 1.22956391410e-07,
376 1.24129212952e-07,
377 1.25328445797e-07,
378 1.26556042658e-07,
379 1.27814163916e-07,
380 1.29105209375e-07,
381 1.30431856341e-07,
382 1.31797105598e-07,
383 1.33204337360e-07,
384 1.34657379914e-07,
385 1.36160594606e-07,
386 1.37718982103e-07,
387 1.39338316679e-07,
388 1.41025317971e-07,
389 1.42787873535e-07,
390 1.44635331499e-07,
391 1.46578891730e-07,
392 1.48632138436e-07,
393 1.50811780719e-07,
394 1.53138707402e-07,
395 1.55639532047e-07,
396 1.58348931426e-07,
397 1.61313325908e-07,
398 1.64596952856e-07,
399 1.68292495203e-07,
400 1.72541128694e-07,
401 1.77574279496e-07,
402 1.83813550477e-07,
403 1.92166040885e-07,
404 2.05295471952e-07,
405 2.22600839893e-07,
406];
407
408#[cfg(test)]
409mod tests {
410
411 use super::*;
412 use crate::statistics::{mean, std};
413 use approx_eq::assert_approx_eq;
414
415 #[test]
416 fn test_moments() {
417 let data1 = Normal::new(0., 1.).sample_n(1e6 as usize);
418 assert_approx_eq!(0., mean(&data1), 1e-2);
419 assert_approx_eq!(1., std(&data1), 1e-2);
420
421 let data2 = Normal::new(10., 20.).sample_n(1e6 as usize);
422 assert_approx_eq!(10., mean(&data2), 1e-2);
423 assert_approx_eq!(20., std(&data2), 1e-2);
424 }
425
426 #[test]
427 fn test_cdf() {
428 let x = vec![-4., -3.9, -2.81, -2.67, -2.01, 0.01, 0.75, 1.5, 1.79];
429 let y = vec![
430 3.167124183311986e-05,
431 4.8096344017602614e-05,
432 0.002477074998785861,
433 0.0037925623476854887,
434 0.022215594429431475,
435 0.5039893563146316,
436 0.7733726476231317,
437 0.9331927987311419,
438 0.9632730443012737,
439 ];
440 assert_eq!(x.len(), y.len());
441
442 let sn = Normal::new(0., 1.);
443
444 for i in 0..x.len() {
445 assert_approx_eq!(sn.cdf(x[i]), y[i], 1e-3);
446 }
447 }
448}