rs_stats/distributions/
normal_distribution.rs

1use crate::prob::erf::erf;
2use rand::Rng;
3use rand_distr::{Distribution, Normal as RandNormal};
4use serde::{Deserialize, Serialize};
5use std::f64::consts::PI;
6
7/// Configuration for the Normal distribution.
8///
9/// # Fields
10/// * `mean` - The mean (location parameter)
11/// * `std_dev` - The standard deviation (scale parameter, must be positive)
12///
13/// # Examples
14/// ```
15/// use rs_stats::distributions::normal_distribution::NormalConfig;
16///
17/// let config = NormalConfig { mean: 0.0, std_dev: 1.0 };
18/// assert!(config.std_dev > 0.0);
19/// ```
20#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
21pub struct NormalConfig {
22    /// The mean (μ) of the distribution.
23    pub mean: f64,
24    /// The standard deviation (σ) of the distribution.
25    pub std_dev: f64,
26}
27
28impl NormalConfig {
29    /// Creates a new NormalConfig with validation
30    ///
31    /// # Arguments
32    /// * `mean` - The mean of the distribution
33    /// * `std_dev` - The standard deviation of the distribution
34    ///
35    /// # Returns
36    /// `Some(NormalConfig)` if parameters are valid, `None` otherwise
37    ///
38    /// # Examples
39    /// ```
40    /// use rs_stats::distributions::normal_distribution::NormalConfig;
41    ///
42    /// let standard_normal = NormalConfig::new(0.0, 1.0);
43    /// assert!(standard_normal.is_some());
44    ///
45    /// let invalid_config = NormalConfig::new(0.0, -1.0);
46    /// assert!(invalid_config.is_none());
47    /// ```
48    pub fn new(mean: f64, std_dev: f64) -> Option<Self> {
49        if std_dev > 0.0 && !mean.is_nan() && !std_dev.is_nan() {
50            Some(Self { mean, std_dev })
51        } else {
52            None
53        }
54    }
55}
56
57/// Calculates the probability density function (PDF) for the normal distribution.
58///
59/// # Arguments
60/// * `x` - The value at which to evaluate the PDF
61/// * `mean` - The mean (μ) of the distribution
62/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
63///
64/// # Returns
65/// The probability density at point x
66///
67/// # Panics
68/// Panics if std_dev is not positive.
69///
70/// # Examples
71/// ```
72/// use rs_stats::distributions::normal_distribution::normal_pdf;
73///
74/// // Standard normal distribution at x = 0
75/// let pdf = normal_pdf(0.0, 0.0, 1.0);
76/// assert!((pdf - 0.3989422804014327).abs() < 1e-10);
77///
78/// // Normal distribution with mean = 5, std_dev = 2 at x = 5
79/// let pdf = normal_pdf(5.0, 5.0, 2.0);
80/// assert!((pdf - 0.19947114020071635).abs() < 1e-10);
81/// ```
82pub fn normal_pdf(x: f64, mean: f64, std_dev: f64) -> f64 {
83    assert!(std_dev > 0.0, "Standard deviation must be positive");
84
85    let exponent = -0.5 * ((x - mean) / std_dev).powi(2);
86    (1.0 / (std_dev * (2.0 * PI).sqrt())) * exponent.exp()
87}
88
89/// Calculates the cumulative distribution function (CDF) for the normal distribution.
90///
91/// # Arguments
92/// * `x` - The value at which to evaluate the CDF
93/// * `mean` - The mean (μ) of the distribution
94/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
95///
96/// # Returns
97/// The probability that a random variable is less than or equal to x
98///
99/// # Panics
100/// Panics if std_dev is not positive.
101///
102/// # Examples
103/// ```
104/// use rs_stats::distributions::normal_distribution::normal_cdf;
105///
106/// // Standard normal distribution at x = 0
107/// let cdf = normal_cdf(0.0, 0.0, 1.0);
108/// assert!((cdf - 0.5).abs() < 1e-7);
109///
110/// // Normal distribution with mean = 5, std_dev = 2 at x = 7
111/// let cdf = normal_cdf(7.0, 5.0, 2.0);
112/// assert!((cdf - 0.8413447460685429).abs() < 1e-7);
113/// ```
114pub fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> f64 {
115    assert!(std_dev > 0.0, "Standard deviation must be positive");
116
117    // Special case to handle exact value at the mean
118    if x == mean {
119        return 0.5;
120    }
121
122    // Calculate the standardized value z
123    let z = (x - mean) / std_dev;
124
125    // Use a more numerically stable form of the calculation
126    // The sqrt(2) factor is included in the argument to erf
127    0.5 * (1.0 + erf(z / std::f64::consts::SQRT_2))
128}
129
130/// Calculates the inverse cumulative distribution function (Quantile function) for the normal distribution.
131///
132/// # Arguments
133/// * `p` - Probability value between 0 and 1
134/// * `mean` - The mean (μ) of the distribution
135/// * `sigma` - The standard deviation (σ) of the distribution
136///
137/// # Returns
138/// The value x such that P(X ≤ x) = p
139///
140/// # Examples
141/// ```
142/// use rs_stats::distributions::normal_distribution::{normal_cdf, normal_inverse_cdf};
143///
144/// // Check that inverse_cdf is the inverse of cdf
145/// let x = 0.5;
146/// let p = normal_cdf(x, 0.0, 1.0);
147/// let x_back = normal_inverse_cdf(p, 0.0, 1.0);
148/// assert!((x - x_back).abs() < 1e-8);
149/// ```
150pub fn normal_inverse_cdf(p: f64, mean: f64, sigma: f64) -> f64 {
151    assert!(
152        (0.0..=1.0).contains(&p),
153        "Probability must be between 0 and 1"
154    );
155
156    // Handle edge cases
157    if p == 0.0 {
158        return f64::NEG_INFINITY;
159    }
160    if p == 1.0 {
161        return f64::INFINITY;
162    }
163
164    // Use a simple and reliable implementation based on the Rational Approximation
165    // by Peter J. Acklam
166
167    // Convert to standard normal calculation
168    let q = if p <= 0.5 { p } else { 1.0 - p };
169
170    // Keep track of whether we need to flip the sign at the end
171    let flip_sign = p > 0.5;
172
173    // Avoid numerical issues at boundaries
174    if q <= 0.0 {
175        return if p <= 0.5 {
176            f64::NEG_INFINITY
177        } else {
178            f64::INFINITY
179        };
180    }
181
182    // Coefficients for central region (small |z|)
183    let a = [
184        -3.969_683_028_665_376e1,
185        2.209_460_984_245_205e2,
186        -2.759_285_104_469_687e2,
187        1.383_577_518_672_69e2,
188        -3.066_479_806_614_716e1,
189        2.506_628_277_459_239,
190    ];
191
192    let b = [
193        -5.447_609_879_822_406e1,
194        1.615_858_368_580_409e2,
195        -1.556_989_798_598_866e2,
196        6.680_131_188_771_972e1,
197        -1.328_068_155_288_572e1,
198        1.0,
199    ];
200
201    // Compute rational approximation
202    let r = q - 0.5;
203
204    let z = if q > 0.02425 && q < 0.97575 {
205        // Central region
206        let r2 = r * r;
207        let num = ((((a[0] * r2 + a[1]) * r2 + a[2]) * r2 + a[3]) * r2 + a[4]) * r2 + a[5];
208        let den = ((((b[0] * r2 + b[1]) * r2 + b[2]) * r2 + b[3]) * r2 + b[4]) * r2 + b[5];
209        r * num / den
210    } else {
211        // Tail region
212        let s = if r < 0.0 { q } else { 1.0 - q };
213        let t = (-2.0 * s.ln()).sqrt();
214
215        // Rational approximation for tail
216        let c = [
217            -7.784_894_002_430_293e-3,
218            -3.223_964_580_411_365e-1,
219            -2.400_758_277_161_838,
220            -2.549_732_539_343_734,
221            4.374_664_141_464_968,
222            2.938_163_982_698_783,
223        ];
224
225        let d = [
226            7.784_695_709_041_462e-3,
227            3.224_671_290_700_398e-1,
228            2.445_134_137_142_996,
229            3.754_408_661_907_416,
230            1.0,
231        ];
232
233        let num = ((((c[0] * t + c[1]) * t + c[2]) * t + c[3]) * t + c[4]) * t + c[5];
234        let den = (((d[0] * t + d[1]) * t + d[2]) * t + d[3]) * t + d[4];
235        if r < 0.0 {
236            -t - num / den
237        } else {
238            t - num / den
239        }
240    };
241
242    // If p > 0.5, we need to flip the sign of z
243    let final_z = if flip_sign { -z } else { z };
244
245    // Convert from standard normal to the specified distribution
246    mean + sigma * final_z
247}
248
249/// Generates a random sample from the normal distribution.
250///
251/// # Arguments
252/// * `mean` - The mean (μ) of the distribution
253/// * `sigma` - The standard deviation (σ) of the distribution
254/// * `rng` - A random number generator
255///
256/// # Returns
257/// A random value from the normal distribution
258///
259/// # Examples
260/// ```
261/// use rs_stats::distributions::normal_distribution::normal_sample;
262/// use rand::thread_rng;
263///
264/// let mut rng = thread_rng();
265/// let sample = normal_sample(10.0, 2.0, &mut rng);
266/// // sample is a random value from Normal(10, 2)
267/// ```
268pub fn normal_sample<R: Rng>(mean: f64, sigma: f64, rng: &mut R) -> f64 {
269    let normal = RandNormal::new(mean, sigma).unwrap();
270    normal.sample(rng)
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    // Small epsilon for floating-point comparisons
278    const EPSILON: f64 = 1e-7;
279
280    #[test]
281    fn test_normal_pdf_standard() {
282        let mean = 0.0;
283        let sigma = 1.0;
284
285        // Test at mean (peak of the density)
286        let result = normal_pdf(mean, mean, sigma);
287        assert!((result - 0.3989422804014327).abs() < 1e-10);
288
289        // Test at one standard deviation away
290        let result = normal_pdf(mean + sigma, mean, sigma);
291        assert!((result - 0.24197072451914337).abs() < 1e-10);
292    }
293
294    #[test]
295    fn test_normal_pdf_non_standard() {
296        let mean = 5.0;
297        let sigma = 2.0;
298
299        // Test at mean
300        let result = normal_pdf(mean, mean, sigma);
301        assert!((result - 0.19947114020071635).abs() < 1e-10);
302
303        // Test at one standard deviation away
304        let result = normal_pdf(mean + sigma, mean, sigma);
305        assert!((result - 0.12098536225957168).abs() < 1e-10);
306    }
307
308    #[test]
309    fn test_normal_pdf_symmetry() {
310        let mean = 0.0;
311        let sigma = 1.0;
312        let x = 1.5;
313
314        let pdf_plus = normal_pdf(mean + x, mean, sigma);
315        let pdf_minus = normal_pdf(mean - x, mean, sigma);
316
317        assert!((pdf_plus - pdf_minus).abs() < 1e-10);
318    }
319
320    #[test]
321    fn test_normal_cdf_standard() {
322        let mean = 0.0;
323        let sigma = 1.0;
324
325        // Test at mean
326        let result = normal_cdf(mean, mean, sigma);
327        assert!((result - 0.5).abs() < 1e-10);
328
329        // Test at one standard deviation above mean
330        let result = normal_cdf(mean + sigma, mean, sigma);
331        assert!((result - 0.8413447460685429).abs() < EPSILON);
332
333        // Test at one standard deviation below mean
334        let result = normal_cdf(mean - sigma, mean, sigma);
335        assert!((result - 0.15865525393145707).abs() < EPSILON);
336    }
337
338    #[test]
339    fn test_normal_cdf_non_standard() {
340        let mean = 100.0;
341        let sigma = 15.0;
342
343        // Test at mean
344        let result = normal_cdf(mean, mean, sigma);
345        assert!((result - 0.5).abs() < 1e-10);
346
347        // Test at one standard deviation above mean
348        let result = normal_cdf(mean + sigma, mean, sigma);
349        assert!((result - 0.8413447460685429).abs() < EPSILON);
350    }
351
352    #[test]
353    fn test_normal_inverse_cdf() {
354        let mean = 0.0;
355        let sigma = 1.0;
356
357        // Test at median
358        let result = normal_inverse_cdf(0.5, mean, sigma);
359        assert!((result - mean).abs() < EPSILON);
360
361        // Test at one standard deviation above mean
362        let result = normal_inverse_cdf(0.8413447460685429, mean, sigma);
363        assert!((result - sigma).abs() < EPSILON);
364
365        // Test at one standard deviation below mean
366        let result = normal_inverse_cdf(0.15865525393145707, mean, sigma);
367        assert!((result - (-sigma)).abs() < EPSILON);
368    }
369
370    #[test]
371    fn test_normal_inverse_cdf_non_standard() {
372        let mean = 50.0;
373        let sigma = 5.0;
374
375        // Test at median
376        let result = normal_inverse_cdf(0.5, mean, sigma);
377        assert!((result - mean).abs() < EPSILON);
378
379        // Test at one standard deviation above mean
380        let result = normal_inverse_cdf(0.8413447460685429, mean, sigma);
381        assert!((result - (mean + sigma)).abs() < EPSILON);
382    }
383
384    #[test]
385    fn test_normal_pdf_standard_normal() {
386        // PDF for standard normal at mean should be maximum (approx 0.3989)
387        let pdf = (normal_pdf(0.0, 0.0, 1.0) * 1e7).round() / 1e7;
388        assert!((pdf - 0.3989423).abs() < EPSILON);
389
390        // Test symmetry around mean
391        let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0);
392        let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0);
393        assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
394
395        // Test at specific points
396        assert!((normal_pdf(1.0, 0.0, 1.0) - 0.2419707).abs() < EPSILON);
397        assert!((normal_pdf(2.0, 0.0, 1.0) - 0.0539909).abs() < EPSILON);
398    }
399
400    #[test]
401    #[should_panic(expected = "Standard deviation must be positive")]
402    fn test_normal_pdf_invalid_sigma() {
403        normal_pdf(0.0, 0.0, -1.0);
404    }
405
406    #[test]
407    fn test_normal_cdf_standard_normal() {
408        // CDF at mean should be 0.5
409        let cdf = (normal_cdf(0.0, 0.0, 1.0) * 1e1).round() / 1e1;
410        assert!((cdf - 0.5).abs() < EPSILON);
411
412        // Test at specific points
413        let cdf = (normal_cdf(1.0, 0.0, 1.0) * 1e7).round() / 1e7;
414        assert!((cdf - 0.8413447).abs() < EPSILON);
415
416        let cdf = (normal_cdf(-1.0, 0.0, 1.0) * 1e7).round() / 1e7;
417        assert!((cdf - 0.1586553).abs() < EPSILON);
418
419        let cdf = (normal_cdf(2.0, 0.0, 1.0) * 1e7).round() / 1e7;
420        assert!((cdf - 0.9772499).abs() < EPSILON);
421    }
422
423    #[test]
424    #[should_panic(expected = "Standard deviation must be positive")]
425    fn test_normal_cdf_invalid_sigma() {
426        normal_cdf(0.0, 0.0, -1.0);
427    }
428
429    #[test]
430    fn test_normal_inverse_cdf_standard_normal() {
431        // Inverse CDF of 0.5 should be the mean (0)
432        let x = (normal_inverse_cdf(0.5, 0.0, 1.0) * 1e7).round() / 1e7;
433        assert!(x.abs() < EPSILON);
434
435        // Test at specific probabilities
436        assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0) - 1.0).abs() < 0.01);
437        assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0) + 1.0).abs() < 0.01);
438    }
439}