Skip to main content

rs_stats/distributions/
normal_distribution.rs

1//! # Normal Distribution
2//!
3//! The Normal (Gaussian) distribution N(μ, σ) is the most widely used continuous
4//! distribution, arising naturally as the limiting distribution of sums and averages
5//! of independent random variables (Central Limit Theorem).
6//!
7//! **PDF**: f(x) = 1/(σ√(2π)) · exp(−(x−μ)²/(2σ²))
8//!
9//! **CDF**: F(x) = Φ((x−μ)/σ), where Φ is the standard normal CDF
10//!
11//! ## Medical applications
12//!
13//! | Measurement | Typical parameters |
14//! |-------------|-------------------|
15//! | **Systolic blood pressure** (healthy adults) | N(120, 10) mmHg |
16//! | **Diastolic blood pressure** (healthy adults) | N(80, 8) mmHg |
17//! | **Adult height** (men, Western population) | N(175, 7) cm |
18//! | **Haemoglobin** (adult men) | N(14.5, 1.0) g/dL |
19//! | **Body temperature** | N(37.0, 0.4) °C |
20//! | **IQ scores** (by design) | N(100, 15) |
21//! | **Lab measurement error** | N(0, σ_instrument) |
22//!
23//! ## Example — blood pressure reference intervals
24//!
25//! ```rust
26//! use rs_stats::distributions::normal_distribution::Normal;
27//! use rs_stats::distributions::traits::Distribution;
28//!
29//! // Diastolic BP in a healthy cohort: N(80, 8) mmHg
30//! let bp = Normal::new(80.0, 8.0).unwrap();
31//!
32//! // P(DBP > 90 mmHg) — stage 1 hypertension threshold
33//! let p_high = 1.0 - bp.cdf(90.0).unwrap();
34//! println!("P(DBP > 90 mmHg) = {:.1}%", p_high * 100.0);  // ≈ 10.6%
35//!
36//! // 95% reference interval (2.5th – 97.5th percentile)
37//! let lower = bp.inverse_cdf(0.025).unwrap();
38//! let upper = bp.inverse_cdf(0.975).unwrap();
39//! println!("Reference interval: [{:.1}, {:.1}] mmHg", lower, upper);
40//!
41//! // Fit to patient data (MLE: μ̂ = mean, σ̂ = pop std-dev)
42//! let readings = vec![78.0, 82.0, 79.0, 85.0, 81.0, 77.0, 83.0, 80.0];
43//! let fitted = Normal::fit(&readings).unwrap();
44//! println!("Fitted μ = {:.2}, σ = {:.2}", fitted.mean(), fitted.std_dev());
45//! ```
46
47use crate::distributions::traits::Distribution;
48use crate::error::{StatsError, StatsResult};
49use crate::prob::erf;
50use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
51
52// Private math helpers; the public API is the [`Normal`] struct's
53// [`Distribution`] impl below.
54
55/// Calculates the probability density function (PDF) for the normal distribution.
56///
57/// # Arguments
58/// * `x` - The value at which to evaluate the PDF
59/// * `mean` - The mean (μ) of the distribution
60/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
61///
62/// # Returns
63/// The probability density at point x
64///
65/// # Errors
66/// Returns an error if:
67/// - std_dev is not positive
68/// - Type conversion to f64 fails
69///
70#[inline]
71fn normal_pdf(x: f64, mean: f64, std_dev: f64) -> StatsResult<f64> {
72    if std_dev <= 0.0 {
73        return Err(StatsError::InvalidInput {
74            message: "normal_pdf: standard deviation must be positive".to_string(),
75        });
76    }
77    let z = (x - mean) / std_dev;
78    Ok((-0.5 * z * z).exp() * INV_SQRT_2PI / std_dev)
79}
80
81/// Calculates the cumulative distribution function (CDF) for the normal distribution.
82///
83/// # Arguments
84/// * `x` - The value at which to evaluate the CDF
85/// * `mean` - The mean (μ) of the distribution
86/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
87///
88/// # Returns
89/// The probability that a random variable is less than or equal to x
90///
91/// # Errors
92/// Returns an error if:
93/// - std_dev is not positive
94/// - Type conversion to f64 fails
95///
96#[inline]
97pub(crate) fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> StatsResult<f64> {
98    if std_dev <= 0.0 {
99        return Err(StatsError::InvalidInput {
100            message: "normal_cdf: standard deviation must be positive".to_string(),
101        });
102    }
103    if x == mean {
104        return Ok(0.5);
105    }
106    let z = (x - mean) / (std_dev * SQRT_2);
107    Ok(0.5 * (1.0 + erf(z)?))
108}
109
110/// Calculates the inverse cumulative distribution function (Quantile function) for the normal distribution.
111///
112/// # Arguments
113/// * `p` - Probability value between 0 and 1
114/// * `mean` - The mean (μ) of the distribution
115/// * `std_dev` - The standard deviation (σ) of the distribution
116///
117/// # Returns
118/// The value x such that P(X ≤ x) = p
119///
120#[inline]
121pub(crate) fn normal_inverse_cdf(p: f64, mean: f64, std_dev: f64) -> StatsResult<f64> {
122    let p_64 = p;
123
124    if !(0.0..=1.0).contains(&p_64) {
125        return Err(StatsError::InvalidInput {
126            message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
127        });
128    }
129
130    // Handle edge cases
131    if p_64 == 0.0 {
132        return Ok(f64::NEG_INFINITY);
133    }
134    if p_64 == 1.0 {
135        return Ok(f64::INFINITY);
136    }
137
138    // Acklam's rational approximation for the inverse standard normal CDF
139    // (https://web.archive.org/web/20151030215612/http://home.online.no/~pjacklam/notes/invnorm/),
140    // accurate to ~1.15 × 10⁻⁹ over the entire support.
141
142    // Coefficients — central region (|p − 0.5| ≤ 0.47575)
143    let a = [
144        -3.969_683_028_665_376e1,
145        2.209_460_984_245_205e2,
146        -2.759_285_104_469_687e2,
147        1.383_577_518_672_69e2,
148        -3.066_479_806_614_716e1,
149        2.506_628_277_459_239,
150    ];
151    let b = [
152        -5.447_609_879_822_406e1,
153        1.615_858_368_580_409e2,
154        -1.556_989_798_598_866e2,
155        6.680_131_188_771_972e1,
156        -1.328_068_155_288_572e1,
157        1.0,
158    ];
159    // Coefficients — tail region
160    let c = [
161        -7.784_894_002_430_293e-3,
162        -3.223_964_580_411_365e-1,
163        -2.400_758_277_161_838,
164        -2.549_732_539_343_734,
165        4.374_664_141_464_968,
166        2.938_163_982_698_783,
167    ];
168    let d = [
169        7.784_695_709_041_462e-3,
170        3.224_671_290_700_398e-1,
171        2.445_134_137_142_996,
172        3.754_408_661_907_416,
173    ];
174
175    const P_LOW: f64 = 0.02425;
176    const P_HIGH: f64 = 1.0 - P_LOW;
177
178    let z = if p_64 < P_LOW {
179        // Lower tail
180        let q = (-2.0 * p_64.ln()).sqrt();
181        let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
182        let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
183        num / den
184    } else if p_64 > P_HIGH {
185        // Upper tail
186        let q = (-2.0 * (1.0 - p_64).ln()).sqrt();
187        let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
188        let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
189        -num / den
190    } else {
191        // Central region
192        let q = p_64 - 0.5;
193        let r = q * q;
194        let num = ((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5];
195        let den = ((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + b[5];
196        q * num / den
197    };
198
199    Ok(mean + std_dev * z)
200}
201
202// ── Typed struct + Distribution impl ──────────────────────────────────────────
203
204/// Normal (Gaussian) distribution N(μ, σ²) as a typed struct.
205///
206/// Implements [`Distribution`] for use with `fit_all` / `fit_best`.
207///
208/// # Examples
209/// ```
210/// use rs_stats::distributions::normal_distribution::Normal;
211/// use rs_stats::distributions::traits::Distribution;
212///
213/// let n = Normal::new(0.0, 1.0).unwrap();
214/// assert!((n.mean() - 0.0).abs() < 1e-10);
215/// assert!((n.pdf(0.0).unwrap() - 0.398_942_280_401_4).abs() < 1e-10);
216/// ```
217#[derive(Debug, Clone, Copy)]
218pub struct Normal {
219    /// Mean μ
220    pub mean: f64,
221    /// Standard deviation σ (must be > 0)
222    pub std_dev: f64,
223}
224
225impl Normal {
226    /// Creates a `Normal` distribution with validation.
227    pub fn new(mean: f64, std_dev: f64) -> StatsResult<Self> {
228        if std_dev <= 0.0 || std_dev.is_nan() || mean.is_nan() {
229            return Err(StatsError::InvalidInput {
230                message: "Normal::new: std_dev must be positive and parameters must be finite"
231                    .to_string(),
232            });
233        }
234        Ok(Self { mean, std_dev })
235    }
236
237    /// Maximum-likelihood estimate from data.
238    ///
239    /// MLE: μ = mean(data), σ = population std-dev. Single-pass online
240    /// (Welford) — never walks `data` twice and never allocates.
241    pub fn fit(data: &[f64]) -> StatsResult<Self> {
242        if data.is_empty() {
243            return Err(StatsError::InvalidInput {
244                message: "Normal::fit: data must not be empty".to_string(),
245            });
246        }
247        let mut count = 0.0_f64;
248        let mut mean = 0.0_f64;
249        let mut m2 = 0.0_f64;
250        for &x in data {
251            count += 1.0;
252            let delta = x - mean;
253            mean += delta / count;
254            m2 += delta * (x - mean);
255        }
256        let variance = m2 / count; // population (MLE)
257        Self::new(mean, variance.sqrt())
258    }
259}
260
261impl Distribution for Normal {
262    type X = f64;
263    fn name(&self) -> &str {
264        "Normal"
265    }
266    fn num_params(&self) -> usize {
267        2
268    }
269    fn pdf(&self, x: f64) -> StatsResult<f64> {
270        normal_pdf(x, self.mean, self.std_dev)
271    }
272    fn logpdf(&self, x: f64) -> StatsResult<f64> {
273        let z = (x - self.mean) / self.std_dev;
274        Ok(-0.5 * z * z - self.std_dev.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln())
275    }
276    /// Closed-form bulk log-likelihood. Lets LLVM autovectorise the
277    /// `Σ z_i²` reduction (no per-point Result-returning closure).
278    ///
279    /// `Σ ln f(xᵢ) = −½ · Σ ((xᵢ−μ)/σ)² − n·(ln σ + ½·ln 2π)`
280    fn log_likelihood_fast(&self, data: &[f64]) -> f64 {
281        let inv_sigma = 1.0 / self.std_dev;
282        let mut sum_sq = 0.0_f64;
283        for &x in data {
284            let z = (x - self.mean) * inv_sigma;
285            sum_sq += z * z;
286        }
287        let n = data.len() as f64;
288        -0.5 * sum_sq - n * (self.std_dev.ln() + 0.5 * (2.0 * std::f64::consts::PI).ln())
289    }
290    fn cdf(&self, x: f64) -> StatsResult<f64> {
291        normal_cdf(x, self.mean, self.std_dev)
292    }
293    fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
294        normal_inverse_cdf(p, self.mean, self.std_dev)
295    }
296    fn mean(&self) -> f64 {
297        self.mean
298    }
299    fn variance(&self) -> f64 {
300        self.std_dev * self.std_dev
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    // Small epsilon for floating-point comparisons
309    const EPSILON: f64 = 1e-7;
310
311    #[test]
312    fn test_normal_pdf_standard() {
313        let mean = 0.0;
314        let sigma = 1.0;
315
316        // Test at mean (peak of the density)
317        let result = normal_pdf(mean, mean, sigma).unwrap();
318        assert!((result - 0.3989422804014327).abs() < 1e-10);
319
320        // Test at one standard deviation away
321        let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
322        assert!((result - 0.24197072451914337).abs() < 1e-10);
323    }
324
325    #[test]
326    fn test_normal_pdf_non_standard() {
327        let mean = 5.0;
328        let sigma = 2.0;
329
330        // Test at mean
331        let result = normal_pdf(mean, mean, sigma).unwrap();
332        assert!((result - 0.19947114020071635).abs() < 1e-10);
333
334        // Test at one standard deviation away
335        let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
336        assert!((result - 0.12098536225957168).abs() < 1e-10);
337    }
338
339    #[test]
340    fn test_normal_pdf_symmetry() {
341        let mean = 0.0;
342        let sigma = 1.0;
343        let x = 1.5;
344
345        let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
346        let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
347
348        assert!((pdf_plus - pdf_minus).abs() < 1e-10);
349    }
350
351    #[test]
352    fn test_normal_cdf_standard() {
353        let mean = 0.0;
354        let sigma = 1.0;
355
356        // Test at mean
357        let result = normal_cdf(mean, mean, sigma).unwrap();
358        assert!((result - 0.5).abs() < 1e-10);
359
360        // Test at one standard deviation above mean
361        let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
362        assert!((result - 0.8413447460685429).abs() < EPSILON);
363
364        // Test at one standard deviation below mean
365        let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
366        assert!((result - 0.15865525393145707).abs() < EPSILON);
367    }
368
369    #[test]
370    fn test_normal_cdf_non_standard() {
371        let mean = 100.0;
372        let sigma = 15.0;
373
374        // Test at mean
375        let result = normal_cdf(mean, mean, sigma).unwrap();
376        assert!((result - 0.5).abs() < 1e-10);
377
378        // Test at one standard deviation above mean
379        let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
380        assert!((result - 0.8413447460685429).abs() < EPSILON);
381    }
382
383    #[test]
384    fn test_normal_inverse_cdf() {
385        let mean = 0.0;
386        let sigma = 1.0;
387
388        // Test at median
389        let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
390        assert!((result - mean).abs() < EPSILON);
391
392        // Test at one standard deviation above mean
393        let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
394        assert!((result - sigma).abs() < EPSILON);
395
396        // Test at one standard deviation below mean
397        let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
398        assert!((result - (-sigma)).abs() < EPSILON);
399    }
400
401    #[test]
402    fn test_normal_inverse_cdf_non_standard() {
403        let mean = 50.0;
404        let sigma = 5.0;
405
406        // Test at median
407        let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
408        assert!((result - mean).abs() < EPSILON);
409
410        // Test at one standard deviation above mean
411        let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
412        assert!((result - (mean + sigma)).abs() < EPSILON);
413    }
414
415    #[test]
416    fn test_normal_pdf_standard_normal() {
417        // PDF for standard normal at mean should be maximum (approx 0.3989)
418        let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
419        assert!((pdf - 0.3989423).abs() < EPSILON);
420
421        // Test symmetry around mean
422        let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
423        let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
424        assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
425
426        // Test at specific points
427        assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
428        assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
429    }
430
431    #[test]
432    fn test_normal_pdf_invalid_sigma() {
433        let result = normal_pdf(0.0, 0.0, -1.0);
434        assert!(
435            result.is_err(),
436            "Should return error for negative standard deviation"
437        );
438        assert!(matches!(
439            result.unwrap_err(),
440            StatsError::InvalidInput { .. }
441        ));
442    }
443
444    #[test]
445    fn test_normal_cdf_standard_normal() {
446        // CDF at mean should be 0.5
447        let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
448        assert!((cdf - 0.5).abs() < EPSILON);
449
450        // Test at specific points
451        let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
452        assert!((cdf - 0.8413447).abs() < EPSILON);
453
454        let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
455        assert!((cdf - 0.1586553).abs() < EPSILON);
456
457        let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
458        assert!((cdf - 0.9772499).abs() < EPSILON);
459    }
460
461    #[test]
462    fn test_normal_cdf_invalid_sigma() {
463        let result = normal_cdf(0.0, 0.0, -1.0);
464        assert!(
465            result.is_err(),
466            "Should return error for negative standard deviation"
467        );
468        assert!(matches!(
469            result.unwrap_err(),
470            StatsError::InvalidInput { .. }
471        ));
472    }
473
474    #[test]
475    fn test_normal_inverse_cdf_standard_normal() {
476        // Inverse CDF of 0.5 should be the mean (0)
477        let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
478        assert!(x.abs() < EPSILON);
479
480        // Test at specific probabilities
481        assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
482        assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
483    }
484
485    #[test]
486    fn test_normal_inverse_cdf_p_negative() {
487        let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
488        assert!(result.is_err());
489        assert!(matches!(
490            result.unwrap_err(),
491            StatsError::InvalidInput { .. }
492        ));
493    }
494
495    #[test]
496    fn test_normal_inverse_cdf_p_greater_than_one() {
497        let result = normal_inverse_cdf(1.5, 0.0, 1.0);
498        assert!(result.is_err());
499        assert!(matches!(
500            result.unwrap_err(),
501            StatsError::InvalidInput { .. }
502        ));
503    }
504
505    #[test]
506    fn test_normal_inverse_cdf_p_zero() {
507        let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
508        assert_eq!(result, f64::NEG_INFINITY);
509    }
510
511    #[test]
512    fn test_normal_inverse_cdf_p_one() {
513        let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
514        assert_eq!(result, f64::INFINITY);
515    }
516
517    #[test]
518    fn test_normal_pdf_std_dev_zero() {
519        let result = normal_pdf(0.0, 0.0, 0.0);
520        assert!(result.is_err());
521        assert!(matches!(
522            result.unwrap_err(),
523            StatsError::InvalidInput { .. }
524        ));
525    }
526
527    #[test]
528    fn test_normal_cdf_std_dev_zero() {
529        let result = normal_cdf(0.0, 0.0, 0.0);
530        assert!(result.is_err());
531        assert!(matches!(
532            result.unwrap_err(),
533            StatsError::InvalidInput { .. }
534        ));
535    }
536
537    #[test]
538    fn test_normal_inverse_cdf_std_dev_zero() {
539        // std_dev = 0 should still work (just returns mean)
540        let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
541        assert_eq!(result, 5.0);
542    }
543
544    #[test]
545    fn test_normal_inverse_cdf_std_dev_negative() {
546        // std_dev < 0 should still work (just scales the result)
547        let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
548        assert_eq!(result, 0.0);
549    }
550
551    #[test]
552    fn test_normal_new_valid() {
553        let dist = Normal::new(0.0, 1.0).unwrap();
554        assert_eq!(dist.mean, 0.0);
555        assert_eq!(dist.std_dev, 1.0);
556    }
557}