1use num_traits::ToPrimitive;
2
3use crate::error::{StatsError, StatsResult};
4use crate::prob::erf;
5use crate::utils::constants::{INV_SQRT_2PI, SQRT_2};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22pub struct NormalConfig<T>
23where
24 T: ToPrimitive,
25{
26 pub mean: T,
28 pub std_dev: T,
30}
31
32impl<T> NormalConfig<T>
33where
34 T: ToPrimitive,
35{
36 pub fn new(mean: T, std_dev: T) -> StatsResult<Self> {
56 let std_dev_64 = std_dev
57 .to_f64()
58 .ok_or_else(|| StatsError::ConversionError {
59 message: "NormalConfig::new: Failed to convert std_dev to f64".to_string(),
60 })?;
61 let mean_64 = mean.to_f64().ok_or_else(|| StatsError::ConversionError {
62 message: "NormalConfig::new: Failed to convert mean to f64".to_string(),
63 })?;
64
65 if std_dev_64 > 0.0 && !mean_64.is_nan() && !std_dev_64.is_nan() {
66 Ok(Self { mean, std_dev })
67 } else {
68 Err(StatsError::InvalidInput {
69 message: "NormalConfig::new: std_dev must be positive".to_string(),
70 })
71 }
72 }
73}
74
75#[inline]
103pub fn normal_pdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
104where
105 T: ToPrimitive,
106{
107 if std_dev <= 0.0 {
108 return Err(StatsError::InvalidInput {
109 message: "normal_pdf: Standard deviation must be positive".to_string(),
110 });
111 }
112
113 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
114 message: "normal_pdf: Failed to convert x to f64".to_string(),
115 })?;
116
117 let z = (x_64 - mean) / std_dev;
119 let exponent = -0.5 * z * z;
120 Ok(exponent.exp() * INV_SQRT_2PI / std_dev)
122}
123
124#[inline]
152pub fn normal_cdf<T>(x: T, mean: f64, std_dev: f64) -> StatsResult<f64>
153where
154 T: ToPrimitive,
155{
156 if std_dev <= 0.0 {
157 return Err(StatsError::InvalidInput {
158 message: "normal_cdf: Standard deviation must be positive".to_string(),
159 });
160 }
161
162 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
163 message: "normal_cdf: Failed to convert x to f64".to_string(),
164 })?;
165
166 if x_64 == mean {
168 return Ok(0.5);
169 }
170
171 let z = (x_64 - mean) / (std_dev * SQRT_2);
175 Ok(0.5 * (1.0 + erf(z)?))
176}
177
178#[inline]
199pub fn normal_inverse_cdf<T>(p: T, mean: f64, std_dev: f64) -> StatsResult<f64>
200where
201 T: ToPrimitive,
202{
203 let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
204 message: "normal_inverse_cdf: Failed to convert p to f64".to_string(),
205 })?;
206
207 if !(0.0..=1.0).contains(&p_64) {
208 return Err(StatsError::InvalidInput {
209 message: "normal_inverse_cdf: Probability must be between 0 and 1".to_string(),
210 });
211 }
212
213 if p_64 == 0.0 {
215 return Ok(f64::NEG_INFINITY);
216 }
217 if p_64 == 1.0 {
218 return Ok(f64::INFINITY);
219 }
220
221 let q = if p_64 <= 0.5 { p_64 } else { 1.0 - p_64 };
226
227 if q <= 0.0 {
229 return if p_64 <= 0.5 {
230 Ok(f64::NEG_INFINITY)
231 } else {
232 Ok(f64::INFINITY)
233 };
234 }
235
236 let a = [
238 -3.969_683_028_665_376e1,
239 2.209_460_984_245_205e2,
240 -2.759_285_104_469_687e2,
241 1.383_577_518_672_69e2,
242 -3.066_479_806_614_716e1,
243 2.506_628_277_459_239,
244 ];
245
246 let b = [
247 -5.447_609_879_822_406e1,
248 1.615_858_368_580_409e2,
249 -1.556_989_798_598_866e2,
250 6.680_131_188_771_972e1,
251 -1.328_068_155_288_572e1,
252 1.0,
253 ];
254
255 let r = q - 0.5;
257
258 let z = if q > 0.02425 && q < 0.97575 {
259 let r2 = r * r;
261 let num = ((((a[0] * r2 + a[1]) * r2 + a[2]) * r2 + a[3]) * r2 + a[4]) * r2 + a[5];
262 let den = ((((b[0] * r2 + b[1]) * r2 + b[2]) * r2 + b[3]) * r2 + b[4]) * r2 + b[5];
263 r * num / den
264 } else {
265 let s = if r < 0.0 { q } else { 1.0 - q };
267 let t = (-2.0 * s.ln()).sqrt();
268
269 let c = [
271 -7.784_894_002_430_293e-3,
272 -3.223_964_580_411_365e-1,
273 -2.400_758_277_161_838,
274 -2.549_732_539_343_734,
275 4.374_664_141_464_968,
276 2.938_163_982_698_783,
277 ];
278
279 let d = [
280 7.784_695_709_041_462e-3,
281 3.224_671_290_700_398e-1,
282 2.445_134_137_142_996,
283 3.754_408_661_907_416,
284 1.0,
285 ];
286
287 let num = ((((c[0] * t + c[1]) * t + c[2]) * t + c[3]) * t + c[4]) * t + c[5];
288 let den = (((d[0] * t + d[1]) * t + d[2]) * t + d[3]) * t + d[4];
289 if r < 0.0 {
290 -t - num / den
291 } else {
292 t - num / den
293 }
294 };
295
296 let final_z = if p_64 > 0.5 { -z } else { z };
298
299 let result = mean + std_dev * final_z;
300 Ok(result)
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 const EPSILON: f64 = 1e-7;
310
311 #[test]
312 fn test_normal_pdf_standard() {
313 let mean = 0.0;
314 let sigma = 1.0;
315
316 let result = normal_pdf(mean, mean, sigma).unwrap();
318 assert!((result - 0.3989422804014327).abs() < 1e-10);
319
320 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
322 assert!((result - 0.24197072451914337).abs() < 1e-10);
323 }
324
325 #[test]
326 fn test_normal_pdf_non_standard() {
327 let mean = 5.0;
328 let sigma = 2.0;
329
330 let result = normal_pdf(mean, mean, sigma).unwrap();
332 assert!((result - 0.19947114020071635).abs() < 1e-10);
333
334 let result = normal_pdf(mean + sigma, mean, sigma).unwrap();
336 assert!((result - 0.12098536225957168).abs() < 1e-10);
337 }
338
339 #[test]
340 fn test_normal_pdf_symmetry() {
341 let mean = 0.0;
342 let sigma = 1.0;
343 let x = 1.5;
344
345 let pdf_plus = normal_pdf(mean + x, mean, sigma).unwrap();
346 let pdf_minus = normal_pdf(mean - x, mean, sigma).unwrap();
347
348 assert!((pdf_plus - pdf_minus).abs() < 1e-10);
349 }
350
351 #[test]
352 fn test_normal_cdf_standard() {
353 let mean = 0.0;
354 let sigma = 1.0;
355
356 let result = normal_cdf(mean, mean, sigma).unwrap();
358 assert!((result - 0.5).abs() < 1e-10);
359
360 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
362 assert!((result - 0.8413447460685429).abs() < EPSILON);
363
364 let result = normal_cdf(mean - sigma, mean, sigma).unwrap();
366 assert!((result - 0.15865525393145707).abs() < EPSILON);
367 }
368
369 #[test]
370 fn test_normal_cdf_non_standard() {
371 let mean = 100.0;
372 let sigma = 15.0;
373
374 let result = normal_cdf(mean, mean, sigma).unwrap();
376 assert!((result - 0.5).abs() < 1e-10);
377
378 let result = normal_cdf(mean + sigma, mean, sigma).unwrap();
380 assert!((result - 0.8413447460685429).abs() < EPSILON);
381 }
382
383 #[test]
384 fn test_normal_inverse_cdf() {
385 let mean = 0.0;
386 let sigma = 1.0;
387
388 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
390 assert!((result - mean).abs() < EPSILON);
391
392 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
394 assert!((result - sigma).abs() < EPSILON);
395
396 let result = normal_inverse_cdf(0.15865525393145707, mean, sigma).unwrap();
398 assert!((result - (-sigma)).abs() < EPSILON);
399 }
400
401 #[test]
402 fn test_normal_inverse_cdf_non_standard() {
403 let mean = 50.0;
404 let sigma = 5.0;
405
406 let result = normal_inverse_cdf(0.5, mean, sigma).unwrap();
408 assert!((result - mean).abs() < EPSILON);
409
410 let result = normal_inverse_cdf(0.8413447460685429, mean, sigma).unwrap();
412 assert!((result - (mean + sigma)).abs() < EPSILON);
413 }
414
415 #[test]
416 fn test_normal_pdf_standard_normal() {
417 let pdf = (normal_pdf(0.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
419 assert!((pdf - 0.3989423).abs() < EPSILON);
420
421 let pdf_plus1 = normal_pdf(1.0, 0.0, 1.0).unwrap();
423 let pdf_minus1 = normal_pdf(-1.0, 0.0, 1.0).unwrap();
424 assert!((pdf_plus1 - pdf_minus1).abs() < EPSILON);
425
426 assert!((normal_pdf(1.0, 0.0, 1.0).unwrap() - 0.2419707).abs() < EPSILON);
428 assert!((normal_pdf(2.0, 0.0, 1.0).unwrap() - 0.0539909).abs() < EPSILON);
429 }
430
431 #[test]
432 fn test_normal_pdf_invalid_sigma() {
433 let result = normal_pdf(0.0, 0.0, -1.0);
434 assert!(
435 result.is_err(),
436 "Should return error for negative standard deviation"
437 );
438 assert!(matches!(
439 result.unwrap_err(),
440 StatsError::InvalidInput { .. }
441 ));
442 }
443
444 #[test]
445 fn test_normal_cdf_standard_normal() {
446 let cdf = (normal_cdf(0.0, 0.0, 1.0).unwrap() * 1e1).round() / 1e1;
448 assert!((cdf - 0.5).abs() < EPSILON);
449
450 let cdf = (normal_cdf(1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
452 assert!((cdf - 0.8413447).abs() < EPSILON);
453
454 let cdf = (normal_cdf(-1.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
455 assert!((cdf - 0.1586553).abs() < EPSILON);
456
457 let cdf = (normal_cdf(2.0, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
458 assert!((cdf - 0.9772499).abs() < EPSILON);
459 }
460
461 #[test]
462 fn test_normal_cdf_invalid_sigma() {
463 let result = normal_cdf(0.0, 0.0, -1.0);
464 assert!(
465 result.is_err(),
466 "Should return error for negative standard deviation"
467 );
468 assert!(matches!(
469 result.unwrap_err(),
470 StatsError::InvalidInput { .. }
471 ));
472 }
473
474 #[test]
475 fn test_normal_inverse_cdf_standard_normal() {
476 let x = (normal_inverse_cdf(0.5, 0.0, 1.0).unwrap() * 1e7).round() / 1e7;
478 assert!(x.abs() < EPSILON);
479
480 assert!((normal_inverse_cdf(0.8413447, 0.0, 1.0).unwrap() - 1.0).abs() < 0.01);
482 assert!((normal_inverse_cdf(0.1586553, 0.0, 1.0).unwrap() + 1.0).abs() < 0.01);
483 }
484
485 #[test]
486 fn test_normal_config_new_nan_mean() {
487 let result = NormalConfig::new(f64::NAN, 1.0);
488 assert!(result.is_err());
489 assert!(matches!(
490 result.unwrap_err(),
491 StatsError::InvalidInput { .. }
492 ));
493 }
494
495 #[test]
496 fn test_normal_config_new_nan_std_dev() {
497 let result = NormalConfig::new(0.0, f64::NAN);
498 assert!(result.is_err());
499 assert!(matches!(
500 result.unwrap_err(),
501 StatsError::InvalidInput { .. }
502 ));
503 }
504
505 #[test]
506 fn test_normal_config_new_std_dev_zero() {
507 let result = NormalConfig::new(0.0, 0.0);
508 assert!(result.is_err());
509 assert!(matches!(
510 result.unwrap_err(),
511 StatsError::InvalidInput { .. }
512 ));
513 }
514
515 #[test]
516 fn test_normal_config_new_std_dev_negative() {
517 let result = NormalConfig::new(0.0, -1.0);
518 assert!(result.is_err());
519 assert!(matches!(
520 result.unwrap_err(),
521 StatsError::InvalidInput { .. }
522 ));
523 }
524
525 #[test]
526 fn test_normal_inverse_cdf_p_negative() {
527 let result = normal_inverse_cdf(-0.1, 0.0, 1.0);
528 assert!(result.is_err());
529 assert!(matches!(
530 result.unwrap_err(),
531 StatsError::InvalidInput { .. }
532 ));
533 }
534
535 #[test]
536 fn test_normal_inverse_cdf_p_greater_than_one() {
537 let result = normal_inverse_cdf(1.5, 0.0, 1.0);
538 assert!(result.is_err());
539 assert!(matches!(
540 result.unwrap_err(),
541 StatsError::InvalidInput { .. }
542 ));
543 }
544
545 #[test]
546 fn test_normal_inverse_cdf_p_zero() {
547 let result = normal_inverse_cdf(0.0, 0.0, 1.0).unwrap();
548 assert_eq!(result, f64::NEG_INFINITY);
549 }
550
551 #[test]
552 fn test_normal_inverse_cdf_p_one() {
553 let result = normal_inverse_cdf(1.0, 0.0, 1.0).unwrap();
554 assert_eq!(result, f64::INFINITY);
555 }
556
557 #[test]
558 fn test_normal_pdf_std_dev_zero() {
559 let result = normal_pdf(0.0, 0.0, 0.0);
560 assert!(result.is_err());
561 assert!(matches!(
562 result.unwrap_err(),
563 StatsError::InvalidInput { .. }
564 ));
565 }
566
567 #[test]
568 fn test_normal_cdf_std_dev_zero() {
569 let result = normal_cdf(0.0, 0.0, 0.0);
570 assert!(result.is_err());
571 assert!(matches!(
572 result.unwrap_err(),
573 StatsError::InvalidInput { .. }
574 ));
575 }
576
577 #[test]
578 fn test_normal_inverse_cdf_std_dev_zero() {
579 let result = normal_inverse_cdf(0.5, 5.0, 0.0).unwrap();
581 assert_eq!(result, 5.0);
582 }
583
584 #[test]
585 fn test_normal_inverse_cdf_std_dev_negative() {
586 let result = normal_inverse_cdf(0.5, 0.0, -1.0).unwrap();
588 assert_eq!(result, 0.0);
589 }
590
591 #[test]
592 fn test_normal_config_new_valid() {
593 let config = NormalConfig::new(0.0, 1.0);
594 assert!(config.is_ok());
595 let config = config.unwrap();
596 assert_eq!(config.mean, 0.0);
597 }
598}