1use crate::error::{StatsError, StatsResult};
31use num_traits::ToPrimitive;
32use serde::{Deserialize, Serialize};
33
34#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
47pub struct ExponentialConfig<T>
48where
49 T: ToPrimitive,
50{
51 pub lambda: T,
53}
54
55impl<T> ExponentialConfig<T>
56where
57 T: ToPrimitive,
58{
59 pub fn new(lambda: T) -> StatsResult<Self> {
67 let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
68 message: "ExponentialConfig::new: Failed to convert lambda to f64".to_string(),
69 })?;
70
71 if lambda_64 > 0.0 {
72 Ok(Self { lambda })
73 } else {
74 Err(StatsError::InvalidInput {
75 message: "ExponentialConfig::new: lambda must be positive".to_string(),
76 })
77 }
78 }
79}
80
81#[inline]
108pub fn exponential_pdf<T>(x: T, lambda: T) -> StatsResult<f64>
109where
110 T: ToPrimitive,
111{
112 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
113 message: "exponential_pdf: Failed to convert x to f64".to_string(),
114 })?;
115
116 if x_64 < 0.0 {
117 return Err(StatsError::InvalidInput {
118 message: "exponential_pdf: x must be non-negative".to_string(),
119 });
120 }
121
122 let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
123 message: "exponential_pdf: Failed to convert lambda to f64".to_string(),
124 })?;
125
126 if lambda_64 <= 0.0 {
127 return Err(StatsError::InvalidInput {
128 message: "exponential_pdf: lambda must be positive".to_string(),
129 });
130 }
131
132 Ok(if x_64 == 0.0 {
133 lambda_64
134 } else {
135 lambda_64 * (-lambda_64 * x_64).exp()
136 })
137}
138
139#[inline]
166pub fn exponential_cdf<T>(x: T, lambda: T) -> StatsResult<f64>
167where
168 T: ToPrimitive,
169{
170 let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
171 message: "exponential_cdf: Failed to convert x to f64".to_string(),
172 })?;
173
174 if x_64 < 0.0 {
175 return Err(StatsError::InvalidInput {
176 message: "exponential_cdf: x must be non-negative".to_string(),
177 });
178 }
179
180 let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
181 message: "exponential_cdf: Failed to convert lambda to f64".to_string(),
182 })?;
183
184 if lambda_64 <= 0.0 {
185 return Err(StatsError::InvalidInput {
186 message: "exponential_cdf: lambda must be positive".to_string(),
187 });
188 }
189 Ok(1.0 - (-lambda_64 * x_64).exp())
190}
191
192#[inline]
221pub fn exponential_inverse_cdf<T>(p: T, lambda: T) -> StatsResult<f64>
222where
223 T: ToPrimitive,
224{
225 let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
226 message: "exponential_inverse_cdf: Failed to convert p to f64".to_string(),
227 })?;
228
229 if !(0.0..=1.0).contains(&p_64) {
230 return Err(StatsError::InvalidInput {
231 message: "exponential_inverse_cdf: p must be between 0 and 1".to_string(),
232 });
233 }
234
235 let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
236 message: "exponential_inverse_cdf: Failed to convert lambda to f64".to_string(),
237 })?;
238
239 if lambda_64 <= 0.0 {
240 return Err(StatsError::InvalidInput {
241 message: "exponential_inverse_cdf: lambda must be positive".to_string(),
242 });
243 }
244 Ok(-((1.0 - p_64).ln()) / lambda_64)
245}
246
247#[inline]
271pub fn exponential_mean<T>(lambda: T) -> StatsResult<f64>
272where
273 T: ToPrimitive,
274{
275 let lambda_64 = lambda.to_f64().ok_or_else(|| StatsError::ConversionError {
276 message: "exponential_mean: Failed to convert lambda to f64".to_string(),
277 })?;
278 if lambda_64 <= 0.0 {
279 return Err(StatsError::InvalidInput {
280 message: "exponential_mean: lambda must be positive".to_string(),
281 });
282 }
283
284 Ok(1.0 / lambda_64)
285}
286
287#[inline]
311pub fn exponential_variance(lambda: f64) -> StatsResult<f64> {
312 if lambda <= 0.0 {
313 return Err(StatsError::InvalidInput {
314 message: "exponential_variance: lambda must be positive".to_string(),
315 });
316 }
317
318 Ok(1.0 / (lambda * lambda))
319}
320
321#[derive(Debug, Clone, Copy)]
334pub struct Exponential {
335 pub lambda: f64,
337}
338
339impl Exponential {
340 pub fn new(lambda: f64) -> StatsResult<Self> {
342 if lambda <= 0.0 {
343 return Err(StatsError::InvalidInput {
344 message: "Exponential::new: lambda must be positive".to_string(),
345 });
346 }
347 Ok(Self { lambda })
348 }
349
350 pub fn fit(data: &[f64]) -> StatsResult<Self> {
352 if data.is_empty() {
353 return Err(StatsError::InvalidInput {
354 message: "Exponential::fit: data must not be empty".to_string(),
355 });
356 }
357 if data.iter().any(|&x| x < 0.0) {
358 return Err(StatsError::InvalidInput {
359 message: "Exponential::fit: all data values must be non-negative".to_string(),
360 });
361 }
362 let mean = data.iter().sum::<f64>() / data.len() as f64;
363 Self::new(1.0 / mean)
364 }
365}
366
367impl crate::distributions::traits::Distribution for Exponential {
368 fn name(&self) -> &str {
369 "Exponential"
370 }
371 fn num_params(&self) -> usize {
372 1
373 }
374 fn pdf(&self, x: f64) -> StatsResult<f64> {
375 exponential_pdf(x, self.lambda)
376 }
377 fn logpdf(&self, x: f64) -> StatsResult<f64> {
378 if x < 0.0 {
379 return Err(StatsError::InvalidInput {
380 message: "Exponential::logpdf: x must be non-negative".to_string(),
381 });
382 }
383 Ok(self.lambda.ln() - self.lambda * x)
384 }
385 fn cdf(&self, x: f64) -> StatsResult<f64> {
386 exponential_cdf(x, self.lambda)
387 }
388 fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
389 exponential_inverse_cdf(p, self.lambda)
390 }
391 fn mean(&self) -> f64 {
392 1.0 / self.lambda
393 }
394 fn variance(&self) -> f64 {
395 1.0 / (self.lambda * self.lambda)
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 const EPSILON: f64 = 1e-10;
404
405 #[test]
406 fn test_exponential_pdf() {
407 let lambda = 2.0;
408
409 let result = exponential_pdf(0.0, lambda).unwrap();
411 assert_eq!(result, lambda);
412
413 let result = exponential_pdf(1.0, lambda).unwrap();
415 let expected = lambda * (-lambda).exp();
416 assert!((result - expected).abs() < EPSILON);
417
418 let result = exponential_pdf(0.5, lambda).unwrap();
420 let expected = lambda * (-lambda * 0.5).exp();
421 assert!((result - expected).abs() < EPSILON);
422 }
423
424 #[test]
425 fn test_exponential_cdf() {
426 let lambda = 2.0_f64;
427
428 let result = exponential_cdf(0.0, lambda).unwrap();
430 assert!((result - 0.0).abs() < EPSILON);
431
432 let result = exponential_cdf(1.0, lambda).unwrap();
434 let expected = 1.0 - (-lambda).exp();
435 assert!((result - expected).abs() < EPSILON);
436
437 let result = exponential_cdf(0.5, lambda).unwrap();
439 let expected = 1.0 - (-lambda * 0.5).exp();
440 assert!((result - expected).abs() < EPSILON);
441 }
442
443 #[test]
444 fn test_exponential_inverse_cdf() {
445 let lambda = 2.0_f64;
446
447 let test_cases = vec![0.1, 0.25, 0.5, 0.75, 0.9];
449
450 for p in test_cases {
451 let x = exponential_inverse_cdf(p, lambda).unwrap();
452 let cdf = exponential_cdf(x, lambda).unwrap();
453 assert!(
454 (cdf - p).abs() < EPSILON,
455 "Inverse CDF failed for p = {}: got {}, expected {}",
456 p,
457 cdf,
458 p
459 );
460 }
461 }
462
463 #[test]
464 fn test_exponential_mean() {
465 let lambda = 2.0;
466 let result = exponential_mean(lambda).unwrap();
467 let expected = 1.0 / lambda;
468 assert!((result - expected).abs() < EPSILON);
469 }
470
471 #[test]
472 fn test_exponential_variance() {
473 let lambda = 2.0;
474 let result = exponential_variance(lambda).unwrap();
475 let expected = 1.0 / (lambda * lambda);
476 assert!((result - expected).abs() < EPSILON);
477 }
478
479 #[test]
480 fn test_exponential_pdf_invalid_lambda() {
481 let result = exponential_pdf(1.0, -2.0);
482 assert!(result.is_err());
483 match result {
484 Err(StatsError::InvalidInput { message }) => {
485 assert!(message.contains("lambda must be positive"));
486 }
487 _ => panic!("Expected InvalidInput error"),
488 }
489 }
490
491 #[test]
492 fn test_exponential_pdf_invalid_x() {
493 let result = exponential_pdf(-1.0, 2.0);
494 assert!(result.is_err());
495 match result {
496 Err(StatsError::InvalidInput { message }) => {
497 assert!(message.contains("x must be non-negative"));
498 }
499 _ => panic!("Expected InvalidInput error"),
500 }
501 }
502
503 #[test]
504 fn test_exponential_config() {
505 let config = ExponentialConfig::new(2.0);
507 assert!(config.is_ok());
508
509 let config = ExponentialConfig::new(0.0);
511 assert!(config.is_err());
512
513 let config = ExponentialConfig::new(-1.0);
514 assert!(config.is_err());
515 }
516
517 #[test]
518 fn test_exponential_inverse_cdf_p_negative() {
519 let result = exponential_inverse_cdf(-0.1, 2.0);
520 assert!(result.is_err());
521 assert!(matches!(
522 result.unwrap_err(),
523 StatsError::InvalidInput { .. }
524 ));
525 }
526
527 #[test]
528 fn test_exponential_inverse_cdf_p_greater_than_one() {
529 let result = exponential_inverse_cdf(1.5, 2.0);
530 assert!(result.is_err());
531 assert!(matches!(
532 result.unwrap_err(),
533 StatsError::InvalidInput { .. }
534 ));
535 }
536
537 #[test]
538 fn test_exponential_cdf_invalid_lambda() {
539 let result = exponential_cdf(1.0, -2.0);
540 assert!(result.is_err());
541 assert!(matches!(
542 result.unwrap_err(),
543 StatsError::InvalidInput { .. }
544 ));
545 }
546
547 #[test]
548 fn test_exponential_cdf_invalid_x() {
549 let result = exponential_cdf(-1.0, 2.0);
550 assert!(result.is_err());
551 assert!(matches!(
552 result.unwrap_err(),
553 StatsError::InvalidInput { .. }
554 ));
555 }
556
557 #[test]
558 fn test_exponential_mean_invalid_lambda() {
559 let result = exponential_mean(0.0);
560 assert!(result.is_err());
561 assert!(matches!(
562 result.unwrap_err(),
563 StatsError::InvalidInput { .. }
564 ));
565 }
566
567 #[test]
568 fn test_exponential_variance_invalid_lambda() {
569 let result = exponential_variance(0.0);
570 assert!(result.is_err());
571 assert!(matches!(
572 result.unwrap_err(),
573 StatsError::InvalidInput { .. }
574 ));
575 }
576
577 #[test]
578 fn test_exponential_pdf_x_positive() {
579 let result = exponential_pdf(0.5, 2.0).unwrap();
581 let lambda: f64 = 2.0;
582 let x: f64 = 0.5;
583 let expected = lambda * (-lambda * x).exp();
584 assert!((result - expected).abs() < EPSILON);
585 }
586
587 #[test]
588 fn test_exponential_inverse_cdf_p_zero() {
589 let result = exponential_inverse_cdf(0.0, 2.0).unwrap();
591 assert_eq!(result, 0.0);
592 }
593
594 #[test]
595 fn test_exponential_inverse_cdf_p_one() {
596 let result = exponential_inverse_cdf(1.0, 2.0).unwrap();
599 assert!(result.is_infinite() || result > 1e10);
600 }
601
602 #[test]
603 fn test_exponential_inverse_cdf_lambda_zero() {
604 let result = exponential_inverse_cdf(0.5, 0.0);
605 assert!(result.is_err());
606 assert!(matches!(
607 result.unwrap_err(),
608 StatsError::InvalidInput { .. }
609 ));
610 }
611
612 #[test]
613 fn test_exponential_inverse_cdf_lambda_negative() {
614 let result = exponential_inverse_cdf(0.5, -1.0);
615 assert!(result.is_err());
616 assert!(matches!(
617 result.unwrap_err(),
618 StatsError::InvalidInput { .. }
619 ));
620 }
621}