rs_stats/distributions/normal_distribution.rs
1use crate::prob::erf::erf;
2use rand::Rng;
3use rand_distr::{Distribution, Normal as RandNormal};
4use serde::{Deserialize, Serialize};
5use std::f64::consts::PI;
6
7/// Configuration for the Normal distribution.
8///
9/// # Fields
10/// * `mean` - The mean (location parameter)
11/// * `std_dev` - The standard deviation (scale parameter, must be positive)
12///
13/// # Examples
14/// ```
15/// use rs_stats::distributions::normal_distribution::NormalConfig;
16///
17/// let config = NormalConfig { mean: 0.0, std_dev: 1.0 };
18/// assert!(config.std_dev > 0.0);
19/// ```
20#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
21pub struct NormalConfig {
22 /// The mean (μ) of the distribution.
23 pub mean: f64,
24 /// The standard deviation (σ) of the distribution.
25 pub std_dev: f64,
26}
27
28impl NormalConfig {
29 /// Creates a new NormalConfig with validation
30 ///
31 /// # Arguments
32 /// * `mean` - The mean of the distribution
33 /// * `std_dev` - The standard deviation of the distribution
34 ///
35 /// # Returns
36 /// `Some(NormalConfig)` if parameters are valid, `None` otherwise
37 ///
38 /// # Examples
39 /// ```
40 /// use rs_stats::distributions::normal_distribution::NormalConfig;
41 ///
42 /// let standard_normal = NormalConfig::new(0.0, 1.0);
43 /// assert!(standard_normal.is_some());
44 ///
45 /// let invalid_config = NormalConfig::new(0.0, -1.0);
46 /// assert!(invalid_config.is_none());
47 /// ```
48 pub fn new(mean: f64, std_dev: f64) -> Option<Self> {
49 if std_dev > 0.0 && !mean.is_nan() && !std_dev.is_nan() {
50 Some(Self { mean, std_dev })
51 } else {
52 None
53 }
54 }
55}
56
57/// Calculates the probability density function (PDF) for the normal distribution.
58///
59/// # Arguments
60/// * `x` - The value at which to evaluate the PDF
61/// * `mean` - The mean (μ) of the distribution
62/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
63///
64/// # Returns
65/// The probability density at point x
66///
67/// # Panics
68/// Panics if std_dev is not positive.
69///
70/// # Examples
71/// ```
72/// use rs_stats::distributions::normal_distribution::normal_pdf;
73///
74/// // Standard normal distribution at x = 0
75/// let pdf = normal_pdf(0.0, 0.0, 1.0);
76/// assert!((pdf - 0.3989422804014327).abs() < 1e-10);
77///
78/// // Normal distribution with mean = 5, std_dev = 2 at x = 5
79/// let pdf = normal_pdf(5.0, 5.0, 2.0);
80/// assert!((pdf - 0.19947114020071635).abs() < 1e-10);
81/// ```
82pub fn normal_pdf(x: f64, mean: f64, std_dev: f64) -> f64 {
83 assert!(std_dev > 0.0, "Standard deviation must be positive");
84
85 let exponent = -0.5 * ((x - mean) / std_dev).powi(2);
86 (1.0 / (std_dev * (2.0 * PI).sqrt())) * exponent.exp()
87}
88
89/// Calculates the cumulative distribution function (CDF) for the normal distribution.
90///
91/// # Arguments
92/// * `x` - The value at which to evaluate the CDF
93/// * `mean` - The mean (μ) of the distribution
94/// * `std_dev` - The standard deviation (σ) of the distribution (must be positive)
95///
96/// # Returns
97/// The probability that a random variable is less than or equal to x
98///
99/// # Panics
100/// Panics if std_dev is not positive.
101///
102/// # Examples
103/// ```
104/// use rs_stats::distributions::normal_distribution::normal_cdf;
105///
106/// // Standard normal distribution at x = 0
107/// let cdf = normal_cdf(0.0, 0.0, 1.0);
108/// assert!((cdf - 0.5).abs() < 1e-7);
109///
110/// // Normal distribution with mean = 5, std_dev = 2 at x = 7
111/// let cdf = normal_cdf(7.0, 5.0, 2.0);
112/// assert!((cdf - 0.8413447460685429).abs() < 1e-7);
113/// ```
114pub fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> f64 {
115 assert!(std_dev > 0.0, "Standard deviation must be positive");
116
117 // Special case to handle exact value at the mean
118 if x == mean {
119 return 0.5;
120 }
121
122 // Calculate the standardized value z
123 let z = (x - mean) / std_dev;
124
125 // Use a more numerically stable form of the calculation
126 // The sqrt(2) factor is included in the argument to erf
127 0.5 * (1.0 + erf(z / std::f64::consts::SQRT_2))
128}
129
130/// Calculates the inverse cumulative distribution function (Quantile function) for the normal distribution.
131///
132/// # Arguments
133/// * `p` - Probability value between 0 and 1
134/// * `mean` - The mean (μ) of the distribution
135/// * `sigma` - The standard deviation (σ) of the distribution
136///
137/// # Returns
138/// The value x such that P(X ≤ x) = p
139///
140/// # Examples
141/// ```
142/// use rs_stats::distributions::normal_distribution::{normal_cdf, normal_inverse_cdf};
143///
144/// // Check that inverse_cdf is the inverse of cdf
145/// let x = 0.5;
146/// let p = normal_cdf(x, 0.0, 1.0);
147/// let x_back = normal_inverse_cdf(p, 0.0, 1.0);
148/// assert!((x - x_back).abs() < 1e-8);
149/// ```
150pub fn normal_inverse_cdf(p: f64, mean: f64, sigma: f64) -> f64 {
151 assert!(
152 (0.0..=1.0).contains(&p),
153 "Probability must be between 0 and 1"
154 );
155
156 // Handle edge cases
157 if p == 0.0 {
158 return f64::NEG_INFINITY;
159 }
160 if p == 1.0 {
161 return f64::INFINITY;
162 }
163
164 // Use a simple and reliable implementation based on the Rational Approximation
165 // by Peter J. Acklam
166
167 // Convert to standard normal calculation
168 let q = if p <= 0.5 { p } else { 1.0 - p };
169
170 // Keep track of whether we need to flip the sign at the end
171 let flip_sign = p > 0.5;
172
173 // Avoid numerical issues at boundaries
174 if q <= 0.0 {
175 return if p <= 0.5 {
176 f64::NEG_INFINITY
177 } else {
178 f64::INFINITY
179 };
180 }
181
182 // Coefficients for central region (small |z|)
183 let a = [
184 -3.969_683_028_665_376e1,
185 2.209_460_984_245_205e2,
186 -2.759_285_104_469_687e2,
187 1.383_577_518_672_69e2,
188 -3.066_479_806_614_716e1,
189 2.506_628_277_459_239,
190 ];
191
192 let b = [
193 -5.447_609_879_822_406e1,
194 1.615_858_368_580_409e2,
195 -1.556_989_798_598_866e2,
196 6.680_131_188_771_972e1,
197 -1.328_068_155_288_572e1,
198 1.0,
199 ];
200
201 // Compute rational approximation
202 let r = q - 0.5;
203
204 let z = if q > 0.02425 && q < 0.97575 {
205 // Central region
206 let r2 = r * r;
207 let num = ((((a[0] * r2 + a[1]) * r2 + a[2]) * r2 + a[3]) * r2 + a[4]) * r2 + a[5];
208 let den = ((((b[0] * r2 + b[1]) * r2 + b[2]) * r2 + b[3]) * r2 + b[4]) * r2 + b[5];
209 r * num / den
210 } else {
211 // Tail region
212 let s = if r < 0.0 { q } else { 1.0 - q };
213 let t = (-2.0 * s.ln()).sqrt();
214
215 // Rational approximation for tail
216 let c = [
217 -7.784_894_002_430_293e-3,
218 -3.223_964_580_411_365e-1,
219 -2.400_758_277_161_838,
220 -2.549_732_539_343_734,
221 4.374_664_141_464_968,
222 2.938_163_982_698_783,
223 ];
224
225 let d = [
226 7.784_695_709_041_462e-3,
227 3.224_671_290_700_398e-1,
228 2.445_134_137_142_996,
229 3.754_408_661_907_416,
230 1.0,
231 ];
232
233 let num = ((((c[0] * t + c[1]) * t + c[2]) * t + c[3]) * t + c[4]) * t + c[5];
234 let den = (((d[0] * t + d[1]) * t + d[2]) * t + d[3]) * t + d[4];
235 if r < 0.0 {
236 -t - num / den
237 } else {
238 t - num / den
239 }
240 };
241
242 // If p > 0.5, we need to flip the sign of z
243 let final_z = if flip_sign { -z } else { z };
244
245 // Convert from standard normal to the specified distribution
246 mean + sigma * final_z
247}
248
249/// Generates a random sample from the normal distribution.
250///
251/// # Arguments
252/// * `mean` - The mean (μ) of the distribution
253/// * `sigma` - The standard deviation (σ) of the distribution
254/// * `rng` - A random number generator
255///
256/// # Returns
257/// A random value from the normal distribution
258///
259/// # Examples
260/// ```
261/// use rs_stats::distributions::normal_distribution::normal_sample;
262/// use rand::thread_rng;
263///
264/// let mut rng = thread_rng();
265/// let sample = normal_sample(10.0, 2.0, &mut rng);
266/// // sample is a random value from Normal(10, 2)
267/// ```
268pub fn normal_sample<R: Rng>(mean: f64, sigma: f64, rng: &mut R) -> f64 {
269 let normal = RandNormal::new(mean, sigma).unwrap();
270 normal.sample(rng)
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 // Small epsilon for floating-point comparisons
278 const EPSILON: f64 = 1e-7;
279
280 #[test]
281 fn test_normal_pdf_standard() {
282 let mean = 0.0;
283 let sigma = 1.0;
284
285 // Test at mean (peak of the density)
286 let result = normal_pdf(mean, mean, sigma);
287 assert!((result - 0.3989422804014327).abs() < 1e-10);
288
289 // Test at one standard deviation away
290 let result = normal_pdf(mean + sigma, mean, sigma);
291 assert!((result - 0.24197072451914337).abs() < 1e-10);
292 }
293
294 #[test]
295 fn test_normal_pdf_non_standard() {
296 let mean = 5.0;
297 let sigma = 2.0;
298
299 // Test at mean
300 let result = normal_pdf(mean, mean, sigma);
301 assert!((result - 0.19947114020071635).abs() < 1e-10);
302
303 // Test at one standard deviation away
304 let result = normal_pdf(mean + sigma, mean, sigma);
305 assert!((result - 0.12098536225957168).abs() < 1e-10);
306 }
307
308 #[test]
309 fn test_normal_pdf_symmetry() {
310 let mean = 0.0;
311 let sigma = 1.0;
312 let x = 1.5;
313
314 let pdf_plus = normal_pdf(mean + x, mean, sigma);
315 let pdf_minus = normal_pdf(mean - x, mean, sigma);
316
317 assert!((pdf_plus - pdf_minus).abs() < 1e-10);
318 }
319
320 #[test]
321 fn test_normal_cdf_standard() {
322 let mean = 0.0;
323 let sigma = 1.0;
324
325 // Test at mean
326 let result = normal_cdf(mean, mean, sigma);
327 assert!((result - 0.5).abs() < 1e-10);
328
329 // Test at one standard deviation above mean
330 let result = normal_cdf(mean + sigma, mean, sigma);
331 assert!((result - 0.8413447460685429).abs() < EPSILON);
332
333 // Test at one standard deviation below mean
334 let result = normal_cdf(mean - sigma, mean, sigma);
335 assert!((result - 0.15865525393145707).abs() < EPSILON);
336 }
337
338 #[test]
339 fn test_normal_cdf_non_standard() {
340 let mean = 100.0;
341 let sigma = 15.0;
342
343 // Test at mean
344 let result = normal_cdf(mean, mean, sigma);
345 assert!((result - 0.5).abs() < 1e-10);
346
347 // Test at one standard deviation above mean
348 let result = normal_cdf(mean + sigma, mean, sigma);
349 assert!((result - 0.8413447460685429).abs() < EPSILON);
350 }
351
352 #[test]
353 fn test_normal_inverse_cdf() {
354 let mean = 0.0;
355 let sigma = 1.0;
356
357 // Test at median
358 let result = normal_inverse_cdf(0.5, mean, sigma);
359 assert!((result - mean).abs() < EPSILON);
360
361 // Test at one standard deviation above mean
362 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma);
363 assert!((result - sigma).abs() < EPSILON);
364
365 // Test at one standard deviation below mean
366 let result = normal_inverse_cdf(0.15865525393145707, mean, sigma);
367 assert!((result - (-sigma)).abs() < EPSILON);
368 }
369
370 #[test]
371 fn test_normal_inverse_cdf_non_standard() {
372 let mean = 50.0;
373 let sigma = 5.0;
374
375 // Test at median
376 let result = normal_inverse_cdf(0.5, mean, sigma);
377 assert!((result - mean).abs() < EPSILON);
378
379 // Test at one standard deviation above mean
380 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma);
381 assert!((result - (mean + sigma)).abs() < EPSILON);
382 }
383
384 #[test]
385 fn test_normal_pdf_standard_normal() {
386 // PDF for standard normal at mean should be maximum (approx 0.3989)
387 let pdf = (normal_pdf(0.0, 0.0, 1.0) * 1e7).round() / 1e7;
388 assert!((pdf - 0.3989423).abs() < EPSILON);
389
390 // Test symmetry around mean
391 let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0);
392 let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0);
393 assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
394
395 // Test at specific points
396 assert!((normal_pdf(1.0, 0.0, 1.0) - 0.2419707).abs() < EPSILON);
397 assert!((normal_pdf(2.0, 0.0, 1.0) - 0.0539909).abs() < EPSILON);
398 }
399
400 #[test]
401 #[should_panic(expected = "Standard deviation must be positive")]
402 fn test_normal_pdf_invalid_sigma() {
403 normal_pdf(0.0, 0.0, -1.0);
404 }
405
406 #[test]
407 fn test_normal_cdf_standard_normal() {
408 // CDF at mean should be 0.5
409 let cdf = (normal_cdf(0.0, 0.0, 1.0) * 1e1).round() / 1e1;
410 assert!((cdf - 0.5).abs() < EPSILON);
411
412 // Test at specific points
413 let cdf = (normal_cdf(1.0, 0.0, 1.0) * 1e7).round() / 1e7;
414 assert!((cdf - 0.8413447).abs() < EPSILON);
415
416 let cdf = (normal_cdf(-1.0, 0.0, 1.0) * 1e7).round() / 1e7;
417 assert!((cdf - 0.1586553).abs() < EPSILON);
418
419 let cdf = (normal_cdf(2.0, 0.0, 1.0) * 1e7).round() / 1e7;
420 assert!((cdf - 0.9772499).abs() < EPSILON);
421 }
422
423 #[test]
424 #[should_panic(expected = "Standard deviation must be positive")]
425 fn test_normal_cdf_invalid_sigma() {
426 normal_cdf(0.0, 0.0, -1.0);
427 }
428
429 #[test]
430 fn test_normal_inverse_cdf_standard_normal() {
431 // Inverse CDF of 0.5 should be the mean (0)
432 let x = (normal_inverse_cdf(0.5, 0.0, 1.0) * 1e7).round() / 1e7;
433 assert!(x.abs() < EPSILON);
434
435 // Test at specific probabilities
436 assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0) - 1.0).abs() < 0.01);
437 assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0) + 1.0).abs() < 0.01);
438 }
439}