Skip to main content

so_stats/
distributions.rs

1//! Probability distributions for statistical computing
2//!
3//! Provides common probability distributions with PDF, CDF, quantile functions,
4//! and random number generation.
5
6#![allow(non_snake_case)] // Allow mathematical notation (N, K, etc.)
7
8use rand::Rng;
9use rand::seq::SliceRandom;
10use statrs::distribution::{
11    Bernoulli, Binomial, Geometric, Hypergeometric, NegativeBinomial, Poisson,
12};
13use statrs::distribution::{
14    Beta, Cauchy, ChiSquared, Continuous, ContinuousCDF, Discrete, DiscreteCDF, FisherSnedecor,
15    Gamma, LogNormal, Normal, StudentsT, Weibull,
16};
17use statrs::function::gamma::gamma;
18use thiserror::Error;
19
20/// Errors for distribution operations
21#[derive(Error, Debug)]
22pub enum DistributionError {
23    #[error("Invalid parameter: {0}")]
24    InvalidParameter(String),
25
26    #[error("Numerical error: {0}")]
27    NumericalError(String),
28
29    #[error("Distribution not supported: {0}")]
30    NotSupported(String),
31}
32
33/// Result type for distribution operations
34pub type Result<T> = std::result::Result<T, DistributionError>;
35
36/// Enum representing common continuous distributions
37#[derive(Debug, Clone)]
38pub enum ContinuousDistribution {
39    Normal { mean: f64, std_dev: f64 },
40    StudentsT { df: f64 },
41    ChiSquared { df: f64 },
42    FisherSnedecor { d1: f64, d2: f64 },
43    Exponential { rate: f64 },
44    Gamma { shape: f64, rate: f64 },
45    Beta { alpha: f64, beta: f64 },
46    LogNormal { mu: f64, sigma: f64 },
47    Cauchy { location: f64, scale: f64 },
48    Weibull { shape: f64, scale: f64 },
49    Uniform { lower: f64, upper: f64 },
50}
51
52impl ContinuousDistribution {
53    /// Create a standard normal distribution
54    pub fn standard_normal() -> Self {
55        Self::Normal {
56            mean: 0.0,
57            std_dev: 1.0,
58        }
59    }
60
61    /// Create a normal distribution with given parameters
62    pub fn normal(mean: f64, std_dev: f64) -> Result<Self> {
63        if std_dev <= 0.0 {
64            return Err(DistributionError::InvalidParameter(
65                "Standard deviation must be positive".to_string(),
66            ));
67        }
68        Ok(Self::Normal { mean, std_dev })
69    }
70
71    /// Create a t-distribution
72    pub fn students_t(df: f64) -> Result<Self> {
73        if df <= 0.0 {
74            return Err(DistributionError::InvalidParameter(
75                "Degrees of freedom must be positive".to_string(),
76            ));
77        }
78        Ok(Self::StudentsT { df })
79    }
80
81    /// Create a chi-squared distribution
82    pub fn chi_squared(df: f64) -> Result<Self> {
83        if df <= 0.0 {
84            return Err(DistributionError::InvalidParameter(
85                "Degrees of freedom must be positive".to_string(),
86            ));
87        }
88        Ok(Self::ChiSquared { df })
89    }
90
91    /// Create an F-distribution
92    pub fn fisher_snedecor(d1: f64, d2: f64) -> Result<Self> {
93        if d1 <= 0.0 || d2 <= 0.0 {
94            return Err(DistributionError::InvalidParameter(
95                "Degrees of freedom must be positive".to_string(),
96            ));
97        }
98        Ok(Self::FisherSnedecor { d1, d2 })
99    }
100
101    /// Probability density function
102    pub fn pdf(&self, x: f64) -> f64 {
103        match self {
104            Self::Normal { mean, std_dev } => {
105                let dist = Normal::new(*mean, *std_dev).unwrap();
106                dist.pdf(x)
107            }
108            Self::StudentsT { df } => {
109                let dist = StudentsT::new(0.0, 1.0, *df).unwrap();
110                dist.pdf(x)
111            }
112            Self::ChiSquared { df } => {
113                let dist = ChiSquared::new(*df).unwrap();
114                dist.pdf(x)
115            }
116            Self::FisherSnedecor { d1, d2 } => {
117                let dist = FisherSnedecor::new(*d1, *d2).unwrap();
118                dist.pdf(x)
119            }
120            Self::Exponential { rate } => {
121                if x < 0.0 {
122                    0.0
123                } else {
124                    rate * (-rate * x).exp()
125                }
126            }
127            Self::Gamma { shape, rate } => {
128                let dist = Gamma::new(*shape, *rate).unwrap();
129                dist.pdf(x)
130            }
131            Self::Beta { alpha, beta } => {
132                let dist = Beta::new(*alpha, *beta).unwrap();
133                dist.pdf(x)
134            }
135            Self::LogNormal { mu, sigma } => {
136                let dist = LogNormal::new(*mu, *sigma).unwrap();
137                dist.pdf(x)
138            }
139            Self::Cauchy { location, scale } => {
140                let dist = Cauchy::new(*location, *scale).unwrap();
141                dist.pdf(x)
142            }
143            Self::Weibull { shape, scale } => {
144                let dist = Weibull::new(*shape, *scale).unwrap();
145                dist.pdf(x)
146            }
147            Self::Uniform { lower, upper } => {
148                if x < *lower || x > *upper {
149                    0.0
150                } else {
151                    1.0 / (upper - lower)
152                }
153            }
154        }
155    }
156
157    /// Cumulative distribution function
158    pub fn cdf(&self, x: f64) -> f64 {
159        match self {
160            Self::Normal { mean, std_dev } => {
161                let dist = Normal::new(*mean, *std_dev).unwrap();
162                dist.cdf(x)
163            }
164            Self::StudentsT { df } => {
165                let dist = StudentsT::new(0.0, 1.0, *df).unwrap();
166                dist.cdf(x)
167            }
168            Self::ChiSquared { df } => {
169                let dist = ChiSquared::new(*df).unwrap();
170                dist.cdf(x)
171            }
172            Self::FisherSnedecor { d1, d2 } => {
173                let dist = FisherSnedecor::new(*d1, *d2).unwrap();
174                dist.cdf(x)
175            }
176            Self::Exponential { rate } => {
177                if x < 0.0 {
178                    0.0
179                } else {
180                    1.0 - (-rate * x).exp()
181                }
182            }
183            Self::Gamma { shape, rate } => {
184                let dist = Gamma::new(*shape, *rate).unwrap();
185                dist.cdf(x)
186            }
187            Self::Beta { alpha, beta } => {
188                let dist = Beta::new(*alpha, *beta).unwrap();
189                dist.cdf(x)
190            }
191            Self::LogNormal { mu, sigma } => {
192                let dist = LogNormal::new(*mu, *sigma).unwrap();
193                dist.cdf(x)
194            }
195            Self::Cauchy { location, scale } => {
196                let dist = Cauchy::new(*location, *scale).unwrap();
197                dist.cdf(x)
198            }
199            Self::Weibull { shape, scale } => {
200                let dist = Weibull::new(*shape, *scale).unwrap();
201                dist.cdf(x)
202            }
203            Self::Uniform { lower, upper } => {
204                if x < *lower {
205                    0.0
206                } else if x > *upper {
207                    1.0
208                } else {
209                    (x - lower) / (upper - lower)
210                }
211            }
212        }
213    }
214
215    /// Quantile function (inverse CDF)
216    pub fn quantile(&self, p: f64) -> Option<f64> {
217        if !(0.0..=1.0).contains(&p) {
218            return None;
219        }
220
221        match self {
222            Self::Normal { mean, std_dev } => {
223                let dist = Normal::new(*mean, *std_dev).unwrap();
224                Some(dist.inverse_cdf(p))
225            }
226            Self::StudentsT { df } => {
227                let dist = StudentsT::new(0.0, 1.0, *df).unwrap();
228                Some(dist.inverse_cdf(p))
229            }
230            Self::ChiSquared { df } => {
231                let dist = ChiSquared::new(*df).unwrap();
232                Some(dist.inverse_cdf(p))
233            }
234            Self::FisherSnedecor { d1, d2 } => {
235                let dist = FisherSnedecor::new(*d1, *d2).unwrap();
236                Some(dist.inverse_cdf(p))
237            }
238            Self::Exponential { rate } => {
239                if p <= 0.0 {
240                    Some(0.0)
241                } else if p >= 1.0 {
242                    Some(f64::INFINITY)
243                } else {
244                    Some(-(1.0 - p).ln() / rate)
245                }
246            }
247            Self::Gamma { shape, rate } => {
248                let dist = Gamma::new(*shape, *rate).unwrap();
249                Some(dist.inverse_cdf(p))
250            }
251            Self::Beta { alpha, beta } => {
252                let dist = Beta::new(*alpha, *beta).unwrap();
253                Some(dist.inverse_cdf(p))
254            }
255            Self::LogNormal { mu, sigma } => {
256                let dist = LogNormal::new(*mu, *sigma).unwrap();
257                Some(dist.inverse_cdf(p))
258            }
259            Self::Cauchy { location, scale } => {
260                let dist = Cauchy::new(*location, *scale).unwrap();
261                Some(dist.inverse_cdf(p))
262            }
263            Self::Weibull { shape, scale } => {
264                let dist = Weibull::new(*shape, *scale).unwrap();
265                Some(dist.inverse_cdf(p))
266            }
267            Self::Uniform { lower, upper } => Some(lower + p * (upper - lower)),
268        }
269    }
270
271    /// Generate random sample from distribution
272    pub fn sample<R: Rng>(&self, _rng: &mut R) -> f64 {
273        match self {
274            Self::Normal { mean, .. } => *mean,
275            Self::StudentsT { df } => {
276                if *df > 1.0 {
277                    0.0
278                } else {
279                    f64::NAN
280                }
281            }
282            Self::ChiSquared { df } => *df,
283            Self::FisherSnedecor { d1: _, d2 } => {
284                if *d2 > 2.0 {
285                    *d2 / (*d2 - 2.0)
286                } else {
287                    f64::NAN
288                }
289            }
290            Self::Exponential { rate } => 1.0 / rate,
291            Self::Gamma { shape, rate } => shape / rate,
292            Self::Beta { alpha, beta } => alpha / (alpha + beta),
293            Self::LogNormal { mu, sigma } => (mu + sigma.powi(2) / 2.0).exp(),
294            Self::Cauchy { location, .. } => *location,
295            Self::Weibull { shape, scale } => scale * gamma(1.0 + 1.0 / shape),
296            Self::Uniform { lower, upper } => (lower + upper) / 2.0,
297        }
298    }
299}
300
301/// Enum representing common discrete distributions
302#[derive(Debug, Clone)]
303pub enum DiscreteDistribution {
304    Bernoulli { p: f64 },
305    Binomial { n: u64, p: f64 },
306    Poisson { lambda: f64 },
307    Geometric { p: f64 },
308    NegativeBinomial { r: f64, p: f64 },
309    Hypergeometric { N: u64, K: u64, n: u64 },
310}
311
312impl DiscreteDistribution {
313    /// Create a Bernoulli distribution
314    pub fn bernoulli(p: f64) -> Result<Self> {
315        if !(0.0..=1.0).contains(&p) {
316            return Err(DistributionError::InvalidParameter(
317                "Probability must be between 0 and 1".to_string(),
318            ));
319        }
320        Ok(Self::Bernoulli { p })
321    }
322
323    /// Create a binomial distribution
324    pub fn binomial(n: u64, p: f64) -> Result<Self> {
325        if !(0.0..=1.0).contains(&p) {
326            return Err(DistributionError::InvalidParameter(
327                "Probability must be between 0 and 1".to_string(),
328            ));
329        }
330        Ok(Self::Binomial { n, p })
331    }
332
333    /// Create a Poisson distribution
334    pub fn poisson(lambda: f64) -> Result<Self> {
335        if lambda <= 0.0 {
336            return Err(DistributionError::InvalidParameter(
337                "Lambda must be positive".to_string(),
338            ));
339        }
340        Ok(Self::Poisson { lambda })
341    }
342
343    /// Create a geometric distribution
344    pub fn geometric(p: f64) -> Result<Self> {
345        if !(0.0..=1.0).contains(&p) {
346            return Err(DistributionError::InvalidParameter(
347                "Probability must be between 0 and 1".to_string(),
348            ));
349        }
350        Ok(Self::Geometric { p })
351    }
352
353    /// Create a negative binomial distribution
354    pub fn negative_binomial(r: f64, p: f64) -> Result<Self> {
355        if r <= 0.0 || !(0.0..=1.0).contains(&p) {
356            return Err(DistributionError::InvalidParameter(
357                "r must be positive and p must be between 0 and 1".to_string(),
358            ));
359        }
360        Ok(Self::NegativeBinomial { r, p })
361    }
362
363    /// Create a hypergeometric distribution
364    pub fn hypergeometric(N: u64, K: u64, n: u64) -> Result<Self> {
365        if n > N || K > N {
366            return Err(DistributionError::InvalidParameter(
367                "Invalid parameters for hypergeometric distribution".to_string(),
368            ));
369        }
370        Ok(Self::Hypergeometric { N, K, n })
371    }
372
373    /// Probability mass function
374    pub fn pmf(&self, k: u64) -> f64 {
375        match self {
376            Self::Bernoulli { p } => {
377                let dist = Bernoulli::new(*p).unwrap();
378                dist.pmf(k as u64)
379            }
380            Self::Binomial { n, p } => {
381                let dist = Binomial::new(*p, *n).unwrap();
382                dist.pmf(k as u64)
383            }
384            Self::Poisson { lambda } => {
385                let dist = Poisson::new(*lambda).unwrap();
386                dist.pmf(k as u64)
387            }
388            Self::Geometric { p } => {
389                let dist = Geometric::new(*p).unwrap();
390                dist.pmf(k as u64)
391            }
392            Self::NegativeBinomial { r, p } => {
393                let dist = NegativeBinomial::new(*r, *p).unwrap();
394                dist.pmf(k as u64)
395            }
396            Self::Hypergeometric { N, K, n } => {
397                let dist = Hypergeometric::new(*N, *K, *n).unwrap();
398                dist.pmf(k as u64)
399            }
400        }
401    }
402
403    /// Cumulative distribution function
404    pub fn cdf(&self, k: u64) -> f64 {
405        match self {
406            Self::Bernoulli { p } => {
407                let dist = Bernoulli::new(*p).unwrap();
408                dist.cdf(k as u64)
409            }
410            Self::Binomial { n, p } => {
411                let dist = Binomial::new(*p, *n).unwrap();
412                dist.cdf(k as u64)
413            }
414            Self::Poisson { lambda } => {
415                let dist = Poisson::new(*lambda).unwrap();
416                dist.cdf(k as u64)
417            }
418            Self::Geometric { p } => {
419                let dist = Geometric::new(*p).unwrap();
420                dist.cdf(k as u64)
421            }
422            Self::NegativeBinomial { r, p } => {
423                let dist = NegativeBinomial::new(*r, *p).unwrap();
424                dist.cdf(k as u64)
425            }
426            Self::Hypergeometric { N, K, n } => {
427                let dist = Hypergeometric::new(*N, *K, *n).unwrap();
428                dist.cdf(k as u64)
429            }
430        }
431    }
432
433    /// Generate random sample from distribution
434    pub fn sample<R: Rng>(&self, rng: &mut R) -> u64 {
435        match self {
436            Self::Bernoulli { p } => {
437                // Bernoulli trial: success with probability p
438                if rng.random::<f64>() < *p { 1 } else { 0 }
439            }
440            Self::Binomial { n, p } => {
441                // Sum of n Bernoulli trials
442                let mut successes = 0;
443                for _ in 0..*n {
444                    if rng.random::<f64>() < *p {
445                        successes += 1;
446                    }
447                }
448                successes
449            }
450            Self::Poisson { lambda } => {
451                // Knuth's algorithm for Poisson sampling
452                let l = (-*lambda).exp();
453                let mut k = 0;
454                let mut p = 1.0;
455                loop {
456                    k += 1;
457                    p *= rng.random::<f64>();
458                    if p <= l {
459                        break;
460                    }
461                }
462                (k - 1) as u64
463            }
464            Self::Geometric { p } => {
465                // Geometric distribution: number of trials until first success
466                ((rng.random::<f64>().ln() / (1.0 - p).ln()).floor() as u64) + 1
467            }
468            Self::NegativeBinomial { r, p } => {
469                // Number of trials until r successes
470                let mut successes = 0;
471                let mut trials = 0;
472                while successes < *r as u64 {
473                    trials += 1;
474                    if rng.random::<f64>() < *p {
475                        successes += 1;
476                    }
477                }
478                trials
479            }
480            Self::Hypergeometric { N, K, n } => {
481                // Simple implementation: sample without replacement
482                // This is inefficient for large N but works for now
483                let mut population = vec![true; *K as usize]
484                    .into_iter()
485                    .chain(vec![false; (*N - *K) as usize])
486                    .collect::<Vec<_>>();
487                population.shuffle(rng);
488                population[..*n as usize].iter().filter(|&&x| x).count() as u64
489            }
490        }
491    }
492}