fugue/core/distribution.rs
1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/docs/core/distribution.md"))]
2use rand::{Rng, RngCore};
3use rand_distr::{
4 Beta as RDBeta, Binomial as RDBinomial, Distribution as RandDistr, Exp as RDExp,
5 Gamma as RDGamma, LogNormal as RDLogNormal, Normal as RDNormal, Poisson as RDPoisson,
6};
7/// Type alias for log-probabilities.
8///
9/// Log-probabilities are represented as `f64` values. Negative infinity represents
10/// zero probability, while finite values represent the natural logarithm of probabilities.
11pub type LogF64 = f64;
12
13/// Generic interface for type-safe probability distributions.
14/// All distributions implement `Distribution<T>` where `T` is the natural return type.
15/// Example:
16///
17/// ```rust
18/// # use fugue::*;
19/// # use rand::thread_rng;
20///
21/// let mut rng = thread_rng();
22///
23/// // Type-safe sampling
24/// let coin = Bernoulli::new(0.5).unwrap();
25/// let flip: bool = coin.sample(&mut rng); // Natural boolean
26/// let prob = coin.log_prob(&flip);
27///
28/// // Safe indexing
29/// let choice = Categorical::uniform(3).unwrap();
30/// let idx: usize = choice.sample(&mut rng); // Safe for arrays
31/// let choice_prob = choice.log_prob(&idx);
32///
33/// // Natural counting
34/// let events = Poisson::new(3.0).unwrap();
35/// let count: u64 = events.sample(&mut rng); // Natural count type
36/// let count_prob = events.log_prob(&count);
37/// ```
38pub trait Distribution<T>: Send + Sync {
39 /// Generate a random sample (with its natural type), `T`, from the distribution, using the provided random number generator, `rng`.
40 ///
41 /// Example:
42 /// ```rust
43 /// # use fugue::*;
44 /// # use rand::thread_rng;
45 ///
46 /// let mut rng = thread_rng();
47 ///
48 /// // Sample different distribution types
49 /// let normal_sample: f64 = Normal::new(0.0, 1.0).unwrap().sample(&mut rng);
50 /// let coin_flip: bool = Bernoulli::new(0.5).unwrap().sample(&mut rng);
51 /// let event_count: u64 = Poisson::new(3.0).unwrap().sample(&mut rng);
52 /// let category_idx: usize = Categorical::uniform(5).unwrap().sample(&mut rng);
53 /// ```
54 fn sample(&self, rng: &mut dyn RngCore) -> T;
55
56 /// Compute the log-probability density (continuous) or mass (discrete) of a value, `x`, from the distribution.
57 ///
58 /// Example:
59 /// ```rust
60 /// # use fugue::*;
61 ///
62 /// // Continuous distribution (probability density)
63 /// let normal = Normal::new(0.0, 1.0).unwrap();
64 /// let density = normal.log_prob(&0.0); // Peak of standard normal
65 ///
66 /// // Discrete distribution (probability mass)
67 /// let coin = Bernoulli::new(0.7).unwrap();
68 /// let prob_true = coin.log_prob(&true); // ln(0.7)
69 /// let prob_false = coin.log_prob(&false); // ln(0.3)
70 ///
71 /// // Outside support returns -∞
72 /// let poisson = Poisson::new(3.0).unwrap();
73 /// let invalid = poisson.log_prob(&u64::MAX); // Very unlikely, returns -∞
74 /// ```
75 fn log_prob(&self, x: &T) -> LogF64;
76
77 /// Clone the distribution into a boxed trait object, `Box<dyn Distribution<T>>`.
78 ///
79 /// Example:
80 /// ```rust
81 /// # use fugue::*;
82 ///
83 /// // Clone a distribution into a box
84 /// let original = Normal::new(0.0, 1.0).unwrap();
85 /// let boxed: Box<dyn Distribution<f64>> = original.clone_box();
86 ///
87 /// // Useful for storing different distribution types
88 /// let mut distributions: Vec<Box<dyn Distribution<f64>>> = vec![];
89 /// distributions.push(Normal::new(0.0, 1.0).unwrap().clone_box());
90 /// distributions.push(Uniform::new(-1.0, 1.0).unwrap().clone_box());
91 /// ```
92 fn clone_box(&self) -> Box<dyn Distribution<T>>;
93}
94
95/// A continuous distribution characterized by its mean, `mu`, and standard deviation, `sigma`.
96///
97/// Mathematical Properties:
98/// - **Support**: (-∞, +∞)
99/// - **PDF**: f(x) = (1/(σ√(2π))) × exp(-0.5 × ((x-μ)/σ)²)
100/// - **Mean**: μ
101/// - **Variance**: σ²
102/// - **68-95-99.7 rule**: ~68% within 1σ, ~95% within 2σ, ~99.7% within 3σ
103///
104/// Example:
105/// ```rust
106/// # use fugue::*;
107///
108/// // Standard normal (mean=0, std=1)
109/// let standard = sample(addr!("z"), Normal::new(0.0, 1.0).unwrap());
110///
111/// // Parameter with prior
112/// let theta = sample(addr!("theta"), Normal::new(0.0, 2.0).unwrap());
113///
114/// // Likelihood with observation
115/// let likelihood = observe(addr!("y"), Normal::new(1.5, 0.5).unwrap(), 2.0);
116///
117/// // Measurement error model
118/// let true_value = sample(addr!("true_val"), Normal::new(100.0, 10.0).unwrap());
119/// let measurement = true_value.bind(|val| {
120/// observe(addr!("measured"), Normal::new(val, 2.0).unwrap(), 98.5)
121/// });
122/// ```
123#[derive(Clone, Copy, Debug)]
124pub struct Normal {
125 /// Mean of the normal distribution.
126 mu: f64,
127 /// Standard deviation of the normal distribution (must be positive).
128 sigma: f64,
129}
130impl Normal {
131 /// Create a new Normal distribution with validated parameters.
132 pub fn new(mu: f64, sigma: f64) -> crate::error::FugueResult<Self> {
133 if !mu.is_finite() {
134 return Err(crate::error::FugueError::invalid_parameters(
135 "Normal",
136 "Mean (mu) must be finite",
137 crate::error::ErrorCode::InvalidMean,
138 )
139 .with_context("mu", format!("{}", mu)));
140 }
141 if sigma <= 0.0 || !sigma.is_finite() {
142 return Err(crate::error::FugueError::invalid_parameters(
143 "Normal",
144 "Standard deviation (sigma) must be positive and finite",
145 crate::error::ErrorCode::InvalidVariance,
146 )
147 .with_context("sigma", format!("{}", sigma))
148 .with_context("expected", "> 0.0 and finite"));
149 }
150 Ok(Normal { mu, sigma })
151 }
152
153 /// Get the mean of the distribution.
154 pub fn mu(&self) -> f64 {
155 self.mu
156 }
157
158 /// Get the standard deviation of the distribution.
159 pub fn sigma(&self) -> f64 {
160 self.sigma
161 }
162}
163impl Distribution<f64> for Normal {
164 fn sample(&self, rng: &mut dyn RngCore) -> f64 {
165 if self.sigma <= 0.0 {
166 return f64::NAN;
167 }
168 RDNormal::new(self.mu, self.sigma).unwrap().sample(rng)
169 }
170 fn log_prob(&self, x: &f64) -> LogF64 {
171 // Parameter validation
172 if self.sigma <= 0.0 || !self.sigma.is_finite() || !self.mu.is_finite() || !x.is_finite() {
173 return f64::NEG_INFINITY;
174 }
175
176 // Numerically stable computation
177 let z = (x - self.mu) / self.sigma;
178
179 // Prevent overflow for extreme values (|z| > 37 gives exp(-z²/2) < machine epsilon)
180 if z.abs() > 37.0 {
181 return f64::NEG_INFINITY;
182 }
183
184 // Use precomputed constant for better precision
185 const LN_2PI: f64 = 1.837_877_066_409_345_6; // ln(2π)
186 -0.5 * z * z - self.sigma.ln() - 0.5 * LN_2PI
187 }
188 fn clone_box(&self) -> Box<dyn Distribution<f64>> {
189 Box::new(*self)
190 }
191}
192
193/// A continuous distribution that assigns equal probability density to all values within a specified interval, from `low` to `high`.
194///
195/// Commonly used as an uninformative prior when you want to express complete uncertainty over a bounded range.
196///
197/// Mathematical Properties:
198/// - **Support**: [low, high)
199/// - **PDF**: f(x) = 1/(high-low) for low ≤ x < high, 0 otherwise
200/// - **Mean**: (low + high) / 2
201/// - **Variance**: (high - low)² / 12
202///
203/// Example:
204///
205/// ```rust
206/// # use fugue::*;
207///
208/// // Unit interval [0, 1)
209/// let unit = sample(addr!("p"), Uniform::new(0.0, 1.0).unwrap());
210///
211/// // Symmetric around zero
212/// let symmetric = sample(addr!("x"), Uniform::new(-5.0, 5.0).unwrap());
213///
214/// // Uninformative prior for weight
215/// let weight = sample(addr!("weight"), Uniform::new(0.0, 100.0).unwrap());
216///
217/// // Random angle in radians
218/// let angle = sample(addr!("angle"), Uniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap());
219/// ```
220#[derive(Clone, Copy, Debug)]
221pub struct Uniform {
222 /// Lower bound of the uniform distribution (inclusive).
223 low: f64,
224 /// Upper bound of the uniform distribution (exclusive).
225 high: f64,
226}
227impl Uniform {
228 /// Create a new Uniform distribution with validated parameters.
229 pub fn new(low: f64, high: f64) -> crate::error::FugueResult<Self> {
230 if !low.is_finite() || !high.is_finite() {
231 return Err(crate::error::FugueError::invalid_parameters(
232 "Uniform",
233 "Bounds must be finite",
234 crate::error::ErrorCode::InvalidRange,
235 )
236 .with_context("low", format!("{}", low))
237 .with_context("high", format!("{}", high)));
238 }
239 if low >= high {
240 return Err(crate::error::FugueError::invalid_parameters(
241 "Uniform",
242 "Lower bound must be less than upper bound",
243 crate::error::ErrorCode::InvalidRange,
244 )
245 .with_context("low", format!("{}", low))
246 .with_context("high", format!("{}", high)));
247 }
248 Ok(Uniform { low, high })
249 }
250
251 /// Get the lower bound.
252 pub fn low(&self) -> f64 {
253 self.low
254 }
255
256 /// Get the upper bound.
257 pub fn high(&self) -> f64 {
258 self.high
259 }
260}
261impl Distribution<f64> for Uniform {
262 fn sample(&self, rng: &mut dyn RngCore) -> f64 {
263 // Parameter validation
264 if self.low >= self.high || !self.low.is_finite() || !self.high.is_finite() {
265 return f64::NAN;
266 }
267 Rng::gen_range(rng, self.low..self.high)
268 }
269 fn log_prob(&self, x: &f64) -> LogF64 {
270 // Parameter validation
271 if self.low >= self.high
272 || !self.low.is_finite()
273 || !self.high.is_finite()
274 || !x.is_finite()
275 {
276 return f64::NEG_INFINITY;
277 }
278
279 // Check support with proper boundary handling
280 if *x < self.low || *x >= self.high {
281 f64::NEG_INFINITY
282 } else {
283 let width = self.high - self.low;
284 if width <= 0.0 {
285 f64::NEG_INFINITY
286 } else {
287 -width.ln()
288 }
289 }
290 }
291 fn clone_box(&self) -> Box<dyn Distribution<f64>> {
292 Box::new(*self)
293 }
294}
295
296/// A continuous distribution where the logarithm follows a normal distribution.
297///
298/// Useful for modeling positive-valued quantities that are naturally multiplicative or skewed.
299///
300/// Mathematical Properties:
301/// - **Support**: (0, +∞)
302/// - **PDF**: f(x) = (1/(xσ√(2π))) × exp(-0.5 × ((ln(x)-μ)/σ)²)
303/// - **Mean**: exp(μ + σ²/2)
304/// - **Variance**: (exp(σ²) - 1) × exp(2μ + σ²)
305/// - **Relationship**: If X ~ LogNormal(μ, σ), then ln(X) ~ Normal(μ, σ)
306///
307/// Example:
308/// ```rust
309/// # use fugue::*;
310///
311/// // Standard log-normal (median = 1)
312/// let standard = sample(addr!("x"), LogNormal::new(0.0, 1.0).unwrap());
313///
314/// // Positive scale parameter
315/// let scale = sample(addr!("scale"), LogNormal::new(0.0, 0.5).unwrap());
316///
317/// // Income distribution
318/// let income = sample(addr!("income"), LogNormal::new(10.0, 0.8).unwrap())
319/// .map(|x| x.round() as u64); // Convert to dollars
320///
321/// // Multiplicative error model
322/// let true_value = 100.0;
323/// let measured = sample(addr!("error"), LogNormal::new(0.0, 0.1).unwrap())
324/// .map(move |error| true_value * error);
325/// ```
326#[derive(Clone, Copy, Debug)]
327pub struct LogNormal {
328 /// Mean of the underlying normal distribution.
329 mu: f64,
330 /// Standard deviation of the underlying normal distribution (must be positive).
331 sigma: f64,
332}
333impl LogNormal {
334 /// Create a new LogNormal distribution with validated parameters.
335 pub fn new(mu: f64, sigma: f64) -> crate::error::FugueResult<Self> {
336 if !mu.is_finite() {
337 return Err(crate::error::FugueError::invalid_parameters(
338 "LogNormal",
339 "Mean (mu) must be finite",
340 crate::error::ErrorCode::InvalidMean,
341 )
342 .with_context("mu", format!("{}", mu)));
343 }
344 if sigma <= 0.0 || !sigma.is_finite() {
345 return Err(crate::error::FugueError::invalid_parameters(
346 "LogNormal",
347 "Standard deviation (sigma) must be positive and finite",
348 crate::error::ErrorCode::InvalidVariance,
349 )
350 .with_context("sigma", format!("{}", sigma))
351 .with_context("expected", "> 0.0 and finite"));
352 }
353 Ok(LogNormal { mu, sigma })
354 }
355
356 /// Get the mean of the underlying normal distribution.
357 pub fn mu(&self) -> f64 {
358 self.mu
359 }
360
361 /// Get the standard deviation of the underlying normal distribution.
362 pub fn sigma(&self) -> f64 {
363 self.sigma
364 }
365}
366impl Distribution<f64> for LogNormal {
367 fn sample(&self, rng: &mut dyn RngCore) -> f64 {
368 if self.sigma <= 0.0 {
369 return f64::NAN;
370 }
371 RDLogNormal::new(self.mu, self.sigma).unwrap().sample(rng)
372 }
373 fn log_prob(&self, x: &f64) -> LogF64 {
374 // Parameter and input validation
375 if self.sigma <= 0.0 || !self.sigma.is_finite() || !self.mu.is_finite() {
376 return f64::NEG_INFINITY;
377 }
378 if *x <= 0.0 || !x.is_finite() {
379 return f64::NEG_INFINITY;
380 }
381
382 // Numerically stable computation
383 let lx = x.ln();
384 let z = (lx - self.mu) / self.sigma;
385
386 // Prevent overflow
387 if z.abs() > 37.0 {
388 return f64::NEG_INFINITY;
389 }
390
391 // Stable computation: log_prob = -0.5*z² - ln(x) - ln(σ) - 0.5*ln(2π)
392 const LN_2PI: f64 = 1.837_877_066_409_345_6; // ln(2π)
393 -0.5 * z * z - lx - self.sigma.ln() - 0.5 * LN_2PI
394 }
395 fn clone_box(&self) -> Box<dyn Distribution<f64>> {
396 Box::new(*self)
397 }
398}
399
400/// A continuous distribution often used to model waiting times between events.
401///
402/// Characterized by the memoryless property.
403///
404/// Mathematical Properties:
405/// - **Support**: [0, +∞)
406/// - **PDF**: f(x) = λ × exp(-λx) for x ≥ 0
407/// - **Mean**: 1 / λ
408/// - **Variance**: 1 / λ²
409/// - **Memoryless**: P(X > s + t | X > s) = P(X > t)
410///
411/// Example:
412/// ```rust
413/// # use fugue::*;
414///
415/// // Average wait time of 2 minutes (rate = 0.5 per minute)
416/// let wait_time = sample(addr!("wait"), Exponential::new(0.5).unwrap());
417///
418/// // Service time model
419/// let service = sample(addr!("service_time"), Exponential::new(1.5).unwrap())
420/// .bind(|time| {
421/// if time > 5.0 {
422/// pure("slow")
423/// } else {
424/// pure("fast")
425/// }
426/// });
427///
428/// // Observe actual waiting time
429/// let observed = observe(addr!("actual_wait"), Exponential::new(0.3).unwrap(), 4.2);
430/// ```
431#[derive(Clone, Copy, Debug)]
432pub struct Exponential {
433 /// Rate parameter λ of the exponential distribution (must be positive).
434 rate: f64,
435}
436impl Exponential {
437 /// Create a new Exponential distribution with validated parameters.
438 pub fn new(rate: f64) -> crate::error::FugueResult<Self> {
439 if rate <= 0.0 || !rate.is_finite() {
440 return Err(crate::error::FugueError::invalid_parameters(
441 "Exponential",
442 "Rate parameter must be positive and finite",
443 crate::error::ErrorCode::InvalidRate,
444 )
445 .with_context("rate", format!("{}", rate))
446 .with_context("expected", "> 0.0 and finite"));
447 }
448 Ok(Exponential { rate })
449 }
450
451 /// Get the rate parameter.
452 pub fn rate(&self) -> f64 {
453 self.rate
454 }
455}
456impl Distribution<f64> for Exponential {
457 fn sample(&self, rng: &mut dyn RngCore) -> f64 {
458 if self.rate <= 0.0 {
459 return f64::NAN;
460 }
461 RDExp::new(self.rate).unwrap().sample(rng)
462 }
463 fn log_prob(&self, x: &f64) -> LogF64 {
464 // Parameter validation
465 if self.rate <= 0.0 || !self.rate.is_finite() || !x.is_finite() {
466 return f64::NEG_INFINITY;
467 }
468
469 if *x < 0.0 {
470 f64::NEG_INFINITY
471 } else {
472 // Check for overflow: if rate * x > 700, exp(-rate*x) underflows
473 if self.rate * x > 700.0 {
474 return f64::NEG_INFINITY;
475 }
476 self.rate.ln() - self.rate * x
477 }
478 }
479 fn clone_box(&self) -> Box<dyn Distribution<f64>> {
480 Box::new(*self)
481 }
482}
483
484/// A discrete distribution for binary outcomes (true/false, success/failure).
485///
486/// Returns `bool` directly for type-safe boolean logic.
487///
488/// Mathematical Properties:
489/// - **Support**: {false, true}
490/// - **PMF**: P(X = true) = p, P(X = false) = 1 - p
491/// - **Mean**: p
492/// - **Variance**: p(1 - p)
493///
494/// Example:
495/// ```rust
496/// # use fugue::*;
497///
498/// // Fair coin flip
499/// let coin = sample(addr!("coin"), Bernoulli::new(0.5).unwrap());
500/// let result = coin.bind(|heads| {
501/// if heads {
502/// pure("Heads!")
503/// } else {
504/// pure("Tails!")
505/// }
506/// });
507///
508/// // Biased coin with observation
509/// let biased = observe(addr!("biased_coin"), Bernoulli::new(0.7).unwrap(), true);
510/// ```
511#[derive(Clone, Copy, Debug)]
512pub struct Bernoulli {
513 /// Probability of success (must be in [0, 1]).
514 p: f64,
515}
516impl Bernoulli {
517 /// Create a new Bernoulli distribution with validated parameters.
518 pub fn new(p: f64) -> crate::error::FugueResult<Self> {
519 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
520 return Err(crate::error::FugueError::invalid_parameters(
521 "Bernoulli",
522 "Probability must be in [0, 1]",
523 crate::error::ErrorCode::InvalidProbability,
524 )
525 .with_context("p", format!("{}", p))
526 .with_context("expected", "[0.0, 1.0]"));
527 }
528 Ok(Bernoulli { p })
529 }
530
531 /// Get the success probability.
532 pub fn p(&self) -> f64 {
533 self.p
534 }
535}
536impl Distribution<bool> for Bernoulli {
537 fn sample(&self, rng: &mut dyn RngCore) -> bool {
538 if self.p < 0.0 || self.p > 1.0 || !self.p.is_finite() {
539 return false; // Default to false for invalid parameters
540 }
541 use rand::Rng;
542 rng.gen::<f64>() < self.p
543 }
544 fn log_prob(&self, x: &bool) -> LogF64 {
545 // Parameter validation
546 if self.p < 0.0 || self.p > 1.0 || !self.p.is_finite() {
547 return f64::NEG_INFINITY;
548 }
549
550 if *x {
551 // P(X = true) = p
552 if self.p <= 0.0 {
553 f64::NEG_INFINITY
554 } else {
555 self.p.ln()
556 }
557 } else {
558 // P(X = false) = 1 - p
559 if self.p >= 1.0 {
560 f64::NEG_INFINITY
561 } else {
562 (1.0 - self.p).ln()
563 }
564 }
565 }
566 fn clone_box(&self) -> Box<dyn Distribution<bool>> {
567 Box::new(*self)
568 }
569}
570
571/// A discrete distribution for choosing among multiple categories with specified probabilities.
572///
573/// Returns `usize` for safe array indexing.
574///
575/// Mathematical Properties:
576/// - **Support**: {0, 1, ..., k-1} where k = number of categories
577/// - **PMF**: P(X = i) = probs[i]
578/// - **Mean**: Σ(i × probs[i])
579/// - **Variance**: Σ(i² × probs[i]) - mean²
580///
581/// Example:
582/// ```rust
583/// # use fugue::*;
584///
585/// // Custom probabilities
586/// let weighted = Categorical::new(vec![0.1, 0.2, 0.3, 0.4]).unwrap();
587///
588/// // Uniform distribution over k categories
589/// let uniform = Categorical::uniform(4).unwrap();
590///
591/// // Choose from three options
592/// let options = vec!["red", "green", "blue"];
593/// let choice = sample(addr!("color"), Categorical::new(vec![0.5, 0.3, 0.2]).unwrap())
594/// .map(move |idx| options[idx].to_string());
595///
596/// // Observe a specific choice
597/// let observed = observe(addr!("user_choice"),
598/// Categorical::uniform(3).unwrap(), 1usize);
599/// ```
600#[derive(Clone, Debug)]
601pub struct Categorical {
602 /// Probabilities for each category (should sum to 1.0).
603 probs: Vec<f64>,
604}
605impl Categorical {
606 /// Create a new Categorical distribution with validated parameters.
607 pub fn new(probs: Vec<f64>) -> crate::error::FugueResult<Self> {
608 if probs.is_empty() {
609 return Err(crate::error::FugueError::invalid_parameters(
610 "Categorical",
611 "Probability vector cannot be empty",
612 crate::error::ErrorCode::InvalidProbability,
613 )
614 .with_context("length", "0"));
615 }
616
617 let sum: f64 = probs.iter().sum();
618 if (sum - 1.0).abs() > 1e-6 {
619 return Err(crate::error::FugueError::invalid_parameters(
620 "Categorical",
621 "Probabilities must sum to 1.0",
622 crate::error::ErrorCode::InvalidProbability,
623 )
624 .with_context("sum", format!("{:.6}", sum))
625 .with_context("expected", "1.0")
626 .with_context("tolerance", "1e-6"));
627 }
628
629 for (i, &p) in probs.iter().enumerate() {
630 if !p.is_finite() || p < 0.0 {
631 return Err(crate::error::FugueError::invalid_parameters(
632 "Categorical",
633 "All probabilities must be non-negative and finite",
634 crate::error::ErrorCode::InvalidProbability,
635 )
636 .with_context("index", format!("{}", i))
637 .with_context("value", format!("{}", p))
638 .with_context("expected", ">= 0.0 and finite"));
639 }
640 }
641
642 Ok(Categorical { probs })
643 }
644
645 /// Create a uniform categorical distribution over k categories.
646 pub fn uniform(k: usize) -> crate::error::FugueResult<Self> {
647 if k == 0 {
648 return Err(crate::error::FugueError::invalid_parameters(
649 "Categorical",
650 "Number of categories must be positive",
651 crate::error::ErrorCode::InvalidCount,
652 )
653 .with_context("k", "0"));
654 }
655
656 let prob = 1.0 / k as f64;
657 let probs = vec![prob; k];
658 Ok(Categorical { probs })
659 }
660
661 /// Get the probability vector.
662 pub fn probs(&self) -> &[f64] {
663 &self.probs
664 }
665
666 /// Get the number of categories.
667 pub fn len(&self) -> usize {
668 self.probs.len()
669 }
670
671 /// Check if the distribution has no categories.
672 pub fn is_empty(&self) -> bool {
673 self.probs.is_empty()
674 }
675}
676impl Distribution<usize> for Categorical {
677 fn sample(&self, rng: &mut dyn RngCore) -> usize {
678 // Parameter validation
679 if self.probs.is_empty() {
680 return 0;
681 }
682
683 let prob_sum: f64 = self.probs.iter().sum();
684 if (prob_sum - 1.0).abs() > 1e-6 || self.probs.iter().any(|&p| p < 0.0 || !p.is_finite()) {
685 return 0;
686 }
687
688 use rand::Rng;
689 let u: f64 = rng.gen();
690 let mut cum = 0.0;
691 for (i, &p) in self.probs.iter().enumerate() {
692 cum += p;
693 if u <= cum {
694 return i;
695 }
696 }
697 self.probs.len() - 1
698 }
699 fn log_prob(&self, x: &usize) -> LogF64 {
700 // Parameter validation
701 if self.probs.is_empty() || *x >= self.probs.len() {
702 return f64::NEG_INFINITY;
703 }
704
705 let prob_sum: f64 = self.probs.iter().sum();
706 if (prob_sum - 1.0).abs() > 1e-6 || self.probs.iter().any(|&p| p < 0.0 || !p.is_finite()) {
707 return f64::NEG_INFINITY;
708 }
709
710 if self.probs[*x] <= 0.0 {
711 f64::NEG_INFINITY
712 } else {
713 self.probs[*x].ln()
714 }
715 }
716 fn clone_box(&self) -> Box<dyn Distribution<usize>> {
717 Box::new(self.clone())
718 }
719}
720
721/// A continuous distribution on the interval (0, 1), commonly used for modeling probabilities and proportions.
722///
723/// Conjugate prior for Bernoulli/Binomial distributions.
724///
725/// Mathematical Properties:
726/// - **Support**: (0, 1)
727/// - **PDF**: f(x) = (x^(α-1) × (1-x)^(β-1)) / B(α,β)
728/// - **Mean**: α / (α + β)
729/// - **Variance**: (αβ) / ((α+β)²(α+β+1))
730///
731/// Example:
732/// ```rust
733/// # use fugue::*;
734///
735/// // Uniform on [0,1]
736/// let uniform = sample(addr!("p"), Beta::new(1.0, 1.0).unwrap());
737///
738/// // Prior for success probability
739/// let prob_prior = sample(addr!("success_rate"), Beta::new(2.0, 5.0).unwrap());
740///
741/// // Conjugate prior-likelihood pair
742/// let model = sample(addr!("p"), Beta::new(3.0, 7.0).unwrap())
743/// .bind(|p| observe(addr!("trial"), Bernoulli::new(p).unwrap(), true));
744///
745/// // Skewed towards 0 (beta > alpha)
746/// let skewed = sample(addr!("proportion"), Beta::new(2.0, 8.0).unwrap());
747/// ```
748#[derive(Clone, Copy, Debug)]
749pub struct Beta {
750 /// First shape parameter α (must be positive).
751 alpha: f64,
752 /// Second shape parameter β (must be positive).
753 beta: f64,
754}
755impl Beta {
756 /// Create a new Beta distribution with validated parameters.
757 pub fn new(alpha: f64, beta: f64) -> crate::error::FugueResult<Self> {
758 if alpha <= 0.0 || !alpha.is_finite() {
759 return Err(crate::error::FugueError::invalid_parameters(
760 "Beta",
761 "Alpha parameter must be positive and finite",
762 crate::error::ErrorCode::InvalidShape,
763 )
764 .with_context("alpha", format!("{}", alpha))
765 .with_context("expected", "> 0.0 and finite"));
766 }
767 if beta <= 0.0 || !beta.is_finite() {
768 return Err(crate::error::FugueError::invalid_parameters(
769 "Beta",
770 "Beta parameter must be positive and finite",
771 crate::error::ErrorCode::InvalidShape,
772 )
773 .with_context("beta", format!("{}", beta))
774 .with_context("expected", "> 0.0 and finite"));
775 }
776 Ok(Beta { alpha, beta })
777 }
778
779 /// Get the alpha parameter.
780 pub fn alpha(&self) -> f64 {
781 self.alpha
782 }
783
784 /// Get the beta parameter.
785 pub fn beta(&self) -> f64 {
786 self.beta
787 }
788}
789impl Distribution<f64> for Beta {
790 fn sample(&self, rng: &mut dyn RngCore) -> f64 {
791 if self.alpha <= 0.0 || self.beta <= 0.0 {
792 return f64::NAN;
793 }
794 RDBeta::new(self.alpha, self.beta).unwrap().sample(rng)
795 }
796 fn log_prob(&self, x: &f64) -> LogF64 {
797 // Parameter validation
798 if self.alpha <= 0.0
799 || self.beta <= 0.0
800 || !self.alpha.is_finite()
801 || !self.beta.is_finite()
802 || !x.is_finite()
803 {
804 return f64::NEG_INFINITY;
805 }
806
807 // Support validation
808 if *x <= 0.0 || *x >= 1.0 {
809 return f64::NEG_INFINITY;
810 }
811
812 // Handle edge cases near boundaries
813 if *x < 1e-100 || *x > 1.0 - 1e-100 {
814 return f64::NEG_INFINITY;
815 }
816
817 // Numerically stable computation using log-gamma
818 // log Beta(x; α, β) = (α-1)ln(x) + (β-1)ln(1-x) - log B(α,β)
819 let log_beta_fn = libm::lgamma(self.alpha) + libm::lgamma(self.beta)
820 - libm::lgamma(self.alpha + self.beta);
821
822 let ln_x = x.ln();
823 let ln_1_minus_x = (1.0 - x).ln();
824
825 // Check for extreme log values
826 if ln_x < -700.0 || ln_1_minus_x < -700.0 {
827 return f64::NEG_INFINITY;
828 }
829
830 (self.alpha - 1.0) * ln_x + (self.beta - 1.0) * ln_1_minus_x - log_beta_fn
831 }
832 fn clone_box(&self) -> Box<dyn Distribution<f64>> {
833 Box::new(*self)
834 }
835}
836
837/// A continuous distribution over positive real numbers, parameterized by shape and rate.
838///
839/// Commonly used for modeling waiting times and as a conjugate prior for Poisson distributions.
840///
841/// Mathematical Properties:
842/// - **Support**: (0, +∞)
843/// - **PDF**: f(x) = (λ^k / Γ(k)) × x^(k-1) × exp(-λx)
844/// - **Mean**: k / λ
845/// - **Variance**: k / λ²
846///
847/// Example:
848/// ```rust
849/// # use fugue::*;
850///
851/// // Shape=1 gives Exponential distribution
852/// let exponential_like = sample(addr!("wait_time"), Gamma::new(1.0, 2.0).unwrap());
853///
854/// // Prior for precision parameter
855/// let precision = sample(addr!("precision"), Gamma::new(2.0, 1.0).unwrap());
856///
857/// // Conjugate prior for Poisson rate
858/// let model = sample(addr!("rate"), Gamma::new(3.0, 2.0).unwrap())
859/// .bind(|lambda| observe(addr!("count"), Poisson::new(lambda).unwrap(), 5u64));
860///
861/// // Scale parameter (rate = 1/scale)
862/// let scale_param = sample(addr!("scale"), Gamma::new(2.0, 0.5).unwrap()); // mean = 4
863/// ```
864#[derive(Clone, Copy, Debug)]
865pub struct Gamma {
866 /// Shape parameter k (must be positive).
867 shape: f64,
868 /// Rate parameter λ (must be positive).
869 rate: f64,
870}
871impl Gamma {
872 /// Create a new Gamma distribution with validated parameters.
873 pub fn new(shape: f64, rate: f64) -> crate::error::FugueResult<Self> {
874 if shape <= 0.0 || !shape.is_finite() {
875 return Err(crate::error::FugueError::invalid_parameters(
876 "Gamma",
877 "Shape parameter must be positive and finite",
878 crate::error::ErrorCode::InvalidShape,
879 )
880 .with_context("shape", format!("{}", shape))
881 .with_context("expected", "> 0.0 and finite"));
882 }
883 if rate <= 0.0 || !rate.is_finite() {
884 return Err(crate::error::FugueError::invalid_parameters(
885 "Gamma",
886 "Rate parameter must be positive and finite",
887 crate::error::ErrorCode::InvalidRate,
888 )
889 .with_context("rate", format!("{}", rate))
890 .with_context("expected", "> 0.0 and finite"));
891 }
892 Ok(Gamma { shape, rate })
893 }
894
895 /// Get the shape parameter.
896 pub fn shape(&self) -> f64 {
897 self.shape
898 }
899
900 /// Get the rate parameter.
901 pub fn rate(&self) -> f64 {
902 self.rate
903 }
904}
905impl Distribution<f64> for Gamma {
906 fn sample(&self, rng: &mut dyn RngCore) -> f64 {
907 if self.shape <= 0.0 || self.rate <= 0.0 {
908 return f64::NAN;
909 }
910 RDGamma::new(self.shape, 1.0 / self.rate)
911 .unwrap()
912 .sample(rng)
913 }
914 fn log_prob(&self, x: &f64) -> LogF64 {
915 // Parameter validation
916 if self.shape <= 0.0
917 || self.rate <= 0.0
918 || !self.shape.is_finite()
919 || !self.rate.is_finite()
920 || !x.is_finite()
921 {
922 return f64::NEG_INFINITY;
923 }
924
925 if *x <= 0.0 {
926 return f64::NEG_INFINITY;
927 }
928
929 // Check for overflow conditions
930 if self.rate * x > 700.0 || x.ln() * (self.shape - 1.0) < -700.0 {
931 return f64::NEG_INFINITY;
932 }
933
934 // Numerically stable computation
935 // log Gamma(x; k, λ) = k*ln(λ) + (k-1)*ln(x) - λ*x - ln Γ(k)
936 let log_rate = self.rate.ln();
937 let log_x = x.ln();
938 let log_gamma_shape = libm::lgamma(self.shape);
939
940 self.shape * log_rate + (self.shape - 1.0) * log_x - self.rate * x - log_gamma_shape
941 }
942 fn clone_box(&self) -> Box<dyn Distribution<f64>> {
943 Box::new(*self)
944 }
945}
946
947/// A discrete distribution representing the number of successes in n independent trials, with probability of success p.
948///
949/// Returns `u64` for natural success counting.
950///
951/// Mathematical Properties:
952/// - **Support**: {0, 1, ..., n}
953/// - **PMF**: P(X = k) = C(n,k) × p^k × (1-p)^(n-k)
954/// - **Mean**: n × p
955/// - **Variance**: n × p × (1-p)
956///
957/// Example:
958/// ```rust
959/// # use fugue::*;
960///
961/// // 10 coin flips
962/// let successes = sample(addr!("heads"), Binomial::new(10, 0.5).unwrap())
963/// .bind(|count| {
964/// let rate = count as f64 / 10.0;
965/// pure(format!("Success rate: {:.1}%", rate * 100.0))
966/// });
967///
968/// // Clinical trial
969/// let trial = sample(addr!("success_rate"), Beta::new(1.0, 1.0).unwrap())
970/// .bind(|p| sample(addr!("successes"), Binomial::new(100, p).unwrap()));
971///
972/// // Observe trial results
973/// let observed = observe(addr!("trial_successes"), Binomial::new(20, 0.3).unwrap(), 7u64);
974/// ```
975#[derive(Clone, Copy, Debug)]
976pub struct Binomial {
977 /// Number of trials.
978 n: u64,
979 /// Probability of success on each trial (must be in [0, 1]).
980 p: f64,
981}
982impl Binomial {
983 /// Create a new Binomial distribution with validated parameters.
984 pub fn new(n: u64, p: f64) -> crate::error::FugueResult<Self> {
985 if !p.is_finite() || !(0.0..=1.0).contains(&p) {
986 return Err(crate::error::FugueError::invalid_parameters(
987 "Binomial",
988 "Probability must be in [0, 1]",
989 crate::error::ErrorCode::InvalidProbability,
990 )
991 .with_context("p", format!("{}", p))
992 .with_context("expected", "[0.0, 1.0]"));
993 }
994 Ok(Binomial { n, p })
995 }
996
997 /// Get the number of trials.
998 pub fn n(&self) -> u64 {
999 self.n
1000 }
1001
1002 /// Get the success probability.
1003 pub fn p(&self) -> f64 {
1004 self.p
1005 }
1006}
1007impl Distribution<u64> for Binomial {
1008 fn sample(&self, rng: &mut dyn RngCore) -> u64 {
1009 RDBinomial::new(self.n, self.p).unwrap().sample(rng)
1010 }
1011 fn log_prob(&self, x: &u64) -> LogF64 {
1012 let k = *x;
1013 if k > self.n {
1014 return f64::NEG_INFINITY;
1015 }
1016 // log Binomial(k; n, p) = log C(n,k) + k*ln(p) + (n-k)*ln(1-p)
1017 let log_binom_coeff = libm::lgamma(self.n as f64 + 1.0)
1018 - libm::lgamma(k as f64 + 1.0)
1019 - libm::lgamma((self.n - k) as f64 + 1.0);
1020 log_binom_coeff + (k as f64) * self.p.ln() + ((self.n - k) as f64) * (1.0 - self.p).ln()
1021 }
1022 fn clone_box(&self) -> Box<dyn Distribution<u64>> {
1023 Box::new(*self)
1024 }
1025}
1026
1027/// A discrete distribution for modeling the number of events occurring in a fixed interval.
1028///
1029/// Returns `u64` for natural counting arithmetic.
1030///
1031/// Mathematical Properties:
1032/// - **Support**: {0, 1, 2, 3, ...}
1033/// - **PMF**: P(X = k) = (λ^k × e^(-λ)) / k!
1034/// - **Mean**: λ
1035/// - **Variance**: λ
1036/// - **Memoryless**: Past events don't affect future rates
1037///
1038/// Example:
1039/// ```rust
1040/// # use fugue::*;
1041///
1042/// // Model event counts
1043/// let events = sample(addr!("events"), Poisson::new(3.0).unwrap())
1044/// .bind(|count| {
1045/// let status = match count {
1046/// 0 => "No events",
1047/// 1 => "Single event",
1048/// n if n > 10 => "High activity",
1049/// _ => "Normal activity"
1050/// };
1051/// pure(status.to_string())
1052/// });
1053///
1054/// // Hierarchical model with Gamma prior
1055/// let hierarchical = sample(addr!("rate"), Gamma::new(2.0, 1.0).unwrap())
1056/// .bind(|lambda| sample(addr!("count"), Poisson::new(lambda).unwrap()));
1057///
1058/// // Observe count data
1059/// let observed = observe(addr!("observed_count"), Poisson::new(4.0).unwrap(), 7u64);
1060/// ```
1061#[derive(Clone, Copy, Debug)]
1062pub struct Poisson {
1063 /// Rate parameter λ (must be positive). Mean and variance of the distribution.
1064 lambda: f64,
1065}
1066impl Poisson {
1067 /// Create a new Poisson distribution with validated parameters.
1068 pub fn new(lambda: f64) -> crate::error::FugueResult<Self> {
1069 if lambda <= 0.0 || !lambda.is_finite() {
1070 return Err(crate::error::FugueError::invalid_parameters(
1071 "Poisson",
1072 "Rate parameter lambda must be positive and finite",
1073 crate::error::ErrorCode::InvalidRate,
1074 )
1075 .with_context("lambda", format!("{}", lambda))
1076 .with_context("expected", "> 0.0 and finite"));
1077 }
1078 Ok(Poisson { lambda })
1079 }
1080
1081 /// Get the rate parameter.
1082 pub fn lambda(&self) -> f64 {
1083 self.lambda
1084 }
1085}
1086impl Distribution<u64> for Poisson {
1087 fn sample(&self, rng: &mut dyn RngCore) -> u64 {
1088 if self.lambda <= 0.0 || !self.lambda.is_finite() {
1089 return 0;
1090 }
1091 RDPoisson::new(self.lambda).unwrap().sample(rng) as u64
1092 }
1093 fn log_prob(&self, x: &u64) -> LogF64 {
1094 // Parameter validation
1095 if self.lambda <= 0.0 || !self.lambda.is_finite() {
1096 return f64::NEG_INFINITY;
1097 }
1098
1099 let k = *x;
1100
1101 // Handle extreme cases
1102 if self.lambda > 700.0 && k == 0 {
1103 return -self.lambda; // Direct computation to avoid lgamma issues
1104 }
1105
1106 // Numerically stable computation
1107 // log Poisson(k; λ) = k*ln(λ) - λ - ln(k!)
1108 let k_f64 = k as f64;
1109 let log_lambda = self.lambda.ln();
1110 let log_factorial = libm::lgamma(k_f64 + 1.0);
1111
1112 k_f64 * log_lambda - self.lambda - log_factorial
1113 }
1114 fn clone_box(&self) -> Box<dyn Distribution<u64>> {
1115 Box::new(*self)
1116 }
1117}
1118
1119#[cfg(test)]
1120mod tests {
1121 use super::*;
1122 use rand::rngs::StdRng;
1123 use rand::SeedableRng;
1124
1125 #[test]
1126 fn normal_constructor_and_log_prob() {
1127 assert!(Normal::new(0.0, 1.0).is_ok());
1128 assert!(Normal::new(f64::NAN, 1.0).is_err());
1129 assert!(Normal::new(0.0, 0.0).is_err());
1130
1131 let n = Normal::new(0.0, 1.0).unwrap();
1132 assert!(n.log_prob(&0.0).is_finite());
1133 assert_eq!(n.log_prob(&f64::INFINITY), f64::NEG_INFINITY);
1134 }
1135
1136 #[test]
1137 fn uniform_support_and_log_prob() {
1138 assert!(Uniform::new(0.0, 1.0).is_ok());
1139 assert!(Uniform::new(1.0, 0.0).is_err());
1140 let u = Uniform::new(-2.0, 2.0).unwrap();
1141 // Inside support
1142 let lp0 = u.log_prob(&0.0);
1143 assert!(lp0.is_finite());
1144 // Outside support
1145 assert_eq!(u.log_prob(&2.0), f64::NEG_INFINITY);
1146 assert_eq!(u.log_prob(&-2.1), f64::NEG_INFINITY);
1147 }
1148
1149 #[test]
1150 fn lognormal_validation() {
1151 assert!(LogNormal::new(0.0, 1.0).is_ok());
1152 assert!(LogNormal::new(0.0, 0.0).is_err());
1153 let ln = LogNormal::new(0.0, 1.0).unwrap();
1154 assert_eq!(ln.log_prob(&0.0), f64::NEG_INFINITY);
1155 assert!(ln.log_prob(&1.0).is_finite());
1156 }
1157
1158 #[test]
1159 fn exponential_validation() {
1160 assert!(Exponential::new(1.0).is_ok());
1161 assert!(Exponential::new(0.0).is_err());
1162 let e = Exponential::new(2.0).unwrap();
1163 assert_eq!(e.log_prob(&-1.0), f64::NEG_INFINITY);
1164 assert!((e.log_prob(&0.0) - (2.0f64).ln()).abs() < 1e-12);
1165 }
1166
1167 #[test]
1168 fn bernoulli_validation() {
1169 assert!(Bernoulli::new(0.5).is_ok());
1170 assert!(Bernoulli::new(-0.1).is_err());
1171 let b = Bernoulli::new(0.25).unwrap();
1172 assert!((b.log_prob(&true) - (0.25f64).ln()).abs() < 1e-12);
1173 assert!((b.log_prob(&false) - (0.75f64).ln()).abs() < 1e-12);
1174 }
1175
1176 #[test]
1177 fn categorical_validation_and_log_prob() {
1178 assert!(Categorical::new(vec![0.5, 0.5]).is_ok());
1179 assert!(Categorical::new(vec![]).is_err());
1180 assert!(Categorical::new(vec![0.6, 0.5]).is_err());
1181
1182 let c = Categorical::new(vec![0.2, 0.8]).unwrap();
1183 assert!((c.log_prob(&1) - (0.8f64).ln()).abs() < 1e-12);
1184 assert_eq!(c.log_prob(&2), f64::NEG_INFINITY);
1185 }
1186
1187 #[test]
1188 fn beta_validation_and_support() {
1189 assert!(Beta::new(2.0, 3.0).is_ok());
1190 assert!(Beta::new(0.0, 1.0).is_err());
1191 let b = Beta::new(2.0, 5.0).unwrap();
1192 assert_eq!(b.log_prob(&0.0), f64::NEG_INFINITY);
1193 assert_eq!(b.log_prob(&1.0), f64::NEG_INFINITY);
1194 assert!(b.log_prob(&0.5).is_finite());
1195 }
1196
1197 #[test]
1198 fn gamma_validation_and_support() {
1199 assert!(Gamma::new(1.5, 2.0).is_ok());
1200 assert!(Gamma::new(0.0, 2.0).is_err());
1201 assert!(Gamma::new(1.0, 0.0).is_err());
1202 let g = Gamma::new(2.0, 1.0).unwrap();
1203 assert_eq!(g.log_prob(&-1.0), f64::NEG_INFINITY);
1204 assert!(g.log_prob(&1.0).is_finite());
1205 }
1206
1207 #[test]
1208 fn binomial_validation_and_log_prob() {
1209 assert!(Binomial::new(10, 0.5).is_ok());
1210 assert!(Binomial::new(10, 1.5).is_err());
1211 let bi = Binomial::new(5, 0.3).unwrap();
1212 assert_eq!(bi.log_prob(&6), f64::NEG_INFINITY); // k > n
1213 assert!(bi.log_prob(&3).is_finite());
1214 }
1215
1216 #[test]
1217 fn poisson_validation_and_log_prob() {
1218 assert!(Poisson::new(1.0).is_ok());
1219 assert!(Poisson::new(0.0).is_err());
1220 let p = Poisson::new(3.0).unwrap();
1221 assert!(p.log_prob(&0).is_finite());
1222 assert!(p.log_prob(&5).is_finite());
1223 }
1224
1225 #[test]
1226 fn sampling_basic_sanity() {
1227 let mut rng = StdRng::seed_from_u64(42);
1228 let n = Normal::new(0.0, 1.0).unwrap();
1229 let x = n.sample(&mut rng);
1230 assert!(x.is_finite());
1231
1232 let u = Uniform::new(-1.0, 2.0).unwrap();
1233 let y = u.sample(&mut rng);
1234 assert!((-1.0..2.0).contains(&y));
1235
1236 let b = Bernoulli::new(0.7).unwrap();
1237 let _z = b.sample(&mut rng);
1238 }
1239
1240 #[test]
1241 fn categorical_uniform_constructor() {
1242 let cu = Categorical::uniform(4).unwrap();
1243 assert_eq!(cu.len(), 4);
1244 for &p in cu.probs() {
1245 assert!((p - 0.25).abs() < 1e-12);
1246 }
1247 }
1248}