rs_stats/distributions/
normal_distribution.rs

1use num_traits::ToPrimitive;
2
3use crate::error::{StatsError, StatsResult};
4use crate::prob::erf;
5use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
6use serde::{Deserialize, Serialize};
7
8/// Configuration for the Normal distribution.
9///
10/// # Fields
11/// * `mean` - The mean (location parameter)
12/// * `std_dev` - The standard deviation (scale parameter, must be positive)
13///
14/// # Examples
15/// ```
16/// use rs_stats::distributions::normal_distribution::NormalConfig;
17///
18/// let config = NormalConfig { mean: 0.0, std_dev: 1.0 };
19/// assert!(config.std_dev > 0.0);
20/// ```
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22pub struct NormalConfig<T>
23where
24    T: ToPrimitive,
25{
26    /// The mean (μ) of the distribution.
27    pub mean: T,
28    /// The standard deviation (σ) of the distribution.
29    pub std_dev: T,
30}
31
32impl<T> NormalConfig<T>
33where
34    T: ToPrimitive,
35{
36    /// Creates a new NormalConfig with validation
37    ///
38    /// # Arguments
39    /// * `mean` - The mean of the distribution
40    /// * `std_dev` - The standard deviation of the distribution
41    ///
42    /// # Returns
43    /// `Some(NormalConfig)` if parameters are valid, `None` otherwise
44    ///
45    /// # Examples
46    /// ```
47    /// use rs_stats::distributions::normal_distribution::NormalConfig;
48    ///
49    /// let standard_normal = NormalConfig::new(0.0, 1.0);
50    /// assert!(standard_normal.is_ok());
51    ///
52    /// let invalid_config = NormalConfig::new(0.0, -1.0);
53    /// assert!(invalid_config.is_err());
54    /// ```
55    pub fn new(mean: T, std_dev: T) -> StatsResult<Self> {
56        let std_dev_64 = std_dev
57            .to_f64()
58            .ok_or_else(|| StatsError::ConversionError {
59                message: "NormalConfig::new: Failed to convert std_dev to f64".to_string(),
60            })?;
61        let mean_64 = mean.to_f64().ok_or_else(|| StatsError::ConversionError {
62            message: "NormalConfig::new: Failed to convert mean to f64".to_string(),
63        })?;
64
65        if std_dev_64 > 0.0 && !mean_64.is_nan() && !std_dev_64.is_nan() {
66            Ok(Self { mean, std_dev })
67        } else {
68            Err(StatsError::InvalidInput {
69                message: "NormalConfig::new: std_dev must be positive".to_string(),
70            })
71        }
72    }
73}
74
75/// Calculates the probability density function (PDF) for the normal distribution.
76///
77/// # Arguments
78/// * `x` - The value at which to evaluate the PDF
79/// * `mean` - The mean (μ) of the distribution
80/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
81///
82/// # Returns
83/// The probability density at point x
84///
85/// # Errors
86/// Returns an error if:
87/// - std_dev is not positive
88/// - Type conversion to f64 fails
89///
90/// # Examples
91/// ```
92/// use rs_stats::distributions::normal_distribution::normal_pdf;
93///
94/// // Standard normal distribution at x = 0
95/// let pdf = normal_pdf(0.0, 0.0, 1.0).unwrap();
96/// assert!((pdf - 0.3989422804014327).abs() < 1e-10);
97///
98/// // Normal distribution with mean = 5, std_dev = 2 at x = 5
99/// let pdf = normal_pdf(5.0, 5.0, 2.0).unwrap();
100/// assert!((pdf - 0.19947114020071635).abs() < 1e-10);
101/// ```
102#[inline]
103pub fn normal_pdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
104where
105    T: ToPrimitive,
106{
107    if std_dev <= 0.0 {
108        return Err(StatsError::InvalidInput {
109            message: "normal_pdf: Standard deviation must be positive".to_string(),
110        });
111    }
112
113    let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
114        message: "normal_pdf: Failed to convert x to f64".to_string(),
115    })?;
116
117    // Use multiplication instead of powi(2) for better performance
118    let z = (x_64 - mean) / std_dev;
119    let exponent = -0.5 * z * z;
120    // Use precomputed constant instead of computing sqrt(2π) every call
121    Ok(exponent.exp() * INV_SQRT_2PI / std_dev)
122}
123
124/// Calculates the cumulative distribution function (CDF) for the normal distribution.
125///
126/// # Arguments
127/// * `x` - The value at which to evaluate the CDF
128/// * `mean` - The mean (μ) of the distribution
129/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
130///
131/// # Returns
132/// The probability that a random variable is less than or equal to x
133///
134/// # Errors
135/// Returns an error if:
136/// - std_dev is not positive
137/// - Type conversion to f64 fails
138///
139/// # Examples
140/// ```
141/// use rs_stats::distributions::normal_distribution::normal_cdf;
142///
143/// // Standard normal distribution at x = 0
144/// let cdf = normal_cdf(0.0, 0.0, 1.0).unwrap();
145/// assert!((cdf - 0.5).abs() < 1e-7);
146///
147/// // Normal distribution with mean = 5, std_dev = 2 at x = 7
148/// let cdf = normal_cdf(7.0, 5.0, 2.0).unwrap();
149/// assert!((cdf - 0.8413447460685429).abs() < 1e-7);
150/// ```
151#[inline]
152pub fn normal_cdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
153where
154    T: ToPrimitive,
155{
156    if std_dev <= 0.0 {
157        return Err(StatsError::InvalidInput {
158            message: "normal_cdf: Standard deviation must be positive".to_string(),
159        });
160    }
161
162    let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
163        message: "normal_cdf: Failed to convert x to f64".to_string(),
164    })?;
165
166    // Special case to handle exact value at the mean
167    if x_64 == mean {
168        return Ok(0.5);
169    }
170
171    // Inline z-score calculation and combine with SQRT_2 division
172    // z_score = (x - mean) / std_dev, then divide by SQRT_2
173    // Combined: (x - mean) / (std_dev * SQRT_2)
174    let z = (x_64 - mean) / (std_dev * SQRT_2);
175    Ok(0.5 * (1.0 + erf(z)?))
176}
177
178/// Calculates the inverse cumulative distribution function (Quantile function) for the normal distribution.
179///
180/// # Arguments
181/// * `p` - Probability value between 0 and 1
182/// * `mean` - The mean (μ) of the distribution
183/// * `std_dev` - The standard deviation (σ) of the distribution
184///
185/// # Returns
186/// The value x such that P(X ≤ x) = p
187///
188/// # Examples
189/// ```
190/// use rs_stats::distributions::normal_distribution::{normal_cdf, normal_inverse_cdf};
191///
192/// // Check that inverse_cdf is the inverse of cdf
193/// let x = 0.5;
194/// let p = normal_cdf(x, 0.0, 1.0).unwrap();
195/// let x_back = normal_inverse_cdf(p, 0.0, 1.0).unwrap();
196/// assert!((x - x_back).abs() < 1e-8);
197/// ```
198#[inline]
199pub fn normal_inverse_cdf<T>(p: T, mean: f64, std_dev: f64) -> StatsResult<f64>
200where
201    T: ToPrimitive,
202{
203    let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
204        message: "normal_inverse_cdf: Failed to convert p to f64".to_string(),
205    })?;
206
207    if !(0.0..=1.0).contains(&p_64) {
208        return Err(StatsError::InvalidInput {
209            message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
210        });
211    }
212
213    // Handle edge cases
214    if p_64 == 0.0 {
215        return Ok(f64::NEG_INFINITY);
216    }
217    if p_64 == 1.0 {
218        return Ok(f64::INFINITY);
219    }
220
221    // Use a simple and reliable implementation based on the Rational Approximation
222    // by Peter J. Acklam
223
224    // Convert to standard normal calculation
225    let q = if p_64 <= 0.5 { p_64 } else { 1.0 - p_64 };
226
227    // Avoid numerical issues at boundaries
228    if q <= 0.0 {
229        return if p_64 <= 0.5 {
230            Ok(f64::NEG_INFINITY)
231        } else {
232            Ok(f64::INFINITY)
233        };
234    }
235
236    // Coefficients for central region (small |z|)
237    let a = [
238        -3.969_683_028_665_376e1,
239        2.209_460_984_245_205e2,
240        -2.759_285_104_469_687e2,
241        1.383_577_518_672_69e2,
242        -3.066_479_806_614_716e1,
243        2.506_628_277_459_239,
244    ];
245
246    let b = [
247        -5.447_609_879_822_406e1,
248        1.615_858_368_580_409e2,
249        -1.556_989_798_598_866e2,
250        6.680_131_188_771_972e1,
251        -1.328_068_155_288_572e1,
252        1.0,
253    ];
254
255    // Compute rational approximation
256    let r = q - 0.5;
257
258    let z = if q > 0.02425 && q < 0.97575 {
259        // Central region
260        let r2 = r * r;
261        let num = ((((a[0] * r2 + a[1]) * r2 + a[2]) * r2 + a[3]) * r2 + a[4]) * r2 + a[5];
262        let den = ((((b[0] * r2 + b[1]) * r2 + b[2]) * r2 + b[3]) * r2 + b[4]) * r2 + b[5];
263        r * num / den
264    } else {
265        // Tail region
266        let s = if r < 0.0 { q } else { 1.0 - q };
267        let t = (-2.0 * s.ln()).sqrt();
268
269        // Rational approximation for tail
270        let c = [
271            -7.784_894_002_430_293e-3,
272            -3.223_964_580_411_365e-1,
273            -2.400_758_277_161_838,
274            -2.549_732_539_343_734,
275            4.374_664_141_464_968,
276            2.938_163_982_698_783,
277        ];
278
279        let d = [
280            7.784_695_709_041_462e-3,
281            3.224_671_290_700_398e-1,
282            2.445_134_137_142_996,
283            3.754_408_661_907_416,
284            1.0,
285        ];
286
287        let num = ((((c[0] * t + c[1]) * t + c[2]) * t + c[3]) * t + c[4]) * t + c[5];
288        let den = (((d[0] * t + d[1]) * t + d[2]) * t + d[3]) * t + d[4];
289        if r < 0.0 {
290            -t - num / den
291        } else {
292            t - num / den
293        }
294    };
295
296    // If p > 0.5, we need to flip the sign of z
297    let final_z = if p_64 > 0.5 { -z } else { z };
298
299    let result = mean + std_dev * final_z;
300    // Convert from standard normal to the specified distribution
301    Ok(result)
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    // Small epsilon for floating-point comparisons
309    const EPSILON: f64 = 1e-7;
310
311    #[test]
312    fn test_normal_pdf_standard() {
313        let mean = 0.0;
314        let sigma = 1.0;
315
316        // Test at mean (peak of the density)
317        let result = normal_pdf(mean, mean, sigma).unwrap();
318        assert!((result - 0.3989422804014327).abs() < 1e-10);
319
320        // Test at one standard deviation away
321        let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
322        assert!((result - 0.24197072451914337).abs() < 1e-10);
323    }
324
325    #[test]
326    fn test_normal_pdf_non_standard() {
327        let mean = 5.0;
328        let sigma = 2.0;
329
330        // Test at mean
331        let result = normal_pdf(mean, mean, sigma).unwrap();
332        assert!((result - 0.19947114020071635).abs() < 1e-10);
333
334        // Test at one standard deviation away
335        let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
336        assert!((result - 0.12098536225957168).abs() < 1e-10);
337    }
338
339    #[test]
340    fn test_normal_pdf_symmetry() {
341        let mean = 0.0;
342        let sigma = 1.0;
343        let x = 1.5;
344
345        let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
346        let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
347
348        assert!((pdf_plus - pdf_minus).abs() < 1e-10);
349    }
350
351    #[test]
352    fn test_normal_cdf_standard() {
353        let mean = 0.0;
354        let sigma = 1.0;
355
356        // Test at mean
357        let result = normal_cdf(mean, mean, sigma).unwrap();
358        assert!((result - 0.5).abs() < 1e-10);
359
360        // Test at one standard deviation above mean
361        let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
362        assert!((result - 0.8413447460685429).abs() < EPSILON);
363
364        // Test at one standard deviation below mean
365        let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
366        assert!((result - 0.15865525393145707).abs() < EPSILON);
367    }
368
369    #[test]
370    fn test_normal_cdf_non_standard() {
371        let mean = 100.0;
372        let sigma = 15.0;
373
374        // Test at mean
375        let result = normal_cdf(mean, mean, sigma).unwrap();
376        assert!((result - 0.5).abs() < 1e-10);
377
378        // Test at one standard deviation above mean
379        let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
380        assert!((result - 0.8413447460685429).abs() < EPSILON);
381    }
382
383    #[test]
384    fn test_normal_inverse_cdf() {
385        let mean = 0.0;
386        let sigma = 1.0;
387
388        // Test at median
389        let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
390        assert!((result - mean).abs() < EPSILON);
391
392        // Test at one standard deviation above mean
393        let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
394        assert!((result - sigma).abs() < EPSILON);
395
396        // Test at one standard deviation below mean
397        let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
398        assert!((result - (-sigma)).abs() < EPSILON);
399    }
400
401    #[test]
402    fn test_normal_inverse_cdf_non_standard() {
403        let mean = 50.0;
404        let sigma = 5.0;
405
406        // Test at median
407        let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
408        assert!((result - mean).abs() < EPSILON);
409
410        // Test at one standard deviation above mean
411        let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
412        assert!((result - (mean + sigma)).abs() < EPSILON);
413    }
414
415    #[test]
416    fn test_normal_pdf_standard_normal() {
417        // PDF for standard normal at mean should be maximum (approx 0.3989)
418        let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
419        assert!((pdf - 0.3989423).abs() < EPSILON);
420
421        // Test symmetry around mean
422        let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
423        let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
424        assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
425
426        // Test at specific points
427        assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
428        assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
429    }
430
431    #[test]
432    fn test_normal_pdf_invalid_sigma() {
433        let result = normal_pdf(0.0, 0.0, -1.0);
434        assert!(
435            result.is_err(),
436            "Should return error for negative standard deviation"
437        );
438        assert!(matches!(
439            result.unwrap_err(),
440            StatsError::InvalidInput { .. }
441        ));
442    }
443
444    #[test]
445    fn test_normal_cdf_standard_normal() {
446        // CDF at mean should be 0.5
447        let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
448        assert!((cdf - 0.5).abs() < EPSILON);
449
450        // Test at specific points
451        let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
452        assert!((cdf - 0.8413447).abs() < EPSILON);
453
454        let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
455        assert!((cdf - 0.1586553).abs() < EPSILON);
456
457        let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
458        assert!((cdf - 0.9772499).abs() < EPSILON);
459    }
460
461    #[test]
462    fn test_normal_cdf_invalid_sigma() {
463        let result = normal_cdf(0.0, 0.0, -1.0);
464        assert!(
465            result.is_err(),
466            "Should return error for negative standard deviation"
467        );
468        assert!(matches!(
469            result.unwrap_err(),
470            StatsError::InvalidInput { .. }
471        ));
472    }
473
474    #[test]
475    fn test_normal_inverse_cdf_standard_normal() {
476        // Inverse CDF of 0.5 should be the mean (0)
477        let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
478        assert!(x.abs() < EPSILON);
479
480        // Test at specific probabilities
481        assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
482        assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
483    }
484
485    #[test]
486    fn test_normal_config_new_nan_mean() {
487        let result = NormalConfig::new(f64::NAN, 1.0);
488        assert!(result.is_err());
489        assert!(matches!(
490            result.unwrap_err(),
491            StatsError::InvalidInput { .. }
492        ));
493    }
494
495    #[test]
496    fn test_normal_config_new_nan_std_dev() {
497        let result = NormalConfig::new(0.0, f64::NAN);
498        assert!(result.is_err());
499        assert!(matches!(
500            result.unwrap_err(),
501            StatsError::InvalidInput { .. }
502        ));
503    }
504
505    #[test]
506    fn test_normal_config_new_std_dev_zero() {
507        let result = NormalConfig::new(0.0, 0.0);
508        assert!(result.is_err());
509        assert!(matches!(
510            result.unwrap_err(),
511            StatsError::InvalidInput { .. }
512        ));
513    }
514
515    #[test]
516    fn test_normal_config_new_std_dev_negative() {
517        let result = NormalConfig::new(0.0, -1.0);
518        assert!(result.is_err());
519        assert!(matches!(
520            result.unwrap_err(),
521            StatsError::InvalidInput { .. }
522        ));
523    }
524
525    #[test]
526    fn test_normal_inverse_cdf_p_negative() {
527        let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
528        assert!(result.is_err());
529        assert!(matches!(
530            result.unwrap_err(),
531            StatsError::InvalidInput { .. }
532        ));
533    }
534
535    #[test]
536    fn test_normal_inverse_cdf_p_greater_than_one() {
537        let result = normal_inverse_cdf(1.5, 0.0, 1.0);
538        assert!(result.is_err());
539        assert!(matches!(
540            result.unwrap_err(),
541            StatsError::InvalidInput { .. }
542        ));
543    }
544
545    #[test]
546    fn test_normal_inverse_cdf_p_zero() {
547        let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
548        assert_eq!(result, f64::NEG_INFINITY);
549    }
550
551    #[test]
552    fn test_normal_inverse_cdf_p_one() {
553        let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
554        assert_eq!(result, f64::INFINITY);
555    }
556
557    #[test]
558    fn test_normal_pdf_std_dev_zero() {
559        let result = normal_pdf(0.0, 0.0, 0.0);
560        assert!(result.is_err());
561        assert!(matches!(
562            result.unwrap_err(),
563            StatsError::InvalidInput { .. }
564        ));
565    }
566
567    #[test]
568    fn test_normal_cdf_std_dev_zero() {
569        let result = normal_cdf(0.0, 0.0, 0.0);
570        assert!(result.is_err());
571        assert!(matches!(
572            result.unwrap_err(),
573            StatsError::InvalidInput { .. }
574        ));
575    }
576
577    #[test]
578    fn test_normal_inverse_cdf_std_dev_zero() {
579        // std_dev = 0 should still work (just returns mean)
580        let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
581        assert_eq!(result, 5.0);
582    }
583
584    #[test]
585    fn test_normal_inverse_cdf_std_dev_negative() {
586        // std_dev < 0 should still work (just scales the result)
587        let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
588        assert_eq!(result, 0.0);
589    }
590
591    #[test]
592    fn test_normal_config_new_valid() {
593        let config = NormalConfig::new(0.0, 1.0);
594        assert!(config.is_ok());
595        let config = config.unwrap();
596        assert_eq!(config.mean, 0.0);
597    }
598}