1use num_traits::ToPrimitive;
48
49use crate::distributions::traits::Distribution;
50use crate::error::{StatsError, StatsResult};
51use crate::prob::erf;
52use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
53use serde::{Deserialize, Serialize};
54
55#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
69pub struct NormalConfig<T>
70where
71 T: ToPrimitive,
72{
73 pub mean: T,
75 pub std_dev: T,
77}
78
79impl<T> NormalConfig<T>
80where
81 T: ToPrimitive,
82{
83 pub fn new(mean: T, std_dev: T) -> StatsResult<Self> {
103 let std_dev_64 = std_dev
104 .to_f64()
105 .ok_or_else(|| StatsError::ConversionError {
106 message: "NormalConfig::new: Failed to convert std_dev to f64".to_string(),
107 })?;
108 let mean_64 = mean.to_f64().ok_or_else(|| StatsError::ConversionError {
109 message: "NormalConfig::new: Failed to convert mean to f64".to_string(),
110 })?;
111
112 if std_dev_64 > 0.0 && !mean_64.is_nan() && !std_dev_64.is_nan() {
113 Ok(Self { mean, std_dev })
114 } else {
115 Err(StatsError::InvalidInput {
116 message: "NormalConfig::new: std_dev must be positive".to_string(),
117 })
118 }
119 }
120}
121
122#[inline]
150pub fn normal_pdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
151where
152 T: ToPrimitive,
153{
154 if std_dev <= 0.0 {
155 return Err(StatsError::InvalidInput {
156 message: "normal_pdf: Standard deviation must be positive".to_string(),
157 });
158 }
159
160 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
161 message: "normal_pdf: Failed to convert x to f64".to_string(),
162 })?;
163
164 let z = (x_64 - mean) / std_dev;
166 let exponent = -0.5 * z * z;
167 Ok(exponent.exp() * INV_SQRT_2PI / std_dev)
169}
170
171#[inline]
199pub fn normal_cdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
200where
201 T: ToPrimitive,
202{
203 if std_dev <= 0.0 {
204 return Err(StatsError::InvalidInput {
205 message: "normal_cdf: Standard deviation must be positive".to_string(),
206 });
207 }
208
209 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
210 message: "normal_cdf: Failed to convert x to f64".to_string(),
211 })?;
212
213 if x_64 == mean {
215 return Ok(0.5);
216 }
217
218 let z = (x_64 - mean) / (std_dev * SQRT_2);
222 Ok(0.5 * (1.0 + erf(z)?))
223}
224
225#[inline]
246pub fn normal_inverse_cdf<T>(p: T, mean: f64, std_dev: f64) -> StatsResult<f64>
247where
248 T: ToPrimitive,
249{
250 let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
251 message: "normal_inverse_cdf: Failed to convert p to f64".to_string(),
252 })?;
253
254 if !(0.0..=1.0).contains(&p_64) {
255 return Err(StatsError::InvalidInput {
256 message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
257 });
258 }
259
260 if p_64 == 0.0 {
262 return Ok(f64::NEG_INFINITY);
263 }
264 if p_64 == 1.0 {
265 return Ok(f64::INFINITY);
266 }
267
268 let q = if p_64 <= 0.5 { p_64 } else { 1.0 - p_64 };
273
274 if q <= 0.0 {
276 return if p_64 <= 0.5 {
277 Ok(f64::NEG_INFINITY)
278 } else {
279 Ok(f64::INFINITY)
280 };
281 }
282
283 let a = [
285 -3.969_683_028_665_376e1,
286 2.209_460_984_245_205e2,
287 -2.759_285_104_469_687e2,
288 1.383_577_518_672_69e2,
289 -3.066_479_806_614_716e1,
290 2.506_628_277_459_239,
291 ];
292
293 let b = [
294 -5.447_609_879_822_406e1,
295 1.615_858_368_580_409e2,
296 -1.556_989_798_598_866e2,
297 6.680_131_188_771_972e1,
298 -1.328_068_155_288_572e1,
299 1.0,
300 ];
301
302 let r = q - 0.5;
304
305 let z = if q > 0.02425 && q < 0.97575 {
306 let r2 = r * r;
308 let num = ((((a[0] * r2 + a[1]) * r2 + a[2]) * r2 + a[3]) * r2 + a[4]) * r2 + a[5];
309 let den = ((((b[0] * r2 + b[1]) * r2 + b[2]) * r2 + b[3]) * r2 + b[4]) * r2 + b[5];
310 r * num / den
311 } else {
312 let s = if r < 0.0 { q } else { 1.0 - q };
314 let t = (-2.0 * s.ln()).sqrt();
315
316 let c = [
318 -7.784_894_002_430_293e-3,
319 -3.223_964_580_411_365e-1,
320 -2.400_758_277_161_838,
321 -2.549_732_539_343_734,
322 4.374_664_141_464_968,
323 2.938_163_982_698_783,
324 ];
325
326 let d = [
327 7.784_695_709_041_462e-3,
328 3.224_671_290_700_398e-1,
329 2.445_134_137_142_996,
330 3.754_408_661_907_416,
331 1.0,
332 ];
333
334 let num = ((((c[0] * t + c[1]) * t + c[2]) * t + c[3]) * t + c[4]) * t + c[5];
335 let den = (((d[0] * t + d[1]) * t + d[2]) * t + d[3]) * t + d[4];
336 if r < 0.0 {
337 -t - num / den
338 } else {
339 t - num / den
340 }
341 };
342
343 let final_z = if p_64 > 0.5 { -z } else { z };
345
346 let result = mean + std_dev * final_z;
347 Ok(result)
349}
350
351#[derive(Debug, Clone, Copy)]
367pub struct Normal {
368 pub mean: f64,
370 pub std_dev: f64,
372}
373
374impl Normal {
375 pub fn new(mean: f64, std_dev: f64) -> StatsResult<Self> {
377 if std_dev <= 0.0 || std_dev.is_nan() || mean.is_nan() {
378 return Err(StatsError::InvalidInput {
379 message: "Normal::new: std_dev must be positive and parameters must be finite"
380 .to_string(),
381 });
382 }
383 Ok(Self { mean, std_dev })
384 }
385
386 pub fn fit(data: &[f64]) -> StatsResult<Self> {
390 if data.is_empty() {
391 return Err(StatsError::InvalidInput {
392 message: "Normal::fit: data must not be empty".to_string(),
393 });
394 }
395 let n = data.len() as f64;
396 let mean = data.iter().sum::<f64>() / n;
397 let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n;
398 Self::new(mean, variance.sqrt())
399 }
400}
401
402impl Distribution for Normal {
403 fn name(&self) -> &str {
404 "Normal"
405 }
406 fn num_params(&self) -> usize {
407 2
408 }
409 fn pdf(&self, x: f64) -> StatsResult<f64> {
410 normal_pdf(x, self.mean, self.std_dev)
411 }
412 fn logpdf(&self, x: f64) -> StatsResult<f64> {
413 let z = (x - self.mean) / self.std_dev;
414 Ok(-0.5 * z * z - self.std_dev.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln())
415 }
416 fn cdf(&self, x: f64) -> StatsResult<f64> {
417 normal_cdf(x, self.mean, self.std_dev)
418 }
419 fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
420 normal_inverse_cdf(p, self.mean, self.std_dev)
421 }
422 fn mean(&self) -> f64 {
423 self.mean
424 }
425 fn variance(&self) -> f64 {
426 self.std_dev * self.std_dev
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 const EPSILON: f64 = 1e-7;
436
437 #[test]
438 fn test_normal_pdf_standard() {
439 let mean = 0.0;
440 let sigma = 1.0;
441
442 let result = normal_pdf(mean, mean, sigma).unwrap();
444 assert!((result - 0.3989422804014327).abs() < 1e-10);
445
446 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
448 assert!((result - 0.24197072451914337).abs() < 1e-10);
449 }
450
451 #[test]
452 fn test_normal_pdf_non_standard() {
453 let mean = 5.0;
454 let sigma = 2.0;
455
456 let result = normal_pdf(mean, mean, sigma).unwrap();
458 assert!((result - 0.19947114020071635).abs() < 1e-10);
459
460 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
462 assert!((result - 0.12098536225957168).abs() < 1e-10);
463 }
464
465 #[test]
466 fn test_normal_pdf_symmetry() {
467 let mean = 0.0;
468 let sigma = 1.0;
469 let x = 1.5;
470
471 let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
472 let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
473
474 assert!((pdf_plus - pdf_minus).abs() < 1e-10);
475 }
476
477 #[test]
478 fn test_normal_cdf_standard() {
479 let mean = 0.0;
480 let sigma = 1.0;
481
482 let result = normal_cdf(mean, mean, sigma).unwrap();
484 assert!((result - 0.5).abs() < 1e-10);
485
486 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
488 assert!((result - 0.8413447460685429).abs() < EPSILON);
489
490 let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
492 assert!((result - 0.15865525393145707).abs() < EPSILON);
493 }
494
495 #[test]
496 fn test_normal_cdf_non_standard() {
497 let mean = 100.0;
498 let sigma = 15.0;
499
500 let result = normal_cdf(mean, mean, sigma).unwrap();
502 assert!((result - 0.5).abs() < 1e-10);
503
504 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
506 assert!((result - 0.8413447460685429).abs() < EPSILON);
507 }
508
509 #[test]
510 fn test_normal_inverse_cdf() {
511 let mean = 0.0;
512 let sigma = 1.0;
513
514 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
516 assert!((result - mean).abs() < EPSILON);
517
518 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
520 assert!((result - sigma).abs() < EPSILON);
521
522 let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
524 assert!((result - (-sigma)).abs() < EPSILON);
525 }
526
527 #[test]
528 fn test_normal_inverse_cdf_non_standard() {
529 let mean = 50.0;
530 let sigma = 5.0;
531
532 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
534 assert!((result - mean).abs() < EPSILON);
535
536 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
538 assert!((result - (mean + sigma)).abs() < EPSILON);
539 }
540
541 #[test]
542 fn test_normal_pdf_standard_normal() {
543 let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
545 assert!((pdf - 0.3989423).abs() < EPSILON);
546
547 let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
549 let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
550 assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
551
552 assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
554 assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
555 }
556
557 #[test]
558 fn test_normal_pdf_invalid_sigma() {
559 let result = normal_pdf(0.0, 0.0, -1.0);
560 assert!(
561 result.is_err(),
562 "Should return error for negative standard deviation"
563 );
564 assert!(matches!(
565 result.unwrap_err(),
566 StatsError::InvalidInput { .. }
567 ));
568 }
569
570 #[test]
571 fn test_normal_cdf_standard_normal() {
572 let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
574 assert!((cdf - 0.5).abs() < EPSILON);
575
576 let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
578 assert!((cdf - 0.8413447).abs() < EPSILON);
579
580 let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
581 assert!((cdf - 0.1586553).abs() < EPSILON);
582
583 let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
584 assert!((cdf - 0.9772499).abs() < EPSILON);
585 }
586
587 #[test]
588 fn test_normal_cdf_invalid_sigma() {
589 let result = normal_cdf(0.0, 0.0, -1.0);
590 assert!(
591 result.is_err(),
592 "Should return error for negative standard deviation"
593 );
594 assert!(matches!(
595 result.unwrap_err(),
596 StatsError::InvalidInput { .. }
597 ));
598 }
599
600 #[test]
601 fn test_normal_inverse_cdf_standard_normal() {
602 let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
604 assert!(x.abs() < EPSILON);
605
606 assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
608 assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
609 }
610
611 #[test]
612 fn test_normal_config_new_nan_mean() {
613 let result = NormalConfig::new(f64::NAN, 1.0);
614 assert!(result.is_err());
615 assert!(matches!(
616 result.unwrap_err(),
617 StatsError::InvalidInput { .. }
618 ));
619 }
620
621 #[test]
622 fn test_normal_config_new_nan_std_dev() {
623 let result = NormalConfig::new(0.0, f64::NAN);
624 assert!(result.is_err());
625 assert!(matches!(
626 result.unwrap_err(),
627 StatsError::InvalidInput { .. }
628 ));
629 }
630
631 #[test]
632 fn test_normal_config_new_std_dev_zero() {
633 let result = NormalConfig::new(0.0, 0.0);
634 assert!(result.is_err());
635 assert!(matches!(
636 result.unwrap_err(),
637 StatsError::InvalidInput { .. }
638 ));
639 }
640
641 #[test]
642 fn test_normal_config_new_std_dev_negative() {
643 let result = NormalConfig::new(0.0, -1.0);
644 assert!(result.is_err());
645 assert!(matches!(
646 result.unwrap_err(),
647 StatsError::InvalidInput { .. }
648 ));
649 }
650
651 #[test]
652 fn test_normal_inverse_cdf_p_negative() {
653 let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
654 assert!(result.is_err());
655 assert!(matches!(
656 result.unwrap_err(),
657 StatsError::InvalidInput { .. }
658 ));
659 }
660
661 #[test]
662 fn test_normal_inverse_cdf_p_greater_than_one() {
663 let result = normal_inverse_cdf(1.5, 0.0, 1.0);
664 assert!(result.is_err());
665 assert!(matches!(
666 result.unwrap_err(),
667 StatsError::InvalidInput { .. }
668 ));
669 }
670
671 #[test]
672 fn test_normal_inverse_cdf_p_zero() {
673 let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
674 assert_eq!(result, f64::NEG_INFINITY);
675 }
676
677 #[test]
678 fn test_normal_inverse_cdf_p_one() {
679 let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
680 assert_eq!(result, f64::INFINITY);
681 }
682
683 #[test]
684 fn test_normal_pdf_std_dev_zero() {
685 let result = normal_pdf(0.0, 0.0, 0.0);
686 assert!(result.is_err());
687 assert!(matches!(
688 result.unwrap_err(),
689 StatsError::InvalidInput { .. }
690 ));
691 }
692
693 #[test]
694 fn test_normal_cdf_std_dev_zero() {
695 let result = normal_cdf(0.0, 0.0, 0.0);
696 assert!(result.is_err());
697 assert!(matches!(
698 result.unwrap_err(),
699 StatsError::InvalidInput { .. }
700 ));
701 }
702
703 #[test]
704 fn test_normal_inverse_cdf_std_dev_zero() {
705 let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
707 assert_eq!(result, 5.0);
708 }
709
710 #[test]
711 fn test_normal_inverse_cdf_std_dev_negative() {
712 let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
714 assert_eq!(result, 0.0);
715 }
716
717 #[test]
718 fn test_normal_config_new_valid() {
719 let config = NormalConfig::new(0.0, 1.0);
720 assert!(config.is_ok());
721 let config = config.unwrap();
722 assert_eq!(config.mean, 0.0);
723 }
724}