compute 0.2.3

A crate for statistical computing.
Documentation
#![allow(clippy::many_single_char_names)]

use crate::{distributions::*, prelude::erf};
use std::f64::consts::PI;

/// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) distribution.
#[derive(Debug, Clone, Copy)]
pub struct Normal {
    /// Mean (or location) parameter.
    mu: f64,
    /// Standard deviation (or scale) parameter.
    sigma: f64,
}

impl Normal {
    /// Create a new Normal distribution with mean `mu` and standard deviation `sigma`.
    ///
    /// # Errors
    /// Panics if `sigma < 0`.
    pub fn new(mu: f64, sigma: f64) -> Self {
        if sigma < 0. {
            panic!("Sigma must be non-negative.")
        }
        Normal { mu, sigma }
    }
    pub fn set_mu(&mut self, mu: f64) -> &mut Self {
        self.mu = mu;
        self
    }
    pub fn set_sigma(&mut self, sigma: f64) -> &mut Self {
        if sigma < 0. {
            panic!("Sigma must be non-negative.")
        }
        self.sigma = sigma;
        self
    }
    /// TODO: make `cdf` a method of the `Continuous` trait.
    pub fn cdf(&self, x: f64) -> f64 {
        0.5 * (1. + erf((x - self.mu) / (self.sigma * 2_f64.sqrt())))
    }
}

impl Default for Normal {
    fn default() -> Self {
        Self::new(0., 1.)
    }
}

impl Distribution for Normal {
    type Output = f64;
    /// Sample from the given Normal distribution.
    fn sample(&self) -> f64 {
        loop {
            let u = alea::u64();

            let i = (u & 0x7F) as usize;
            let j = ((u >> 8) & 0xFFFFFF) as u32;
            let s = if u & 0x80 != 0 { 1.0 } else { -1.0 };

            if j < K[i] {
                let x = j as f64 * W[i];
                return s * x * self.sigma + self.mu;
            }

            let (x, y) = if i < 127 {
                let x = j as f64 * W[i];
                let y = Y[i + 1] + (Y[i] - Y[i + 1]) * alea::f64();
                (x, y)
            } else {
                let x = R - (-alea::f64()).ln_1p() / R;
                let y = (-R * (x - 0.5 * R)).exp() * alea::f64();
                (x, y)
            };

            if y < (-0.5 * x * x).exp() {
                return s * x * self.sigma + self.mu;
            }
        }
    }
}

impl Distribution1D for Normal {
    fn update(&mut self, params: &[f64]) {
        self.set_mu(params[0]).set_sigma(params[1]);
    }
}

impl Continuous for Normal {
    type PDFType = f64;
    /// Calculates the probability density function of the given Normal distribution at `x`.
    fn pdf(&self, x: f64) -> f64 {
        1. / (self.sigma * (2. * PI).sqrt()) * (-0.5 * ((x - self.mu) / self.sigma).powi(2)).exp()
    }

    fn ln_pdf(&self, x: Self::PDFType) -> f64 {
        -0.5 * ((x - self.mu) / self.sigma).powi(2) - (self.sigma * (2. * PI).sqrt()).ln()
    }
}

impl Mean for Normal {
    type MeanType = f64;
    /// Returns the mean of the given Normal distribution.
    fn mean(&self) -> f64 {
        self.mu
    }
}

impl Variance for Normal {
    type VarianceType = f64;
    /// Returns the variance of the given Normal distribution.
    fn var(&self) -> f64 {
        self.sigma
    }
}

#[test]
fn maxprob() {
    let n = self::Normal::new(5., 4.);
    (0..20).for_each(|x| {
        assert!(n.pdf(5.) >= n.pdf(x as f64));
    });
    assert!(n.pdf(5.) > n.pdf(2.));
    assert!(n.pdf(5.) > n.pdf(6.));
}

const R: f64 = 3.44428647676;

const K: [u32; 128] = [
    00000000, 12590644, 14272653, 14988939, 15384584, 15635009, 15807561, 15933577, 16029594,
    16105155, 16166147, 16216399, 16258508, 16294295, 16325078, 16351831, 16375291, 16396026,
    16414479, 16431002, 16445880, 16459343, 16471578, 16482744, 16492970, 16502368, 16511031,
    16519039, 16526459, 16533352, 16539769, 16545755, 16551348, 16556584, 16561493, 16566101,
    16570433, 16574511, 16578353, 16581977, 16585398, 16588629, 16591685, 16594575, 16597311,
    16599901, 16602354, 16604679, 16606881, 16608968, 16610945, 16612818, 16614592, 16616272,
    16617861, 16619363, 16620782, 16622121, 16623383, 16624570, 16625685, 16626730, 16627708,
    16628619, 16629465, 16630248, 16630969, 16631628, 16632228, 16632768, 16633248, 16633671,
    16634034, 16634340, 16634586, 16634774, 16634903, 16634972, 16634980, 16634926, 16634810,
    16634628, 16634381, 16634066, 16633680, 16633222, 16632688, 16632075, 16631380, 16630598,
    16629726, 16628757, 16627686, 16626507, 16625212, 16623794, 16622243, 16620548, 16618698,
    16616679, 16614476, 16612071, 16609444, 16606571, 16603425, 16599973, 16596178, 16591995,
    16587369, 16582237, 16576520, 16570120, 16562917, 16554758, 16545450, 16534739, 16522287,
    16507638, 16490152, 16468907, 16442518, 16408804, 16364095, 16301683, 16207738, 16047994,
    15704248, 15472926,
];

const Y: [f64; 128] = [
    1.0000000000000,
    0.96359862301100,
    0.93628081335300,
    0.91304110425300,
    0.8922785066960,
    0.87323935691900,
    0.85549640763400,
    0.83877892834900,
    0.8229020836990,
    0.80773273823400,
    0.79317104551900,
    0.77913972650500,
    0.7655774360820,
    0.75243445624800,
    0.73966978767700,
    0.72724912028500,
    0.7151433774130,
    0.70332764645500,
    0.69178037703500,
    0.68048276891000,
    0.6694182972330,
    0.65857233912000,
    0.64793187618900,
    0.63748525489600,
    0.6272219914500,
    0.61713261153200,
    0.60720851746700,
    0.59744187729600,
    0.5878255314650,
    0.57835291380300,
    0.56901798419800,
    0.55981517091100,
    0.5507393208770,
    0.54178565668200,
    0.53294973914500,
    0.52422743462800,
    0.5156148863730,
    0.50710848925300,
    0.49870486747800,
    0.49040085481200,
    0.4821934769860,
    0.47407993601000,
    0.46605759612500,
    0.45812397121400,
    0.4502767134670,
    0.44251360317100,
    0.43483253947300,
    0.42723153202200,
    0.4197086933790,
    0.41226223212000,
    0.40489044654800,
    0.39759171895500,
    0.3903645103820,
    0.38320735581600,
    0.37611885978800,
    0.36909769233400,
    0.3621425852820,
    0.35525232883400,
    0.34842576841500,
    0.34166180177600,
    0.3349593763110,
    0.32831748658800,
    0.32173517206300,
    0.31521151497000,
    0.3087456383670,
    0.30233670433800,
    0.29598391232000,
    0.28968649757100,
    0.2834437297390,
    0.27725491156000,
    0.27111937764900,
    0.26503649338700,
    0.2590056539120,
    0.25302628318300,
    0.24709783313900,
    0.24121978293200,
    0.2353916382390,
    0.22961293064900,
    0.22388321712200,
    0.21820207951800,
    0.2125691242010,
    0.20698398170900,
    0.20144630649600,
    0.19595577674500,
    0.1905120942560,
    0.18511498440600,
    0.17976419618500,
    0.17445950232400,
    0.1692006994920,
    0.16398760860000,
    0.15882007519500,
    0.15369796996400,
    0.1486211893480,
    0.14358965629500,
    0.13860332114300,
    0.13366216266900,
    0.1287661893090,
    0.12391544058200,
    0.11910998874500,
    0.11434994070300,
    0.1096354402300,
    0.10496667053300,
    0.10034385723200,
    0.09576727182660,
    0.0912372357329,
    0.08675412501270,
    0.08231837593200,
    0.07793049152950,
    0.0735910494266,
    0.06930071117420,
    0.06506023352900,
    0.06087048217450,
    0.0567324485840,
    0.05264727098000,
    0.04861626071630,
    0.04464093597690,
    0.0407230655415,
    0.03686472673860,
    0.03306838393780,
    0.02933699774110,
    0.0256741818288,
    0.02208443726340,
    0.01857352005770,
    0.01514905528540,
    0.0118216532614,
    0.00860719483079,
    0.00553245272614,
    0.00265435214565,
];

const W: [f64; 128] = [
    1.62318314817e-08,
    2.16291505214e-08,
    2.54246305087e-08,
    2.84579525938e-08,
    3.10340022482e-08,
    3.33011726243e-08,
    3.53439060345e-08,
    3.72152672658e-08,
    3.89509895720e-08,
    4.05763964764e-08,
    4.21101548915e-08,
    4.35664624904e-08,
    4.49563968336e-08,
    4.62887864029e-08,
    4.75707945735e-08,
    4.88083237257e-08,
    5.00063025384e-08,
    5.11688950428e-08,
    5.22996558616e-08,
    5.34016475624e-08,
    5.44775307871e-08,
    5.55296344581e-08,
    5.65600111659e-08,
    5.75704813695e-08,
    5.85626690412e-08,
    5.95380306862e-08,
    6.04978791776e-08,
    6.14434034901e-08,
    6.23756851626e-08,
    6.32957121259e-08,
    6.42043903937e-08,
    6.51025540077e-08,
    6.59909735447e-08,
    6.68703634341e-08,
    6.77413882848e-08,
    6.86046683810e-08,
    6.94607844804e-08,
    7.03102820203e-08,
    7.11536748229e-08,
    7.19914483720e-08,
    7.28240627230e-08,
    7.36519550992e-08,
    7.44755422158e-08,
    7.52952223703e-08,
    7.61113773308e-08,
    7.69243740467e-08,
    7.77345662086e-08,
    7.85422956743e-08,
    7.93478937793e-08,
    8.01516825471e-08,
    8.09539758128e-08,
    8.17550802699e-08,
    8.25552964535e-08,
    8.33549196661e-08,
    8.41542408569e-08,
    8.49535474601e-08,
    8.57531242006e-08,
    8.65532538723e-08,
    8.73542180955e-08,
    8.81562980590e-08,
    8.89597752521e-08,
    8.97649321908e-08,
    9.05720531451e-08,
    9.13814248700e-08,
    9.21933373471e-08,
    9.30080845407e-08,
    9.38259651738e-08,
    9.46472835298e-08,
    9.54723502847e-08,
    9.63014833769e-08,
    9.71350089201e-08,
    9.79732621669e-08,
    9.88165885297e-08,
    9.96653446693e-08,
    1.00519899658e-07,
    1.01380636230e-07,
    1.02247952126e-07,
    1.03122261554e-07,
    1.04003996769e-07,
    1.04893609795e-07,
    1.05791574313e-07,
    1.06698387725e-07,
    1.07614573423e-07,
    1.08540683296e-07,
    1.09477300508e-07,
    1.10425042570e-07,
    1.11384564771e-07,
    1.12356564007e-07,
    1.13341783071e-07,
    1.14341015475e-07,
    1.15355110887e-07,
    1.16384981291e-07,
    1.17431607977e-07,
    1.18496049514e-07,
    1.19579450872e-07,
    1.20683053909e-07,
    1.21808209468e-07,
    1.22956391410e-07,
    1.24129212952e-07,
    1.25328445797e-07,
    1.26556042658e-07,
    1.27814163916e-07,
    1.29105209375e-07,
    1.30431856341e-07,
    1.31797105598e-07,
    1.33204337360e-07,
    1.34657379914e-07,
    1.36160594606e-07,
    1.37718982103e-07,
    1.39338316679e-07,
    1.41025317971e-07,
    1.42787873535e-07,
    1.44635331499e-07,
    1.46578891730e-07,
    1.48632138436e-07,
    1.50811780719e-07,
    1.53138707402e-07,
    1.55639532047e-07,
    1.58348931426e-07,
    1.61313325908e-07,
    1.64596952856e-07,
    1.68292495203e-07,
    1.72541128694e-07,
    1.77574279496e-07,
    1.83813550477e-07,
    1.92166040885e-07,
    2.05295471952e-07,
    2.22600839893e-07,
];

#[cfg(test)]
mod tests {

    use super::*;
    use crate::statistics::{mean, std};
    use approx_eq::assert_approx_eq;

    #[test]
    fn test_moments() {
        let data1 = Normal::new(0., 1.).sample_n(1e6 as usize);
        assert_approx_eq!(0., mean(&data1), 1e-2);
        assert_approx_eq!(1., std(&data1), 1e-2);

        let data2 = Normal::new(10., 20.).sample_n(1e6 as usize);
        assert_approx_eq!(10., mean(&data2), 1e-2);
        assert_approx_eq!(20., std(&data2), 1e-2);
    }

    #[test]
    fn test_cdf() {
        let x = vec![-4., -3.9, -2.81, -2.67, -2.01, 0.01, 0.75, 1.5, 1.79];
        let y = vec![
            3.167124183311986e-05,
            4.8096344017602614e-05,
            0.002477074998785861,
            0.0037925623476854887,
            0.022215594429431475,
            0.5039893563146316,
            0.7733726476231317,
            0.9331927987311419,
            0.9632730443012737,
        ];
        assert_eq!(x.len(), y.len());

        let sn = Normal::new(0., 1.);

        for i in 0..x.len() {
            assert_approx_eq!(sn.cdf(x[i]), y[i], 1e-3);
        }
    }
}