1use crate::distributions::traits::Distribution;
48use crate::error::{StatsError, StatsResult};
49use crate::prob::erf;
50use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
51
52#[inline]
71fn normal_pdf(x: f64, mean: f64, std_dev: f64) -> StatsResult<f64> {
72 if std_dev <= 0.0 {
73 return Err(StatsError::InvalidInput {
74 message: "normal_pdf: standard deviation must be positive".to_string(),
75 });
76 }
77 let z = (x - mean) / std_dev;
78 Ok((-0.5 * z * z).exp() * INV_SQRT_2PI / std_dev)
79}
80
81#[inline]
97pub(crate) fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> StatsResult<f64> {
98 if std_dev <= 0.0 {
99 return Err(StatsError::InvalidInput {
100 message: "normal_cdf: standard deviation must be positive".to_string(),
101 });
102 }
103 if x == mean {
104 return Ok(0.5);
105 }
106 let z = (x - mean) / (std_dev * SQRT_2);
107 Ok(0.5 * (1.0 + erf(z)?))
108}
109
110#[inline]
121pub(crate) fn normal_inverse_cdf(p: f64, mean: f64, std_dev: f64) -> StatsResult<f64> {
122 let p_64 = p;
123
124 if !(0.0..=1.0).contains(&p_64) {
125 return Err(StatsError::InvalidInput {
126 message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
127 });
128 }
129
130 if p_64 == 0.0 {
132 return Ok(f64::NEG_INFINITY);
133 }
134 if p_64 == 1.0 {
135 return Ok(f64::INFINITY);
136 }
137
138 let a = [
144 -3.969_683_028_665_376e1,
145 2.209_460_984_245_205e2,
146 -2.759_285_104_469_687e2,
147 1.383_577_518_672_69e2,
148 -3.066_479_806_614_716e1,
149 2.506_628_277_459_239,
150 ];
151 let b = [
152 -5.447_609_879_822_406e1,
153 1.615_858_368_580_409e2,
154 -1.556_989_798_598_866e2,
155 6.680_131_188_771_972e1,
156 -1.328_068_155_288_572e1,
157 1.0,
158 ];
159 let c = [
161 -7.784_894_002_430_293e-3,
162 -3.223_964_580_411_365e-1,
163 -2.400_758_277_161_838,
164 -2.549_732_539_343_734,
165 4.374_664_141_464_968,
166 2.938_163_982_698_783,
167 ];
168 let d = [
169 7.784_695_709_041_462e-3,
170 3.224_671_290_700_398e-1,
171 2.445_134_137_142_996,
172 3.754_408_661_907_416,
173 ];
174
175 const P_LOW: f64 = 0.02425;
176 const P_HIGH: f64 = 1.0 - P_LOW;
177
178 let z = if p_64 < P_LOW {
179 let q = (-2.0 * p_64.ln()).sqrt();
181 let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
182 let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
183 num / den
184 } else if p_64 > P_HIGH {
185 let q = (-2.0 * (1.0 - p_64).ln()).sqrt();
187 let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
188 let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
189 -num / den
190 } else {
191 let q = p_64 - 0.5;
193 let r = q * q;
194 let num = ((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5];
195 let den = ((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + b[5];
196 q * num / den
197 };
198
199 Ok(mean + std_dev * z)
200}
201
202#[derive(Debug, Clone, Copy)]
218pub struct Normal {
219 pub mean: f64,
221 pub std_dev: f64,
223}
224
225impl Normal {
226 pub fn new(mean: f64, std_dev: f64) -> StatsResult<Self> {
228 if std_dev <= 0.0 || std_dev.is_nan() || mean.is_nan() {
229 return Err(StatsError::InvalidInput {
230 message: "Normal::new: std_dev must be positive and parameters must be finite"
231 .to_string(),
232 });
233 }
234 Ok(Self { mean, std_dev })
235 }
236
237 pub fn fit(data: &[f64]) -> StatsResult<Self> {
242 if data.is_empty() {
243 return Err(StatsError::InvalidInput {
244 message: "Normal::fit: data must not be empty".to_string(),
245 });
246 }
247 let mut count = 0.0_f64;
248 let mut mean = 0.0_f64;
249 let mut m2 = 0.0_f64;
250 for &x in data {
251 count += 1.0;
252 let delta = x - mean;
253 mean += delta / count;
254 m2 += delta * (x - mean);
255 }
256 let variance = m2 / count; Self::new(mean, variance.sqrt())
258 }
259}
260
261impl Distribution for Normal {
262 type X = f64;
263 fn name(&self) -> &str {
264 "Normal"
265 }
266 fn num_params(&self) -> usize {
267 2
268 }
269 fn pdf(&self, x: f64) -> StatsResult<f64> {
270 normal_pdf(x, self.mean, self.std_dev)
271 }
272 fn logpdf(&self, x: f64) -> StatsResult<f64> {
273 let z = (x - self.mean) / self.std_dev;
274 Ok(-0.5 * z * z - self.std_dev.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln())
275 }
276 fn log_likelihood_fast(&self, data: &[f64]) -> f64 {
281 let inv_sigma = 1.0 / self.std_dev;
282 let mut sum_sq = 0.0_f64;
283 for &x in data {
284 let z = (x - self.mean) * inv_sigma;
285 sum_sq += z * z;
286 }
287 let n = data.len() as f64;
288 -0.5 * sum_sq - n * (self.std_dev.ln() + 0.5 * (2.0 * std::f64::consts::PI).ln())
289 }
290 fn cdf(&self, x: f64) -> StatsResult<f64> {
291 normal_cdf(x, self.mean, self.std_dev)
292 }
293 fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
294 normal_inverse_cdf(p, self.mean, self.std_dev)
295 }
296 fn mean(&self) -> f64 {
297 self.mean
298 }
299 fn variance(&self) -> f64 {
300 self.std_dev * self.std_dev
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 const EPSILON: f64 = 1e-7;
310
311 #[test]
312 fn test_normal_pdf_standard() {
313 let mean = 0.0;
314 let sigma = 1.0;
315
316 let result = normal_pdf(mean, mean, sigma).unwrap();
318 assert!((result - 0.3989422804014327).abs() < 1e-10);
319
320 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
322 assert!((result - 0.24197072451914337).abs() < 1e-10);
323 }
324
325 #[test]
326 fn test_normal_pdf_non_standard() {
327 let mean = 5.0;
328 let sigma = 2.0;
329
330 let result = normal_pdf(mean, mean, sigma).unwrap();
332 assert!((result - 0.19947114020071635).abs() < 1e-10);
333
334 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
336 assert!((result - 0.12098536225957168).abs() < 1e-10);
337 }
338
339 #[test]
340 fn test_normal_pdf_symmetry() {
341 let mean = 0.0;
342 let sigma = 1.0;
343 let x = 1.5;
344
345 let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
346 let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
347
348 assert!((pdf_plus - pdf_minus).abs() < 1e-10);
349 }
350
351 #[test]
352 fn test_normal_cdf_standard() {
353 let mean = 0.0;
354 let sigma = 1.0;
355
356 let result = normal_cdf(mean, mean, sigma).unwrap();
358 assert!((result - 0.5).abs() < 1e-10);
359
360 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
362 assert!((result - 0.8413447460685429).abs() < EPSILON);
363
364 let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
366 assert!((result - 0.15865525393145707).abs() < EPSILON);
367 }
368
369 #[test]
370 fn test_normal_cdf_non_standard() {
371 let mean = 100.0;
372 let sigma = 15.0;
373
374 let result = normal_cdf(mean, mean, sigma).unwrap();
376 assert!((result - 0.5).abs() < 1e-10);
377
378 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
380 assert!((result - 0.8413447460685429).abs() < EPSILON);
381 }
382
383 #[test]
384 fn test_normal_inverse_cdf() {
385 let mean = 0.0;
386 let sigma = 1.0;
387
388 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
390 assert!((result - mean).abs() < EPSILON);
391
392 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
394 assert!((result - sigma).abs() < EPSILON);
395
396 let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
398 assert!((result - (-sigma)).abs() < EPSILON);
399 }
400
401 #[test]
402 fn test_normal_inverse_cdf_non_standard() {
403 let mean = 50.0;
404 let sigma = 5.0;
405
406 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
408 assert!((result - mean).abs() < EPSILON);
409
410 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
412 assert!((result - (mean + sigma)).abs() < EPSILON);
413 }
414
415 #[test]
416 fn test_normal_pdf_standard_normal() {
417 let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
419 assert!((pdf - 0.3989423).abs() < EPSILON);
420
421 let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
423 let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
424 assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
425
426 assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
428 assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
429 }
430
431 #[test]
432 fn test_normal_pdf_invalid_sigma() {
433 let result = normal_pdf(0.0, 0.0, -1.0);
434 assert!(
435 result.is_err(),
436 "Should return error for negative standard deviation"
437 );
438 assert!(matches!(
439 result.unwrap_err(),
440 StatsError::InvalidInput { .. }
441 ));
442 }
443
444 #[test]
445 fn test_normal_cdf_standard_normal() {
446 let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
448 assert!((cdf - 0.5).abs() < EPSILON);
449
450 let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
452 assert!((cdf - 0.8413447).abs() < EPSILON);
453
454 let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
455 assert!((cdf - 0.1586553).abs() < EPSILON);
456
457 let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
458 assert!((cdf - 0.9772499).abs() < EPSILON);
459 }
460
461 #[test]
462 fn test_normal_cdf_invalid_sigma() {
463 let result = normal_cdf(0.0, 0.0, -1.0);
464 assert!(
465 result.is_err(),
466 "Should return error for negative standard deviation"
467 );
468 assert!(matches!(
469 result.unwrap_err(),
470 StatsError::InvalidInput { .. }
471 ));
472 }
473
474 #[test]
475 fn test_normal_inverse_cdf_standard_normal() {
476 let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
478 assert!(x.abs() < EPSILON);
479
480 assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
482 assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
483 }
484
485 #[test]
486 fn test_normal_inverse_cdf_p_negative() {
487 let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
488 assert!(result.is_err());
489 assert!(matches!(
490 result.unwrap_err(),
491 StatsError::InvalidInput { .. }
492 ));
493 }
494
495 #[test]
496 fn test_normal_inverse_cdf_p_greater_than_one() {
497 let result = normal_inverse_cdf(1.5, 0.0, 1.0);
498 assert!(result.is_err());
499 assert!(matches!(
500 result.unwrap_err(),
501 StatsError::InvalidInput { .. }
502 ));
503 }
504
505 #[test]
506 fn test_normal_inverse_cdf_p_zero() {
507 let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
508 assert_eq!(result, f64::NEG_INFINITY);
509 }
510
511 #[test]
512 fn test_normal_inverse_cdf_p_one() {
513 let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
514 assert_eq!(result, f64::INFINITY);
515 }
516
517 #[test]
518 fn test_normal_pdf_std_dev_zero() {
519 let result = normal_pdf(0.0, 0.0, 0.0);
520 assert!(result.is_err());
521 assert!(matches!(
522 result.unwrap_err(),
523 StatsError::InvalidInput { .. }
524 ));
525 }
526
527 #[test]
528 fn test_normal_cdf_std_dev_zero() {
529 let result = normal_cdf(0.0, 0.0, 0.0);
530 assert!(result.is_err());
531 assert!(matches!(
532 result.unwrap_err(),
533 StatsError::InvalidInput { .. }
534 ));
535 }
536
537 #[test]
538 fn test_normal_inverse_cdf_std_dev_zero() {
539 let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
541 assert_eq!(result, 5.0);
542 }
543
544 #[test]
545 fn test_normal_inverse_cdf_std_dev_negative() {
546 let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
548 assert_eq!(result, 0.0);
549 }
550
551 #[test]
552 fn test_normal_new_valid() {
553 let dist = Normal::new(0.0, 1.0).unwrap();
554 assert_eq!(dist.mean, 0.0);
555 assert_eq!(dist.std_dev, 1.0);
556 }
557}