1use num_traits::ToPrimitive;
48
49use crate::distributions::traits::Distribution;
50use crate::error::{StatsError, StatsResult};
51use crate::prob::erf;
52use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
53use serde::{Deserialize, Serialize};
54
55#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
69pub struct NormalConfig<T>
70where
71 T: ToPrimitive,
72{
73 pub mean: T,
75 pub std_dev: T,
77}
78
79impl<T> NormalConfig<T>
80where
81 T: ToPrimitive,
82{
83 pub fn new(mean: T, std_dev: T) -> StatsResult<Self> {
103 let std_dev_64 = std_dev
104 .to_f64()
105 .ok_or_else(|| StatsError::ConversionError {
106 message: "NormalConfig::new: Failed to convert std_dev to f64".to_string(),
107 })?;
108 let mean_64 = mean.to_f64().ok_or_else(|| StatsError::ConversionError {
109 message: "NormalConfig::new: Failed to convert mean to f64".to_string(),
110 })?;
111
112 if std_dev_64 > 0.0 && !mean_64.is_nan() && !std_dev_64.is_nan() {
113 Ok(Self { mean, std_dev })
114 } else {
115 Err(StatsError::InvalidInput {
116 message: "NormalConfig::new: std_dev must be positive".to_string(),
117 })
118 }
119 }
120}
121
122#[inline]
150pub fn normal_pdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
151where
152 T: ToPrimitive,
153{
154 if std_dev <= 0.0 {
155 return Err(StatsError::InvalidInput {
156 message: "normal_pdf: Standard deviation must be positive".to_string(),
157 });
158 }
159
160 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
161 message: "normal_pdf: Failed to convert x to f64".to_string(),
162 })?;
163
164 let z = (x_64 - mean) / std_dev;
166 let exponent = -0.5 * z * z;
167 Ok(exponent.exp() * INV_SQRT_2PI / std_dev)
169}
170
171#[inline]
199pub fn normal_cdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
200where
201 T: ToPrimitive,
202{
203 if std_dev <= 0.0 {
204 return Err(StatsError::InvalidInput {
205 message: "normal_cdf: Standard deviation must be positive".to_string(),
206 });
207 }
208
209 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
210 message: "normal_cdf: Failed to convert x to f64".to_string(),
211 })?;
212
213 if x_64 == mean {
215 return Ok(0.5);
216 }
217
218 let z = (x_64 - mean) / (std_dev * SQRT_2);
222 Ok(0.5 * (1.0 + erf(z)?))
223}
224
225#[inline]
246pub fn normal_inverse_cdf<T>(p: T, mean: f64, std_dev: f64) -> StatsResult<f64>
247where
248 T: ToPrimitive,
249{
250 let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
251 message: "normal_inverse_cdf: Failed to convert p to f64".to_string(),
252 })?;
253
254 if !(0.0..=1.0).contains(&p_64) {
255 return Err(StatsError::InvalidInput {
256 message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
257 });
258 }
259
260 if p_64 == 0.0 {
262 return Ok(f64::NEG_INFINITY);
263 }
264 if p_64 == 1.0 {
265 return Ok(f64::INFINITY);
266 }
267
268 let a = [
274 -3.969_683_028_665_376e1,
275 2.209_460_984_245_205e2,
276 -2.759_285_104_469_687e2,
277 1.383_577_518_672_69e2,
278 -3.066_479_806_614_716e1,
279 2.506_628_277_459_239,
280 ];
281 let b = [
282 -5.447_609_879_822_406e1,
283 1.615_858_368_580_409e2,
284 -1.556_989_798_598_866e2,
285 6.680_131_188_771_972e1,
286 -1.328_068_155_288_572e1,
287 1.0,
288 ];
289 let c = [
291 -7.784_894_002_430_293e-3,
292 -3.223_964_580_411_365e-1,
293 -2.400_758_277_161_838,
294 -2.549_732_539_343_734,
295 4.374_664_141_464_968,
296 2.938_163_982_698_783,
297 ];
298 let d = [
299 7.784_695_709_041_462e-3,
300 3.224_671_290_700_398e-1,
301 2.445_134_137_142_996,
302 3.754_408_661_907_416,
303 ];
304
305 const P_LOW: f64 = 0.02425;
306 const P_HIGH: f64 = 1.0 - P_LOW;
307
308 let z = if p_64 < P_LOW {
309 let q = (-2.0 * p_64.ln()).sqrt();
311 let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
312 let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
313 num / den
314 } else if p_64 > P_HIGH {
315 let q = (-2.0 * (1.0 - p_64).ln()).sqrt();
317 let num = ((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5];
318 let den = (((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0;
319 -num / den
320 } else {
321 let q = p_64 - 0.5;
323 let r = q * q;
324 let num = ((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5];
325 let den = ((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + b[5];
326 q * num / den
327 };
328
329 Ok(mean + std_dev * z)
330}
331
332#[derive(Debug, Clone, Copy)]
348pub struct Normal {
349 pub mean: f64,
351 pub std_dev: f64,
353}
354
355impl Normal {
356 pub fn new(mean: f64, std_dev: f64) -> StatsResult<Self> {
358 if std_dev <= 0.0 || std_dev.is_nan() || mean.is_nan() {
359 return Err(StatsError::InvalidInput {
360 message: "Normal::new: std_dev must be positive and parameters must be finite"
361 .to_string(),
362 });
363 }
364 Ok(Self { mean, std_dev })
365 }
366
367 pub fn fit(data: &[f64]) -> StatsResult<Self> {
372 if data.is_empty() {
373 return Err(StatsError::InvalidInput {
374 message: "Normal::fit: data must not be empty".to_string(),
375 });
376 }
377 let mut count = 0.0_f64;
378 let mut mean = 0.0_f64;
379 let mut m2 = 0.0_f64;
380 for &x in data {
381 count += 1.0;
382 let delta = x - mean;
383 mean += delta / count;
384 m2 += delta * (x - mean);
385 }
386 let variance = m2 / count; Self::new(mean, variance.sqrt())
388 }
389}
390
391impl Distribution for Normal {
392 fn name(&self) -> &str {
393 "Normal"
394 }
395 fn num_params(&self) -> usize {
396 2
397 }
398 fn pdf(&self, x: f64) -> StatsResult<f64> {
399 normal_pdf(x, self.mean, self.std_dev)
400 }
401 fn logpdf(&self, x: f64) -> StatsResult<f64> {
402 let z = (x - self.mean) / self.std_dev;
403 Ok(-0.5 * z * z - self.std_dev.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln())
404 }
405 fn cdf(&self, x: f64) -> StatsResult<f64> {
406 normal_cdf(x, self.mean, self.std_dev)
407 }
408 fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
409 normal_inverse_cdf(p, self.mean, self.std_dev)
410 }
411 fn mean(&self) -> f64 {
412 self.mean
413 }
414 fn variance(&self) -> f64 {
415 self.std_dev * self.std_dev
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 const EPSILON: f64 = 1e-7;
425
426 #[test]
427 fn test_normal_pdf_standard() {
428 let mean = 0.0;
429 let sigma = 1.0;
430
431 let result = normal_pdf(mean, mean, sigma).unwrap();
433 assert!((result - 0.3989422804014327).abs() < 1e-10);
434
435 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
437 assert!((result - 0.24197072451914337).abs() < 1e-10);
438 }
439
440 #[test]
441 fn test_normal_pdf_non_standard() {
442 let mean = 5.0;
443 let sigma = 2.0;
444
445 let result = normal_pdf(mean, mean, sigma).unwrap();
447 assert!((result - 0.19947114020071635).abs() < 1e-10);
448
449 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
451 assert!((result - 0.12098536225957168).abs() < 1e-10);
452 }
453
454 #[test]
455 fn test_normal_pdf_symmetry() {
456 let mean = 0.0;
457 let sigma = 1.0;
458 let x = 1.5;
459
460 let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
461 let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
462
463 assert!((pdf_plus - pdf_minus).abs() < 1e-10);
464 }
465
466 #[test]
467 fn test_normal_cdf_standard() {
468 let mean = 0.0;
469 let sigma = 1.0;
470
471 let result = normal_cdf(mean, mean, sigma).unwrap();
473 assert!((result - 0.5).abs() < 1e-10);
474
475 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
477 assert!((result - 0.8413447460685429).abs() < EPSILON);
478
479 let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
481 assert!((result - 0.15865525393145707).abs() < EPSILON);
482 }
483
484 #[test]
485 fn test_normal_cdf_non_standard() {
486 let mean = 100.0;
487 let sigma = 15.0;
488
489 let result = normal_cdf(mean, mean, sigma).unwrap();
491 assert!((result - 0.5).abs() < 1e-10);
492
493 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
495 assert!((result - 0.8413447460685429).abs() < EPSILON);
496 }
497
498 #[test]
499 fn test_normal_inverse_cdf() {
500 let mean = 0.0;
501 let sigma = 1.0;
502
503 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
505 assert!((result - mean).abs() < EPSILON);
506
507 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
509 assert!((result - sigma).abs() < EPSILON);
510
511 let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
513 assert!((result - (-sigma)).abs() < EPSILON);
514 }
515
516 #[test]
517 fn test_normal_inverse_cdf_non_standard() {
518 let mean = 50.0;
519 let sigma = 5.0;
520
521 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
523 assert!((result - mean).abs() < EPSILON);
524
525 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
527 assert!((result - (mean + sigma)).abs() < EPSILON);
528 }
529
530 #[test]
531 fn test_normal_pdf_standard_normal() {
532 let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
534 assert!((pdf - 0.3989423).abs() < EPSILON);
535
536 let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
538 let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
539 assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
540
541 assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
543 assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
544 }
545
546 #[test]
547 fn test_normal_pdf_invalid_sigma() {
548 let result = normal_pdf(0.0, 0.0, -1.0);
549 assert!(
550 result.is_err(),
551 "Should return error for negative standard deviation"
552 );
553 assert!(matches!(
554 result.unwrap_err(),
555 StatsError::InvalidInput { .. }
556 ));
557 }
558
559 #[test]
560 fn test_normal_cdf_standard_normal() {
561 let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
563 assert!((cdf - 0.5).abs() < EPSILON);
564
565 let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
567 assert!((cdf - 0.8413447).abs() < EPSILON);
568
569 let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
570 assert!((cdf - 0.1586553).abs() < EPSILON);
571
572 let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
573 assert!((cdf - 0.9772499).abs() < EPSILON);
574 }
575
576 #[test]
577 fn test_normal_cdf_invalid_sigma() {
578 let result = normal_cdf(0.0, 0.0, -1.0);
579 assert!(
580 result.is_err(),
581 "Should return error for negative standard deviation"
582 );
583 assert!(matches!(
584 result.unwrap_err(),
585 StatsError::InvalidInput { .. }
586 ));
587 }
588
589 #[test]
590 fn test_normal_inverse_cdf_standard_normal() {
591 let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
593 assert!(x.abs() < EPSILON);
594
595 assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
597 assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
598 }
599
600 #[test]
601 fn test_normal_config_new_nan_mean() {
602 let result = NormalConfig::new(f64::NAN, 1.0);
603 assert!(result.is_err());
604 assert!(matches!(
605 result.unwrap_err(),
606 StatsError::InvalidInput { .. }
607 ));
608 }
609
610 #[test]
611 fn test_normal_config_new_nan_std_dev() {
612 let result = NormalConfig::new(0.0, f64::NAN);
613 assert!(result.is_err());
614 assert!(matches!(
615 result.unwrap_err(),
616 StatsError::InvalidInput { .. }
617 ));
618 }
619
620 #[test]
621 fn test_normal_config_new_std_dev_zero() {
622 let result = NormalConfig::new(0.0, 0.0);
623 assert!(result.is_err());
624 assert!(matches!(
625 result.unwrap_err(),
626 StatsError::InvalidInput { .. }
627 ));
628 }
629
630 #[test]
631 fn test_normal_config_new_std_dev_negative() {
632 let result = NormalConfig::new(0.0, -1.0);
633 assert!(result.is_err());
634 assert!(matches!(
635 result.unwrap_err(),
636 StatsError::InvalidInput { .. }
637 ));
638 }
639
640 #[test]
641 fn test_normal_inverse_cdf_p_negative() {
642 let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
643 assert!(result.is_err());
644 assert!(matches!(
645 result.unwrap_err(),
646 StatsError::InvalidInput { .. }
647 ));
648 }
649
650 #[test]
651 fn test_normal_inverse_cdf_p_greater_than_one() {
652 let result = normal_inverse_cdf(1.5, 0.0, 1.0);
653 assert!(result.is_err());
654 assert!(matches!(
655 result.unwrap_err(),
656 StatsError::InvalidInput { .. }
657 ));
658 }
659
660 #[test]
661 fn test_normal_inverse_cdf_p_zero() {
662 let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
663 assert_eq!(result, f64::NEG_INFINITY);
664 }
665
666 #[test]
667 fn test_normal_inverse_cdf_p_one() {
668 let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
669 assert_eq!(result, f64::INFINITY);
670 }
671
672 #[test]
673 fn test_normal_pdf_std_dev_zero() {
674 let result = normal_pdf(0.0, 0.0, 0.0);
675 assert!(result.is_err());
676 assert!(matches!(
677 result.unwrap_err(),
678 StatsError::InvalidInput { .. }
679 ));
680 }
681
682 #[test]
683 fn test_normal_cdf_std_dev_zero() {
684 let result = normal_cdf(0.0, 0.0, 0.0);
685 assert!(result.is_err());
686 assert!(matches!(
687 result.unwrap_err(),
688 StatsError::InvalidInput { .. }
689 ));
690 }
691
692 #[test]
693 fn test_normal_inverse_cdf_std_dev_zero() {
694 let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
696 assert_eq!(result, 5.0);
697 }
698
699 #[test]
700 fn test_normal_inverse_cdf_std_dev_negative() {
701 let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
703 assert_eq!(result, 0.0);
704 }
705
706 #[test]
707 fn test_normal_config_new_valid() {
708 let config = NormalConfig::new(0.0, 1.0);
709 assert!(config.is_ok());
710 let config = config.unwrap();
711 assert_eq!(config.mean, 0.0);
712 }
713}