#![allow(clippy::many_single_char_names)]
use crate::{distributions::*, prelude::erf};
use std::f64::consts::PI;
#[derive(Debug, Clone, Copy)]
pub struct Normal {
mu: f64,
sigma: f64,
}
impl Normal {
pub fn new(mu: f64, sigma: f64) -> Self {
if sigma < 0. {
panic!("Sigma must be non-negative.")
}
Normal { mu, sigma }
}
pub fn set_mu(&mut self, mu: f64) -> &mut Self {
self.mu = mu;
self
}
pub fn set_sigma(&mut self, sigma: f64) -> &mut Self {
if sigma < 0. {
panic!("Sigma must be non-negative.")
}
self.sigma = sigma;
self
}
pub fn cdf(&self, x: f64) -> f64 {
0.5 * (1. + erf((x - self.mu) / (self.sigma * 2_f64.sqrt())))
}
}
impl Default for Normal {
fn default() -> Self {
Self::new(0., 1.)
}
}
impl Distribution for Normal {
type Output = f64;
fn sample(&self) -> f64 {
loop {
let u = alea::u64();
let i = (u & 0x7F) as usize;
let j = ((u >> 8) & 0xFFFFFF) as u32;
let s = if u & 0x80 != 0 { 1.0 } else { -1.0 };
if j < K[i] {
let x = j as f64 * W[i];
return s * x * self.sigma + self.mu;
}
let (x, y) = if i < 127 {
let x = j as f64 * W[i];
let y = Y[i + 1] + (Y[i] - Y[i + 1]) * alea::f64();
(x, y)
} else {
let x = R - (-alea::f64()).ln_1p() / R;
let y = (-R * (x - 0.5 * R)).exp() * alea::f64();
(x, y)
};
if y < (-0.5 * x * x).exp() {
return s * x * self.sigma + self.mu;
}
}
}
}
impl Distribution1D for Normal {
fn update(&mut self, params: &[f64]) {
self.set_mu(params[0]).set_sigma(params[1]);
}
}
impl Continuous for Normal {
type PDFType = f64;
fn pdf(&self, x: f64) -> f64 {
1. / (self.sigma * (2. * PI).sqrt()) * (-0.5 * ((x - self.mu) / self.sigma).powi(2)).exp()
}
fn ln_pdf(&self, x: Self::PDFType) -> f64 {
-0.5 * ((x - self.mu) / self.sigma).powi(2) - (self.sigma * (2. * PI).sqrt()).ln()
}
}
impl Mean for Normal {
type MeanType = f64;
fn mean(&self) -> f64 {
self.mu
}
}
impl Variance for Normal {
type VarianceType = f64;
fn var(&self) -> f64 {
self.sigma
}
}
#[test]
fn maxprob() {
let n = self::Normal::new(5., 4.);
(0..20).for_each(|x| {
assert!(n.pdf(5.) >= n.pdf(x as f64));
});
assert!(n.pdf(5.) > n.pdf(2.));
assert!(n.pdf(5.) > n.pdf(6.));
}
const R: f64 = 3.44428647676;
const K: [u32; 128] = [
00000000, 12590644, 14272653, 14988939, 15384584, 15635009, 15807561, 15933577, 16029594,
16105155, 16166147, 16216399, 16258508, 16294295, 16325078, 16351831, 16375291, 16396026,
16414479, 16431002, 16445880, 16459343, 16471578, 16482744, 16492970, 16502368, 16511031,
16519039, 16526459, 16533352, 16539769, 16545755, 16551348, 16556584, 16561493, 16566101,
16570433, 16574511, 16578353, 16581977, 16585398, 16588629, 16591685, 16594575, 16597311,
16599901, 16602354, 16604679, 16606881, 16608968, 16610945, 16612818, 16614592, 16616272,
16617861, 16619363, 16620782, 16622121, 16623383, 16624570, 16625685, 16626730, 16627708,
16628619, 16629465, 16630248, 16630969, 16631628, 16632228, 16632768, 16633248, 16633671,
16634034, 16634340, 16634586, 16634774, 16634903, 16634972, 16634980, 16634926, 16634810,
16634628, 16634381, 16634066, 16633680, 16633222, 16632688, 16632075, 16631380, 16630598,
16629726, 16628757, 16627686, 16626507, 16625212, 16623794, 16622243, 16620548, 16618698,
16616679, 16614476, 16612071, 16609444, 16606571, 16603425, 16599973, 16596178, 16591995,
16587369, 16582237, 16576520, 16570120, 16562917, 16554758, 16545450, 16534739, 16522287,
16507638, 16490152, 16468907, 16442518, 16408804, 16364095, 16301683, 16207738, 16047994,
15704248, 15472926,
];
const Y: [f64; 128] = [
1.0000000000000,
0.96359862301100,
0.93628081335300,
0.91304110425300,
0.8922785066960,
0.87323935691900,
0.85549640763400,
0.83877892834900,
0.8229020836990,
0.80773273823400,
0.79317104551900,
0.77913972650500,
0.7655774360820,
0.75243445624800,
0.73966978767700,
0.72724912028500,
0.7151433774130,
0.70332764645500,
0.69178037703500,
0.68048276891000,
0.6694182972330,
0.65857233912000,
0.64793187618900,
0.63748525489600,
0.6272219914500,
0.61713261153200,
0.60720851746700,
0.59744187729600,
0.5878255314650,
0.57835291380300,
0.56901798419800,
0.55981517091100,
0.5507393208770,
0.54178565668200,
0.53294973914500,
0.52422743462800,
0.5156148863730,
0.50710848925300,
0.49870486747800,
0.49040085481200,
0.4821934769860,
0.47407993601000,
0.46605759612500,
0.45812397121400,
0.4502767134670,
0.44251360317100,
0.43483253947300,
0.42723153202200,
0.4197086933790,
0.41226223212000,
0.40489044654800,
0.39759171895500,
0.3903645103820,
0.38320735581600,
0.37611885978800,
0.36909769233400,
0.3621425852820,
0.35525232883400,
0.34842576841500,
0.34166180177600,
0.3349593763110,
0.32831748658800,
0.32173517206300,
0.31521151497000,
0.3087456383670,
0.30233670433800,
0.29598391232000,
0.28968649757100,
0.2834437297390,
0.27725491156000,
0.27111937764900,
0.26503649338700,
0.2590056539120,
0.25302628318300,
0.24709783313900,
0.24121978293200,
0.2353916382390,
0.22961293064900,
0.22388321712200,
0.21820207951800,
0.2125691242010,
0.20698398170900,
0.20144630649600,
0.19595577674500,
0.1905120942560,
0.18511498440600,
0.17976419618500,
0.17445950232400,
0.1692006994920,
0.16398760860000,
0.15882007519500,
0.15369796996400,
0.1486211893480,
0.14358965629500,
0.13860332114300,
0.13366216266900,
0.1287661893090,
0.12391544058200,
0.11910998874500,
0.11434994070300,
0.1096354402300,
0.10496667053300,
0.10034385723200,
0.09576727182660,
0.0912372357329,
0.08675412501270,
0.08231837593200,
0.07793049152950,
0.0735910494266,
0.06930071117420,
0.06506023352900,
0.06087048217450,
0.0567324485840,
0.05264727098000,
0.04861626071630,
0.04464093597690,
0.0407230655415,
0.03686472673860,
0.03306838393780,
0.02933699774110,
0.0256741818288,
0.02208443726340,
0.01857352005770,
0.01514905528540,
0.0118216532614,
0.00860719483079,
0.00553245272614,
0.00265435214565,
];
const W: [f64; 128] = [
1.62318314817e-08,
2.16291505214e-08,
2.54246305087e-08,
2.84579525938e-08,
3.10340022482e-08,
3.33011726243e-08,
3.53439060345e-08,
3.72152672658e-08,
3.89509895720e-08,
4.05763964764e-08,
4.21101548915e-08,
4.35664624904e-08,
4.49563968336e-08,
4.62887864029e-08,
4.75707945735e-08,
4.88083237257e-08,
5.00063025384e-08,
5.11688950428e-08,
5.22996558616e-08,
5.34016475624e-08,
5.44775307871e-08,
5.55296344581e-08,
5.65600111659e-08,
5.75704813695e-08,
5.85626690412e-08,
5.95380306862e-08,
6.04978791776e-08,
6.14434034901e-08,
6.23756851626e-08,
6.32957121259e-08,
6.42043903937e-08,
6.51025540077e-08,
6.59909735447e-08,
6.68703634341e-08,
6.77413882848e-08,
6.86046683810e-08,
6.94607844804e-08,
7.03102820203e-08,
7.11536748229e-08,
7.19914483720e-08,
7.28240627230e-08,
7.36519550992e-08,
7.44755422158e-08,
7.52952223703e-08,
7.61113773308e-08,
7.69243740467e-08,
7.77345662086e-08,
7.85422956743e-08,
7.93478937793e-08,
8.01516825471e-08,
8.09539758128e-08,
8.17550802699e-08,
8.25552964535e-08,
8.33549196661e-08,
8.41542408569e-08,
8.49535474601e-08,
8.57531242006e-08,
8.65532538723e-08,
8.73542180955e-08,
8.81562980590e-08,
8.89597752521e-08,
8.97649321908e-08,
9.05720531451e-08,
9.13814248700e-08,
9.21933373471e-08,
9.30080845407e-08,
9.38259651738e-08,
9.46472835298e-08,
9.54723502847e-08,
9.63014833769e-08,
9.71350089201e-08,
9.79732621669e-08,
9.88165885297e-08,
9.96653446693e-08,
1.00519899658e-07,
1.01380636230e-07,
1.02247952126e-07,
1.03122261554e-07,
1.04003996769e-07,
1.04893609795e-07,
1.05791574313e-07,
1.06698387725e-07,
1.07614573423e-07,
1.08540683296e-07,
1.09477300508e-07,
1.10425042570e-07,
1.11384564771e-07,
1.12356564007e-07,
1.13341783071e-07,
1.14341015475e-07,
1.15355110887e-07,
1.16384981291e-07,
1.17431607977e-07,
1.18496049514e-07,
1.19579450872e-07,
1.20683053909e-07,
1.21808209468e-07,
1.22956391410e-07,
1.24129212952e-07,
1.25328445797e-07,
1.26556042658e-07,
1.27814163916e-07,
1.29105209375e-07,
1.30431856341e-07,
1.31797105598e-07,
1.33204337360e-07,
1.34657379914e-07,
1.36160594606e-07,
1.37718982103e-07,
1.39338316679e-07,
1.41025317971e-07,
1.42787873535e-07,
1.44635331499e-07,
1.46578891730e-07,
1.48632138436e-07,
1.50811780719e-07,
1.53138707402e-07,
1.55639532047e-07,
1.58348931426e-07,
1.61313325908e-07,
1.64596952856e-07,
1.68292495203e-07,
1.72541128694e-07,
1.77574279496e-07,
1.83813550477e-07,
1.92166040885e-07,
2.05295471952e-07,
2.22600839893e-07,
];
#[cfg(test)]
mod tests {
use super::*;
use crate::statistics::{mean, std};
use approx_eq::assert_approx_eq;
#[test]
fn test_moments() {
let data1 = Normal::new(0., 1.).sample_n(1e6 as usize);
assert_approx_eq!(0., mean(&data1), 1e-2);
assert_approx_eq!(1., std(&data1), 1e-2);
let data2 = Normal::new(10., 20.).sample_n(1e6 as usize);
assert_approx_eq!(10., mean(&data2), 1e-2);
assert_approx_eq!(20., std(&data2), 1e-2);
}
#[test]
fn test_cdf() {
let x = vec![-4., -3.9, -2.81, -2.67, -2.01, 0.01, 0.75, 1.5, 1.79];
let y = vec![
3.167124183311986e-05,
4.8096344017602614e-05,
0.002477074998785861,
0.0037925623476854887,
0.022215594429431475,
0.5039893563146316,
0.7733726476231317,
0.9331927987311419,
0.9632730443012737,
];
assert_eq!(x.len(), y.len());
let sn = Normal::new(0., 1.);
for i in 0..x.len() {
assert_approx_eq!(sn.cdf(x[i]), y[i], 1e-3);
}
}
}