scirs2_stats/distributions/bernoulli.rs
1//! Bernoulli distribution functions
2//!
3//! This module provides functionality for the Bernoulli distribution.
4
5use crate::error::{StatsError, StatsResult};
6use crate::sampling::SampleableDistribution;
7use scirs2_core::numeric::{Float, NumCast};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::{Bernoulli as RandBernoulli, Distribution};
10use scirs2_core::validation::check_probability;
11
12/// Bernoulli distribution structure
13///
14/// The Bernoulli distribution is a discrete probability distribution taking
15/// value 1 with probability p and value 0 with probability q = 1 - p.
16/// It is the discrete probability distribution of a random variable which takes
17/// the value 1 with probability p and the value 0 with probability q.
18pub struct Bernoulli<F: Float> {
19 /// Success probability p (0 ≤ p ≤ 1)
20 pub p: F,
21 /// Random number generator
22 rand_distr: RandBernoulli,
23}
24
25impl<F: Float + NumCast + std::fmt::Display> Bernoulli<F> {
26 /// Create a new Bernoulli distribution with given success probability
27 ///
28 /// # Arguments
29 ///
30 /// * `p` - Success probability (0 ≤ p ≤ 1)
31 ///
32 /// # Returns
33 ///
34 /// * A new Bernoulli distribution instance
35 ///
36 /// # Examples
37 ///
38 /// ```
39 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
40 ///
41 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
42 /// ```
43 pub fn new(p: F) -> StatsResult<Self> {
44 // Validate parameters using core validation function
45 let _ = check_probability(p, "Success probability").map_err(StatsError::from)?;
46
47 // Create RNG for Bernoulli distribution
48 let p_f64 = <f64 as scirs2_core::numeric::NumCast>::from(p).ok_or_else(|| {
49 StatsError::ComputationError("Failed to convert p to f64".to_string())
50 })?;
51 let rand_distr = match RandBernoulli::new(p_f64) {
52 Ok(distr) => distr,
53 Err(_) => {
54 return Err(StatsError::ComputationError(
55 "Failed to create Bernoulli distribution for sampling".to_string(),
56 ))
57 }
58 };
59
60 Ok(Bernoulli { p, rand_distr })
61 }
62
63 /// Calculate the probability mass function (PMF) at a given point
64 ///
65 /// # Arguments
66 ///
67 /// * `k` - The point at which to evaluate the PMF (0 or 1)
68 ///
69 /// # Returns
70 ///
71 /// * The value of the PMF at the given point
72 ///
73 /// # Examples
74 ///
75 /// ```
76 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
77 ///
78 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
79 /// let pmf_at_one = bern.pmf(1.0);
80 /// assert!((pmf_at_one - 0.3).abs() < 1e-7);
81 /// ```
82 pub fn pmf(&self, k: F) -> F {
83 let one = F::one();
84 let zero = F::zero();
85
86 // PMF is only defined for k = 0 and k = 1
87 if k == zero {
88 one - self.p // q = 1 - p
89 } else if k == one {
90 self.p
91 } else {
92 zero
93 }
94 }
95
96 /// Calculate the log of the probability mass function (log-PMF) at a given point
97 ///
98 /// # Arguments
99 ///
100 /// * `k` - The point at which to evaluate the log-PMF (0 or 1)
101 ///
102 /// # Returns
103 ///
104 /// * The value of the log-PMF at the given point
105 ///
106 /// # Examples
107 ///
108 /// ```
109 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
110 ///
111 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
112 /// let log_pmf_at_one = bern.log_pmf(1.0);
113 /// assert!((log_pmf_at_one - (-1.2039728)).abs() < 1e-6);
114 /// ```
115 pub fn log_pmf(&self, k: F) -> F {
116 let one = F::one();
117 let zero = F::zero();
118 let neg_infinity = F::neg_infinity();
119
120 // log-PMF is only defined for k = 0 and k = 1
121 if k == zero {
122 if self.p == one {
123 neg_infinity
124 } else {
125 (one - self.p).ln() // ln(q) = ln(1 - p)
126 }
127 } else if k == one {
128 if self.p == zero {
129 neg_infinity
130 } else {
131 self.p.ln() // ln(p)
132 }
133 } else {
134 neg_infinity
135 }
136 }
137
138 /// Calculate the cumulative distribution function (CDF) at a given point
139 ///
140 /// # Arguments
141 ///
142 /// * `k` - The point at which to evaluate the CDF
143 ///
144 /// # Returns
145 ///
146 /// * The value of the CDF at the given point
147 ///
148 /// # Examples
149 ///
150 /// ```
151 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
152 ///
153 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
154 /// let cdf_at_zero = bern.cdf(0.0);
155 /// assert!((cdf_at_zero - 0.7).abs() < 1e-7);
156 /// ```
157 pub fn cdf(&self, k: F) -> F {
158 let zero = F::zero();
159 let one = F::one();
160
161 if k < zero {
162 zero
163 } else if k < one {
164 one - self.p // F(0) = P(X ≤ 0) = P(X = 0) = 1 - p
165 } else {
166 one // F(k) = P(X ≤ k) = 1 for k ≥ 1
167 }
168 }
169
170 /// Inverse of the cumulative distribution function (quantile function)
171 ///
172 /// # Arguments
173 ///
174 /// * `p` - Probability value (between 0 and 1)
175 ///
176 /// # Returns
177 ///
178 /// * The value k such that CDF(k) = p
179 ///
180 /// # Examples
181 ///
182 /// ```
183 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
184 ///
185 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
186 /// let quant = bern.ppf(0.8).expect("Operation failed");
187 /// assert_eq!(quant, 1.0);
188 /// ```
189 pub fn ppf(&self, p_val: F) -> StatsResult<F> {
190 // Validate probability using core validation function
191 let p_val = check_probability(p_val, "Probability value").map_err(StatsError::from)?;
192
193 let zero = F::zero();
194 let one = F::one();
195
196 // Quantile function for Bernoulli
197 let q = one - self.p; // q = 1 - p
198
199 if p_val <= q {
200 Ok(zero) // Q(p) = 0 for p ≤ q
201 } else {
202 Ok(one) // Q(p) = 1 for p > q
203 }
204 }
205
206 /// Generate random samples from the distribution
207 ///
208 /// # Arguments
209 ///
210 /// * `size` - Number of samples to generate
211 ///
212 /// # Returns
213 ///
214 /// * Vector of random samples
215 ///
216 /// # Examples
217 ///
218 /// ```
219 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
220 ///
221 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
222 /// let samples = bern.rvs(10).expect("Operation failed");
223 /// assert_eq!(samples.len(), 10);
224 /// ```
225 pub fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
226 let mut rng = thread_rng();
227 let mut samples = Vec::with_capacity(size);
228 let zero = F::zero();
229 let one = F::one();
230
231 for _ in 0..size {
232 // Generate random Bernoulli sample (0 or 1)
233 let sample = if self.rand_distr.sample(&mut rng) {
234 one
235 } else {
236 zero
237 };
238
239 samples.push(sample);
240 }
241
242 Ok(samples)
243 }
244
245 /// Calculate the mean of the distribution
246 ///
247 /// # Returns
248 ///
249 /// * The mean of the distribution
250 ///
251 /// # Examples
252 ///
253 /// ```
254 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
255 ///
256 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
257 /// let mean = bern.mean();
258 /// assert!((mean - 0.3).abs() < 1e-7);
259 /// ```
260 pub fn mean(&self) -> F {
261 // Mean = p
262 self.p
263 }
264
265 /// Calculate the variance of the distribution
266 ///
267 /// # Returns
268 ///
269 /// * The variance of the distribution
270 ///
271 /// # Examples
272 ///
273 /// ```
274 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
275 ///
276 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
277 /// let variance = bern.var();
278 /// assert!((variance - 0.21).abs() < 1e-7);
279 /// ```
280 pub fn var(&self) -> F {
281 // Variance = p * (1 - p)
282 let one = F::one();
283 self.p * (one - self.p)
284 }
285
286 /// Calculate the standard deviation of the distribution
287 ///
288 /// # Returns
289 ///
290 /// * The standard deviation of the distribution
291 ///
292 /// # Examples
293 ///
294 /// ```
295 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
296 ///
297 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
298 /// let std_dev = bern.std();
299 /// assert!((std_dev - 0.458257).abs() < 1e-6);
300 /// ```
301 pub fn std(&self) -> F {
302 // Std = sqrt(variance)
303 self.var().sqrt()
304 }
305
306 /// Calculate the skewness of the distribution
307 ///
308 /// # Returns
309 ///
310 /// * The skewness of the distribution
311 ///
312 /// # Examples
313 ///
314 /// ```
315 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
316 ///
317 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
318 /// let skewness = bern.skewness();
319 /// assert!((skewness - 0.87287156).abs() < 1e-5);
320 /// ```
321 pub fn skewness(&self) -> F {
322 // Skewness = (1 - 2p) / sqrt(p * (1 - p))
323 let one = F::from(1.0).unwrap_or_else(|| F::zero());
324 let two = F::from(2.0).unwrap_or_else(|| F::zero());
325
326 let q = one - self.p; // q = 1 - p
327
328 // Handle special cases to avoid division by zero
329 if self.p == F::zero() || self.p == F::one() {
330 return F::zero(); // Degenerate case, skewness is not well-defined
331 }
332
333 (one - two * self.p) / (self.p * q).sqrt()
334 }
335
336 /// Calculate the kurtosis of the distribution
337 ///
338 /// # Returns
339 ///
340 /// * The excess kurtosis of the distribution
341 ///
342 /// # Examples
343 ///
344 /// ```
345 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
346 ///
347 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
348 /// let kurtosis = bern.kurtosis();
349 /// assert!((kurtosis - (-1.2351)) < 1e-4);
350 /// ```
351 pub fn kurtosis(&self) -> F {
352 // Excess Kurtosis = (1 - 6p(1-p)) / (p(1-p))
353 let one = F::from(1.0).unwrap_or_else(|| F::zero());
354 let six = F::from(6.0).unwrap_or_else(|| F::zero());
355
356 let q = one - self.p; // q = 1 - p
357 let pq = self.p * q;
358
359 // Handle special cases to avoid division by zero
360 if self.p == F::zero() || self.p == F::one() {
361 return F::zero(); // Degenerate case, kurtosis is not well-defined
362 }
363
364 (one - six * pq) / pq
365 }
366
367 /// Calculate the entropy of the distribution
368 ///
369 /// # Returns
370 ///
371 /// * The entropy value
372 ///
373 /// # Examples
374 ///
375 /// ```
376 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
377 ///
378 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
379 /// let entropy = bern.entropy();
380 /// assert!((entropy - 0.6108643).abs() < 1e-6);
381 /// ```
382 pub fn entropy(&self) -> F {
383 // Entropy = -p * ln(p) - (1-p) * ln(1-p)
384 let zero = F::zero();
385 let one = F::one();
386
387 // Handle special cases
388 if self.p == zero || self.p == one {
389 return zero; // Degenerate case, entropy is 0
390 }
391
392 let q = one - self.p; // q = 1 - p
393
394 // H(X) = -p * ln(p) - q * ln(q)
395 -(self.p * self.p.ln() + q * q.ln())
396 }
397
398 /// Calculate the median of the distribution
399 ///
400 /// # Returns
401 ///
402 /// * The median of the distribution
403 ///
404 /// # Examples
405 ///
406 /// ```
407 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
408 ///
409 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
410 /// let median = bern.median();
411 /// assert_eq!(median, 0.0);
412 /// ```
413 pub fn median(&self) -> F {
414 let zero = F::zero();
415 let one = F::one();
416 let half = F::from(0.5).expect("Failed to convert constant to float");
417
418 // Median is 0 if p < 0.5, 0 or 1 if p = 0.5, and 1 if p > 0.5
419 if self.p < half {
420 zero
421 } else if self.p > half {
422 one
423 } else {
424 // When p = 0.5, both 0 and 1 are medians
425 // We return 0 by convention
426 zero
427 }
428 }
429
430 /// Calculate the mode of the distribution
431 ///
432 /// # Returns
433 ///
434 /// * The mode of the distribution
435 ///
436 /// # Examples
437 ///
438 /// ```
439 /// use scirs2_stats::distributions::bernoulli::Bernoulli;
440 ///
441 /// let bern = Bernoulli::new(0.3f64).expect("Operation failed");
442 /// let mode = bern.mode();
443 /// assert_eq!(mode, 0.0);
444 /// ```
445 pub fn mode(&self) -> F {
446 let zero = F::zero();
447 let one = F::one();
448 let half = F::from(0.5).expect("Failed to convert constant to float");
449
450 // Mode is 0 if p < 0.5, 0 or 1 if p = 0.5, and 1 if p > 0.5
451 if self.p < half {
452 zero
453 } else if self.p > half {
454 one
455 } else {
456 // When p = 0.5, both 0 and 1 are modes
457 // We return 0 by convention
458 zero
459 }
460 }
461}
462
463/// Create a Bernoulli distribution with the given parameter.
464///
465/// This is a convenience function to create a Bernoulli distribution with
466/// the given success probability.
467///
468/// # Arguments
469///
470/// * `p` - Success probability (0 ≤ p ≤ 1)
471///
472/// # Returns
473///
474/// * A Bernoulli distribution object
475///
476/// # Examples
477///
478/// ```
479/// use scirs2_stats::distributions::bernoulli;
480///
481/// let b = bernoulli::bernoulli(0.3f64).expect("Operation failed");
482/// let pmf_at_one = b.pmf(1.0);
483/// assert!((pmf_at_one - 0.3).abs() < 1e-7);
484/// ```
485#[allow(dead_code)]
486pub fn bernoulli<F>(p: F) -> StatsResult<Bernoulli<F>>
487where
488 F: Float + NumCast + std::fmt::Display,
489{
490 Bernoulli::new(p)
491}
492
493/// Implementation of SampleableDistribution for Bernoulli
494impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for Bernoulli<F> {
495 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
496 self.rvs(size)
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use approx::assert_relative_eq;
504
505 #[test]
506 fn test_bernoulli_creation() {
507 // Valid p values
508 let bern1 = Bernoulli::new(0.0).expect("Operation failed");
509 assert_eq!(bern1.p, 0.0);
510
511 let bern2 = Bernoulli::new(0.5).expect("Operation failed");
512 assert_eq!(bern2.p, 0.5);
513
514 let bern3 = Bernoulli::new(1.0).expect("Operation failed");
515 assert_eq!(bern3.p, 1.0);
516
517 // Invalid p values
518 assert!(Bernoulli::<f64>::new(-0.1).is_err());
519 assert!(Bernoulli::<f64>::new(1.1).is_err());
520 }
521
522 #[test]
523 fn test_bernoulli_pmf() {
524 let bern = Bernoulli::new(0.3).expect("Operation failed");
525
526 // PMF at k = 0
527 let pmf_at_zero = bern.pmf(0.0);
528 assert_relative_eq!(pmf_at_zero, 0.7, epsilon = 1e-10);
529
530 // PMF at k = 1
531 let pmf_at_one = bern.pmf(1.0);
532 assert_relative_eq!(pmf_at_one, 0.3, epsilon = 1e-10);
533
534 // PMF at other values (should be 0)
535 let pmf_at_other = bern.pmf(0.5);
536 assert_eq!(pmf_at_other, 0.0);
537
538 // Corner cases
539 let bern_zero = Bernoulli::new(0.0).expect("Operation failed");
540 assert_eq!(bern_zero.pmf(0.0), 1.0);
541 assert_eq!(bern_zero.pmf(1.0), 0.0);
542
543 let bern_one = Bernoulli::new(1.0).expect("Operation failed");
544 assert_eq!(bern_one.pmf(0.0), 0.0);
545 assert_eq!(bern_one.pmf(1.0), 1.0);
546 }
547
548 #[test]
549 fn test_bernoulli_log_pmf() {
550 let bern = Bernoulli::new(0.3).expect("Operation failed");
551
552 // log-PMF at k = 0
553 let log_pmf_at_zero = bern.log_pmf(0.0);
554 assert_relative_eq!(log_pmf_at_zero, 0.7.ln(), epsilon = 1e-10);
555
556 // log-PMF at k = 1
557 let log_pmf_at_one = bern.log_pmf(1.0);
558 assert_relative_eq!(log_pmf_at_one, 0.3.ln(), epsilon = 1e-10);
559
560 // log-PMF at other values (should be -infinity)
561 let log_pmf_at_other = bern.log_pmf(0.5);
562 assert!(log_pmf_at_other.is_infinite() && log_pmf_at_other.is_sign_negative());
563
564 // Corner cases
565 let bern_zero = Bernoulli::new(0.0).expect("Operation failed");
566 assert_eq!(bern_zero.log_pmf(0.0), 0.0);
567 assert!(bern_zero.log_pmf(1.0).is_infinite() && bern_zero.log_pmf(1.0).is_sign_negative());
568
569 let bern_one = Bernoulli::new(1.0).expect("Operation failed");
570 assert!(bern_one.log_pmf(0.0).is_infinite() && bern_one.log_pmf(0.0).is_sign_negative());
571 assert_eq!(bern_one.log_pmf(1.0), 0.0);
572 }
573
574 #[test]
575 fn test_bernoulli_cdf() {
576 let bern = Bernoulli::new(0.3).expect("Operation failed");
577
578 // CDF for various values
579 assert_eq!(bern.cdf(-0.1), 0.0); // F(-0.1) = 0
580 assert_eq!(bern.cdf(0.0), 0.7); // F(0) = P(X ≤ 0) = P(X = 0) = 1 - p = 0.7
581 assert_eq!(bern.cdf(0.5), 0.7); // F(0.5) = P(X ≤ 0.5) = P(X = 0) = 1 - p = 0.7
582 assert_eq!(bern.cdf(1.0), 1.0); // F(1) = P(X ≤ 1) = 1
583 assert_eq!(bern.cdf(2.0), 1.0); // F(2) = P(X ≤ 2) = 1
584 }
585
586 #[test]
587 fn test_bernoulli_ppf() {
588 let bern = Bernoulli::new(0.3).expect("Operation failed");
589
590 // Quantile function
591 assert_eq!(bern.ppf(0.0).expect("Operation failed"), 0.0); // Q(0) = 0
592 assert_eq!(bern.ppf(0.3).expect("Operation failed"), 0.0); // Q(0.3) = 0 since 0.3 ≤ q = 0.7
593 assert_eq!(bern.ppf(0.7).expect("Operation failed"), 0.0); // Q(0.7) = 0 since 0.7 = q = 0.7
594 assert_eq!(bern.ppf(0.71).expect("Operation failed"), 1.0); // Q(0.71) = 1 since 0.71 > q = 0.7
595 assert_eq!(bern.ppf(1.0).expect("Operation failed"), 1.0); // Q(1) = 1
596
597 // Invalid p values
598 assert!(bern.ppf(-0.1).is_err());
599 assert!(bern.ppf(1.1).is_err());
600 }
601
602 #[test]
603 fn test_bernoulli_rvs() {
604 let bern = Bernoulli::new(0.5).expect("Operation failed");
605
606 // Generate samples
607 let samples = bern.rvs(100).expect("Operation failed");
608
609 // Check the number of samples
610 assert_eq!(samples.len(), 100);
611
612 // Check all values are either 0 or 1
613 for &sample in &samples {
614 assert!(sample == 0.0 || sample == 1.0);
615 }
616
617 // With p = 0.5, mean should be close to 0.5 for a large sample
618 let sum: f64 = samples.iter().sum();
619 let mean = sum / samples.len() as f64;
620
621 // Allow for some randomness, but mean should be roughly around 0.5
622 assert!(mean > 0.3 && mean < 0.7);
623 }
624
625 #[test]
626 fn test_bernoulli_stats() {
627 // Test with p = 0.3
628 let bern = Bernoulli::new(0.3).expect("Operation failed");
629
630 // Mean = p = 0.3
631 assert_eq!(bern.mean(), 0.3);
632
633 // Variance = p * (1 - p) = 0.3 * 0.7 = 0.21
634 assert_relative_eq!(bern.var(), 0.21, epsilon = 1e-10);
635
636 // Standard deviation = sqrt(variance) = sqrt(0.21) ≈ 0.458258
637 assert_relative_eq!(bern.std(), 0.21_f64.sqrt(), epsilon = 1e-10);
638
639 // Skewness = (1 - 2p) / sqrt(p * (1 - p)) = (1 - 2*0.3) / sqrt(0.3 * 0.7) = 0.4 / sqrt(0.21) ≈ 0.872872
640 let expected_skewness = (1.0 - 2.0 * 0.3) / (0.3 * 0.7).sqrt();
641 assert_relative_eq!(bern.skewness(), expected_skewness, epsilon = 1e-5);
642
643 // Kurtosis = (1 - 6p(1-p)) / (p(1-p)) = (1 - 6*0.3*0.7) / (0.3*0.7) = (1 - 1.26) / 0.21 ≈ -1.238
644 let expected_kurtosis = (1.0 - 6.0 * 0.3 * 0.7) / (0.3 * 0.7);
645 assert_relative_eq!(bern.kurtosis(), expected_kurtosis, epsilon = 1e-6);
646
647 // Entropy = -p * ln(p) - (1-p) * ln(1-p) = -0.3 * ln(0.3) - 0.7 * ln(0.7) ≈ 0.610864
648 let expected_entropy = -0.3 * 0.3.ln() - 0.7 * 0.7.ln();
649 assert_relative_eq!(bern.entropy(), expected_entropy, epsilon = 1e-6);
650
651 // Median and mode for p < 0.5 are both 0
652 assert_eq!(bern.median(), 0.0);
653 assert_eq!(bern.mode(), 0.0);
654
655 // Test with p = 0.8 (> 0.5)
656 let bern2 = Bernoulli::new(0.8).expect("Operation failed");
657
658 // Median and mode for p > 0.5 are both 1
659 assert_eq!(bern2.median(), 1.0);
660 assert_eq!(bern2.mode(), 1.0);
661
662 // Test with p = 0.5
663 let bern3 = Bernoulli::new(0.5).expect("Operation failed");
664
665 // Median and mode for p = 0.5 are either 0 or 1 (we return 0 by convention)
666 assert_eq!(bern3.median(), 0.0);
667 assert_eq!(bern3.mode(), 0.0);
668 }
669}